Support constructing TensorFlowVariables from multiple tf operations (#2182)

This commit is contained in:
Binglin Chang
2018-06-03 09:13:52 +08:00
committed by Robert Nishihara
parent d699bfbf10
commit 19d6ca0670
2 changed files with 8 additions and 6 deletions
+7 -5
View File
@@ -28,7 +28,7 @@ class TensorFlowVariables(object):
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
"""
def __init__(self, loss, sess=None, input_variables=None):
def __init__(self, output, sess=None, input_variables=None):
"""Creates TensorFlowVariables containing extracted variables.
The variables are extracted by performing a BFS search on the
@@ -38,8 +38,8 @@ class TensorFlowVariables(object):
variable has a placeholder and assignment operation created for it.
Args:
loss (tf.Operation): The tensorflow operation to extract all
variables from.
output (tf.Operation, List[tf.Operation]): The tensorflow
operation to extract all variables from.
sess (tf.Session): Session used for running the get and set
methods.
input_variables (List[tf.Variables]): Variables to include in the
@@ -47,9 +47,11 @@ class TensorFlowVariables(object):
"""
import tensorflow as tf
self.sess = sess
queue = deque([loss])
if not isinstance(output, (list, tuple)):
output = [output]
queue = deque(output)
variable_names = []
explored_inputs = {loss}
explored_inputs = set(output)
# We do a BFS on the dependency graph of the input function to find
# the variables.