Added functionality for retrieving variables from control dependencies (#220)

* Added test for retriving variables from an optimizer

* Added comments to test

* Addressed comments

* Fixed travis bug

* Added fix to circular controls

* Added set for explored operations and duplicate prefix stripping

* Removed embeded ipython

* Removed prefix, use seperate graph for each network

* Removed redundant imports

* Addressed comments and added separate graph to initializer

* fix typos

* get rid of prefix in documentation
This commit is contained in:
Wapaul1
2017-01-30 19:17:42 -08:00
committed by Philipp Moritz
parent 6703f7be6f
commit db7297865f
4 changed files with 84 additions and 36 deletions
+22 -9
View File
@@ -28,28 +28,41 @@ class TensorFlowVariables(object):
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
passed to.
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
prefix (Bool): Boolean for if there is a prefix on the variable names.
"""
def __init__(self, loss, sess=None, prefix=False):
def __init__(self, loss, sess=None):
"""Creates a TensorFlowVariables instance."""
import tensorflow as tf
self.sess = sess
self.loss = loss
self.prefix = prefix
queue = deque([loss])
variable_names = []
explored_inputs = set([loss])
# We do a BFS on the dependency graph of the input function to find
# the variables.
while len(queue) != 0:
op = queue.popleft().op
queue.extend(op.inputs)
if op.node_def.op == "Variable":
variable_names.append(op.node_def.name)
tf_obj = queue.popleft()
# The object put into the queue is not necessarily an operation, so we
# want the op attribute to get the operation underlying the object.
# Only operations contain the inputs that we can explore.
if hasattr(tf_obj, "op"):
tf_obj = tf_obj.op
for input_op in tf_obj.inputs:
if input_op not in explored_inputs:
queue.append(input_op)
explored_inputs.add(input_op)
# Tensorflow control inputs can be circular, so we keep track of
# explored operations.
for control in tf_obj.control_inputs:
if control not in explored_inputs:
queue.append(control)
explored_inputs.add(control)
if tf_obj.node_def.op == "Variable":
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
name = v.op.node_def.name.split("/", 1 if prefix else 0)[-1]
self.variables[name] = v
self.variables[v.op.node_def.name] = v
self.assignment_placeholders = dict()
self.assignment_nodes = []