From 7749c628adad8fcf029537b0dab76cfddb3515b1 Mon Sep 17 00:00:00 2001 From: Hao Date: Mon, 18 Jun 2018 13:58:10 +1200 Subject: [PATCH] add a tool to extract weights from tf models --- extract_tf_weights.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 extract_tf_weights.py diff --git a/extract_tf_weights.py b/extract_tf_weights.py new file mode 100644 index 00000000..a694b0fd --- /dev/null +++ b/extract_tf_weights.py @@ -0,0 +1,31 @@ +import tensorflow as tf +from tensorflow.python.platform import gfile +from tensorflow.python.framework import tensor_util +import sys +import pickle + + +def read_weights(frozen_model): + weights = {} + with tf.Session() as sess: + with gfile.FastGFile(frozen_model, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def) + for n in graph_def.node: + if n.op == 'Const': + weights[n.name] = tensor_util.MakeNdarray(n.attr['value'].tensor) + print("Name:", n.name, "Shape:", weights[n.name].shape) + return weights + + +if len(sys.argv) < 3: + print("Usage: python extract_tf_weights.py ") + +frozen_model = sys.argv[1] +weights_file = sys.argv[2] + +weights = read_weights(frozen_model) +with open(weights_file, "wb") as f: + pickle.dump(weights, f) + print(f"Saved weights to {weights_file}.") \ No newline at end of file