Skip to content

Commit

Permalink
clean the unused control flow input nodes from the graph (tensorflow#…
Browse files Browse the repository at this point in the history
…2287)

BUG
* clean the unused control flow input nodes from the graph

* increase tfjs-converter cloudbuild disk size

* use high cpu cloud build type
  • Loading branch information
pyu10055 authored Oct 30, 2019
1 parent 9a1e1b0 commit 9e0bdb5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions tfjs-converter/cloudbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ substitutions:
options:
logStreamingOption: 'STREAM_ON'
substitution_option: 'ALLOW_LOOSE'
machineType: 'N1_HIGHCPU_8'
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
Expand Down Expand Up @@ -166,6 +168,7 @@ def optimize_graph(graph, signature_def, output_graph,
]

optimized_graph = _run_grappler(config, optimized_graph, graph, signature_def)
optimized_graph = _remove_unused_control_flow_inputs(optimized_graph)

# Because TF break the Prelu op into 6 ops, for performance we are
# fusing those ops into a single prelu
Expand Down Expand Up @@ -268,6 +271,17 @@ def write_artifacts(topology,
with open(output_graph, 'wt') as f:
json.dump(model_json, f)

def _remove_unused_control_flow_inputs(input_graph_def):
result_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if (node.op == 'Placeholder' and
node.name.startswith('unused_control_flow_input')):
continue
new_node = node_def_pb2.NodeDef()
new_node.CopyFrom(node)
result_graph_def.node.extend([new_node])

return result_graph_def

def _check_signature_in_model(saved_model, signature_name):
if signature_name not in saved_model.signatures:
Expand Down

0 comments on commit 9e0bdb5

Please sign in to comment.