Technology-CheckPoint 转 Pb 格式

Tensorfow 的 CheckPoint 格式转 Pb 格式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def export_model(input_checkpoint, output_graph):
#这个可以加载saver的模型
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
# print(graph.get_operations())
# print(graph.get_operations()[-1].name)
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=[graph.get_operations()[-1].name])# 如果有多个输出节点,以逗号隔开这个是重点,输入和输出的参数都需要在这里记录

with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出


export_model('../checkpoint-7000', "../ge.pb")