Data Science Asked by JacoSolari on April 23, 2021
I recently switched from PyTorch to TF (1 and 2) and I am trying to acquire a good workflow with it.
The simple things I want to do are the following:
model.summary()
to inspect the network architecture of the loaded model.I know that TF has the concept of graph and weights as opposed to PyTorch which just have models encompassing everything.
Nevertheless I could not find an easy and best way to load a pretraind model and the internet is full of different answers for different tf versions.
I am really confused because to achieve the points above I have so many different files available when I download a pretrained model from TF1 zoo (or TF2 zoo).
Take this one for instance, the first in the list of the TF1 zoo.
I have the saved_model
folder with the saved_model.pb
and the variables
(empty) folder, the frozen_inference_graph.pb
the model.ckpt
files, the pipeline.config
and in some cases an event
file.
Are all these different files really necessary to encode graph structure and weights?Am I missing something or this is just more complicated than necessary?
In addition, the file/folder structure is different iF you download a model from TF2 zoo (see image below)
What I tried
import tensorflow as tf #(v2.4)
def load_pretrained_model(self, saved_model_sub_folder,
mode):
# 1. this only load an AutoTrackable object that can be use for inference but no graph
if mode == '.pb':
model_dir = str(TRAINED_MODEL_DIR) + saved_model_sub_folder
model_dir = pathlib.Path(model_dir) / "saved_model"
model = tf.saved_model.load(str(model_dir), None, '.')
detection_model = model.signatures['serving_default']
# 2. this returns None
elif mode == '.graph':
def load_graph(frozen_graph_filename):
with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
detection_model = tf.compat.v1.import_graph_def(load_graph(frozen_graph_filename))
else:
detection_model = None
return detection_model
Can someone answer some of the points (1 to 5) above regarding how to load a complete (graph, weights, everything..) customizable tensorflow1 or tensorflow2 model in python3?
Get help from others!
Recent Answers
Recent Questions
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP