tensorflow.train.import_meta_graph does not work?

ghz 7months ago ⋅ 158 views

I try to simply save and restore a graph, but the simplest example does not work as expected (this is done using version 0.9.0 or 0.10.0 on Linux 64 without CUDA using python 2.7 or 3.5.2)

First I save the graph like this:

import tensorflow as tf
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

This creates a file "file" that is non-empty and also sets g1 to something that looks like a proper graph definition.

Then I try to restore this graph:

import tensorflow as tf
g=tf.train.import_meta_graph("file")

This works without an error, but does not return anything at all.

Can anyone provide the necessary code to simply just save the graph for "v4" and completely restore it so that running this in a new session will produce the same result?

Answers

To save and restore a TensorFlow graph along with its variables, you need to use both the tf.train.Saver class and the export_meta_graph and import_meta_graph functions. The export_meta_graph function saves the graph structure to a meta graph file, while the tf.train.Saver saves the variable values to a checkpoint file.

Here's how you can achieve this:

Save the Graph and Variables

First, define your graph, run a session to compute some values, and save both the graph structure and variable values:

import tensorflow as tf

# Define the graph
v1 = tf.placeholder('float32', name='v1') 
v2 = tf.placeholder('float32', name='v2')
v3 = tf.multiply(v1, v2, name='v3')
c1 = tf.constant(22.0, name='c1')
v4 = tf.add(v3, c1, name='v4')

# Create a session and run it
sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
print("Result before saving: ", result)

# Save the graph and variables
saver = tf.train.Saver()
saver.save(sess, './model.ckpt')
tf.train.export_meta_graph('./model.meta')

Restore the Graph and Variables

To restore the graph and its variables, you need to import the meta graph and then restore the variables using a tf.train.Saver:

import tensorflow as tf

# Restore the graph
new_sess = tf.Session()
saver = tf.train.import_meta_graph('./model.meta')
saver.restore(new_sess, './model.ckpt')

# Get the restored operations and placeholders
graph = tf.get_default_graph()
v1 = graph.get_tensor_by_name('v1:0')
v2 = graph.get_tensor_by_name('v2:0')
v4 = graph.get_tensor_by_name('v4:0')

# Run the session with restored graph
result = new_sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
print("Result after restoring: ", result)

Explanation:

  1. Saving:

    • Define the graph and run a session to compute some values.
    • Use tf.train.Saver to save the variable values to a checkpoint file (model.ckpt).
    • Use tf.train.export_meta_graph to save the graph structure to a meta graph file (model.meta).
  2. Restoring:

    • Use tf.train.import_meta_graph to import the graph structure from the meta graph file.
    • Use saver.restore to restore the variable values from the checkpoint file.
    • Access the restored operations and placeholders using graph.get_tensor_by_name.

This approach ensures that you save both the graph structure and the variable values, and restore them correctly to get the same result.