TransWikia.com

How to load a saved model in TensorFlow?

Data Science Asked by Remy on February 12, 2021

This is my code in Python:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np

I checked if the saved model is there using the following code:

tf.compat.v1.saved_model.contains_saved_model(
    '/Link_to_the_saved_model_directory/'
)

which returns True
and I can use the following code to further make sure the model is saved correctly, as far as I understood:

tf.saved_model.Asset(
    '/Link_to_the_saved_model_directory/'
)

which returns this:

<tensorflow.python.training.tracking.tracking.Asset at 0x2aad125e5710>

So, everything looks fine. But, when I use the following script to load the model I get an error.

LoadedModel = tf.saved_model.load(
    export_dir='/Link_to_the_saved_model_directory/', tags=None
)

Error:

OSError: Cannot parse file b'/Link_to_the_saved_model_directory/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

The files in the /Link_to_the_saved_model_directory/ look like this:

['saved_model.pbtxt',
 'saved_model.ckpt-0.data-00000-of-00001',
 'saved_model.ckpt-0.meta',
 'checkpoint',
 'saved_model.ckpt-0.index']

Any suggestion is greatly appreciated on how to load the model so that I could reuse it for transfer learning. It might also be the case that the model is partly written in a previous version of TensorFlow (e.g. TensorFlow 1.x) and this error is thus due to compatibility issue, but I could not find a solution for that yet.

Update: I tried the following method but it does not work (tf was imported using the compatible version import tensorflow.compat.v1. as tf):

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
    saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
    loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)

It returns the following warnings and errors:

WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

One Answer

The TensorFlow documentation for tf.saved_model.load might help:

SavedModels from tf.estimator.Estimator or 1.x SavedModel APIs have a flat graph instead of tf.function objects. These SavedModels will have functions corresponding to their signatures in the .signatures attribute, but also have a .prune method which allows you to extract functions for new subgraphs. This is equivalent to importing the SavedModel and naming feeds and fetches in a Session from TensorFlow 1.x.

You might have to use deprecated v1 api call https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/load

Answered by Brian Spiering on February 12, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP