Tensorflow — .pb Save and Restore

son John
1 min readApr 16, 2020

Save:

import tensorflow as tf
from tensorflow.python.framework import graph_util
with tf.Session(graph=tf.Graph()) as sess:
with tf.variable_scope('v1'):
x = tf.placeholder(tf.int32, name='x')
zz = tf.placeholder(tf.int32, name='zz')
b = tf.Variable(2, name='b')
xb = tf.multiply(x, b)
result = tf.add(xb, zz, name='add')
print(result.name)
sess.run(tf.global_variables_initializer())
constant_graph = graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
['v1/zz', 'v1/x', 'v1/b', 'v1/add'])
with tf.gfile.FastGFile('./filename.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())

Restore:

import tensorflow as tf
from tensorflow.python.framework import graph_util
with tf.Session(graph=tf.Graph()) as sess:
sess.run(tf.global_variables_initializer())
graph_def = tf.GraphDef()
with open('./filename.pb', "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
x = sess.graph.get_tensor_by_name('v1/x:0')
zz = sess.graph.get_tensor_by_name('v1/zz:0')
b = sess.graph.get_tensor_by_name('v1/b:0')
add = sess.graph.get_tensor_by_name('v1/add:0')
print(sess.run(add, feed_dict={x:2, zz:5, b:3}))
#x * b + zz = 11

--

--