mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 09:24:28 +08:00
Support constructing TensorFlowVariables from multiple tf operations (#2182)
This commit is contained in:
committed by
Robert Nishihara
parent
d699bfbf10
commit
19d6ca0670
@@ -28,7 +28,7 @@ class TensorFlowVariables(object):
|
||||
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
|
||||
"""
|
||||
|
||||
def __init__(self, loss, sess=None, input_variables=None):
|
||||
def __init__(self, output, sess=None, input_variables=None):
|
||||
"""Creates TensorFlowVariables containing extracted variables.
|
||||
|
||||
The variables are extracted by performing a BFS search on the
|
||||
@@ -38,8 +38,8 @@ class TensorFlowVariables(object):
|
||||
variable has a placeholder and assignment operation created for it.
|
||||
|
||||
Args:
|
||||
loss (tf.Operation): The tensorflow operation to extract all
|
||||
variables from.
|
||||
output (tf.Operation, List[tf.Operation]): The tensorflow
|
||||
operation to extract all variables from.
|
||||
sess (tf.Session): Session used for running the get and set
|
||||
methods.
|
||||
input_variables (List[tf.Variables]): Variables to include in the
|
||||
@@ -47,9 +47,11 @@ class TensorFlowVariables(object):
|
||||
"""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
queue = deque([loss])
|
||||
if not isinstance(output, (list, tuple)):
|
||||
output = [output]
|
||||
queue = deque(output)
|
||||
variable_names = []
|
||||
explored_inputs = {loss}
|
||||
explored_inputs = set(output)
|
||||
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# the variables.
|
||||
|
||||
Reference in New Issue
Block a user