tensorflow加载多个计算图的冲突解决

2021/4/23 18:58:18

本文主要是介绍tensorflow加载多个计算图的冲突解决,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

需求:顺序加载多个计算图时,会导致第二个计算图后变量  不可用,在程序初始化中解决该问题(一下代码没有做优化,请读者自行修正)

class BertEncoder(object):
    """ model
    """
    def __init__(self, OUTPUT_GRAPH, OUT_TENSOR):
        self.max_length = 30
        self.tokenizer = TOKENIZER
        self.out_graph = os.path.join(CURRENT_DIR, "pb_model", OUTPUT_GRAPH)
        self.out_tensor = OUT_TENSOR
        self.model_graph = {}
        graph = tf.Graph()
        with graph.as_default():
            self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef()
            with open(self.out_graph, "rb") as f:
                self.model_graph['output_graph_def'].ParseFromString(f.read())
            self.model_graph['sess'] = tf.Session(graph=graph)
        with self.model_graph['sess'].as_default():
            with graph.as_default():
                self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer())
                tf.import_graph_def(self.model_graph['output_graph_def'], name="")
                self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("input_ids:0")
                self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("input_mask:0")
                self.output_tensor = self.model_graph['sess'].graph.get_tensor_by_name(self.out_tensor)


    def predict(self, to_predict):
        """pb predict
        """
        sentence = [each.lower() for each in to_predict]
        input_ids, input_mask, = self.convert(sentence)
        feed_dict = {self.input_ids_p: input_ids,
                     self.input_mask_p: input_mask}
        sess = self.model_graph['sess']
        output_emb = sess.run(self.output_tensor, feed_dict)
        return output_emb

 



这篇关于tensorflow加载多个计算图的冲突解决的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程