mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 14:42:31 +08:00
Added functionality for retrieving variables from control dependencies (#220)
* Added test for retriving variables from an optimizer * Added comments to test * Addressed comments * Fixed travis bug * Added fix to circular controls * Added set for explored operations and duplicate prefix stripping * Removed embeded ipython * Removed prefix, use seperate graph for each network * Removed redundant imports * Addressed comments and added separate graph to initializer * fix typos * get rid of prefix in documentation
This commit is contained in:
@@ -28,28 +28,41 @@ class TensorFlowVariables(object):
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
|
||||
passed to.
|
||||
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
|
||||
prefix (Bool): Boolean for if there is a prefix on the variable names.
|
||||
"""
|
||||
def __init__(self, loss, sess=None, prefix=False):
|
||||
def __init__(self, loss, sess=None):
|
||||
"""Creates a TensorFlowVariables instance."""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
self.loss = loss
|
||||
self.prefix = prefix
|
||||
queue = deque([loss])
|
||||
variable_names = []
|
||||
explored_inputs = set([loss])
|
||||
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# the variables.
|
||||
while len(queue) != 0:
|
||||
op = queue.popleft().op
|
||||
queue.extend(op.inputs)
|
||||
if op.node_def.op == "Variable":
|
||||
variable_names.append(op.node_def.name)
|
||||
tf_obj = queue.popleft()
|
||||
|
||||
# The object put into the queue is not necessarily an operation, so we
|
||||
# want the op attribute to get the operation underlying the object.
|
||||
# Only operations contain the inputs that we can explore.
|
||||
if hasattr(tf_obj, "op"):
|
||||
tf_obj = tf_obj.op
|
||||
for input_op in tf_obj.inputs:
|
||||
if input_op not in explored_inputs:
|
||||
queue.append(input_op)
|
||||
explored_inputs.add(input_op)
|
||||
# Tensorflow control inputs can be circular, so we keep track of
|
||||
# explored operations.
|
||||
for control in tf_obj.control_inputs:
|
||||
if control not in explored_inputs:
|
||||
queue.append(control)
|
||||
explored_inputs.add(control)
|
||||
if tf_obj.node_def.op == "Variable":
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
|
||||
name = v.op.node_def.name.split("/", 1 if prefix else 0)[-1]
|
||||
self.variables[name] = v
|
||||
self.variables[v.op.node_def.name] = v
|
||||
self.assignment_placeholders = dict()
|
||||
self.assignment_nodes = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user