I try to simply save and restore a graph, but the simplest example does not work as expected (this is done using version 0.9.0 or 0.10.0 on Linux 64 without CUDA using python 2.7 or 3.5.2)
First I save the graph like this:
import tensorflow as tf
v1 = tf.placeholder('float32')
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])
This creates a file "file" that is non-empty and also sets g1 to something that looks like a proper graph definition.
Then I try to restore this graph:
import tensorflow as tf
g=tf.train.import_meta_graph("file")
This works without an error, but does not return anything at all.
Can anyone provide the necessary code to simply just save the graph for "v4" and completely restore it so that running this in a new session will produce the same result?
Answers
To save and restore a TensorFlow graph along with its variables, you need to use both the tf.train.Saver
class and the export_meta_graph
and import_meta_graph
functions. The export_meta_graph
function saves the graph structure to a meta graph file, while the tf.train.Saver
saves the variable values to a checkpoint file.
Here's how you can achieve this:
Save the Graph and Variables
First, define your graph, run a session to compute some values, and save both the graph structure and variable values:
import tensorflow as tf
# Define the graph
v1 = tf.placeholder('float32', name='v1')
v2 = tf.placeholder('float32', name='v2')
v3 = tf.multiply(v1, v2, name='v3')
c1 = tf.constant(22.0, name='c1')
v4 = tf.add(v3, c1, name='v4')
# Create a session and run it
sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
print("Result before saving: ", result)
# Save the graph and variables
saver = tf.train.Saver()
saver.save(sess, './model.ckpt')
tf.train.export_meta_graph('./model.meta')
Restore the Graph and Variables
To restore the graph and its variables, you need to import the meta graph and then restore the variables using a tf.train.Saver
:
import tensorflow as tf
# Restore the graph
new_sess = tf.Session()
saver = tf.train.import_meta_graph('./model.meta')
saver.restore(new_sess, './model.ckpt')
# Get the restored operations and placeholders
graph = tf.get_default_graph()
v1 = graph.get_tensor_by_name('v1:0')
v2 = graph.get_tensor_by_name('v2:0')
v4 = graph.get_tensor_by_name('v4:0')
# Run the session with restored graph
result = new_sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
print("Result after restoring: ", result)
Explanation:
-
Saving:
- Define the graph and run a session to compute some values.
- Use
tf.train.Saver
to save the variable values to a checkpoint file (model.ckpt
). - Use
tf.train.export_meta_graph
to save the graph structure to a meta graph file (model.meta
).
-
Restoring:
- Use
tf.train.import_meta_graph
to import the graph structure from the meta graph file. - Use
saver.restore
to restore the variable values from the checkpoint file. - Access the restored operations and placeholders using
graph.get_tensor_by_name
.
- Use
This approach ensures that you save both the graph structure and the variable values, and restore them correctly to get the same result.