diff --git a/python/ray/experimental/tfutils.py b/python/ray/experimental/tfutils.py index f52a94682..e33b33cad 100644 --- a/python/ray/experimental/tfutils.py +++ b/python/ray/experimental/tfutils.py @@ -28,7 +28,7 @@ class TensorFlowVariables(object): assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights. """ - def __init__(self, loss, sess=None, input_variables=None): + def __init__(self, output, sess=None, input_variables=None): """Creates TensorFlowVariables containing extracted variables. The variables are extracted by performing a BFS search on the @@ -38,8 +38,8 @@ class TensorFlowVariables(object): variable has a placeholder and assignment operation created for it. Args: - loss (tf.Operation): The tensorflow operation to extract all - variables from. + output (tf.Operation, List[tf.Operation]): The tensorflow + operation to extract all variables from. sess (tf.Session): Session used for running the get and set methods. input_variables (List[tf.Variables]): Variables to include in the @@ -47,9 +47,11 @@ class TensorFlowVariables(object): """ import tensorflow as tf self.sess = sess - queue = deque([loss]) + if not isinstance(output, (list, tuple)): + output = [output] + queue = deque(output) variable_names = [] - explored_inputs = {loss} + explored_inputs = set(output) # We do a BFS on the dependency graph of the input function to find # the variables. diff --git a/test/tensorflow_test.py b/test/tensorflow_test.py index 192f355e9..ca3f92299 100644 --- a/test/tensorflow_test.py +++ b/test/tensorflow_test.py @@ -128,7 +128,7 @@ class TensorFlowTest(unittest.TestCase): variables2.set_flat(flat_weights) assert_almost_equal(flat_weights, variables2.get_flat()) - variables3 = ray.experimental.TensorFlowVariables(loss2) + variables3 = ray.experimental.TensorFlowVariables([loss2]) self.assertEqual(variables3.sess, None) sess = tf.Session() variables3.set_session(sess)