Polished TensorFlowVariables code and documentation (#566)

This commit is contained in:
William Paul
2018-02-12 15:38:58 -08:00
committed by Robert Nishihara
parent ca0f08d100
commit f2b6a7b58d
3 changed files with 210 additions and 43 deletions
+89 -31
View File
@@ -18,24 +18,34 @@ def unflatten(vector, shapes):
class TensorFlowVariables(object):
"""An object used to extract variables from a loss function.
This object also provides methods for getting and setting the weights of
the relevant variables.
"""A class used to set and get weights for Tensorflow networks.
Attributes:
sess (tf.Session): The tensorflow session used to run assignment.
loss: The loss function passed in by the user.
variables (List[tf.Variable]): Extracted variables from the loss.
assignment_placeholders (List[tf.placeholders]): The nodes that weights
get passed to.
assignment _nodes (List[tf.Tensor]): The nodes that assign the weights.
variables (Dict[str, tf.Variable]): Extracted variables from the loss
or additional variables that are passed in.
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
"""
def __init__(self, loss, sess=None):
"""Creates a TensorFlowVariables instance."""
def __init__(self, loss, sess=None, input_variables=None):
"""Creates TensorFlowVariables containing extracted variables.
The variables are extracted by performing a BFS search on the
dependency graph with loss as the root node. After the tree is
traversed and those variables are collected, we append input_variables
to the collected variables. For each variable in the list, the
variable has a placeholder and assignment operation created for it.
Args:
loss (tf.Operation): The tensorflow operation to extract all
variables from.
sess (tf.Session): Session used for running the get and set
methods.
input_variables (List[tf.Variables]): Variables to include in the
list.
"""
import tensorflow as tf
self.sess = sess
self.loss = loss
queue = deque([loss])
variable_names = []
explored_inputs = set([loss])
@@ -44,9 +54,10 @@ class TensorFlowVariables(object):
# the variables.
while len(queue) != 0:
tf_obj = queue.popleft()
# The object put into the queue is not necessarily an operation, so
# we want the op attribute to get the operation underlying the
if tf_obj is None:
continue
# The object put into the queue is not necessarily an operation,
# so we want the op attribute to get the operation underlying the
# object. Only operations contain the inputs that we can explore.
if hasattr(tf_obj, "op"):
tf_obj = tf_obj.op
@@ -63,23 +74,37 @@ class TensorFlowVariables(object):
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]:
variable_list = [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]
if input_variables is not None:
variable_list += input_variables
for v in variable_list:
self.variables[v.op.node_def.name] = v
self.placeholders = dict()
self.assignment_nodes = []
self.assignment_nodes = dict()
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.placeholders[k]))
var.get_shape().as_list(),
name="Placeholder_" + k)
self.assignment_nodes[k] = var.assign(self.placeholders[k])
def set_session(self, sess):
"""Modifies the current session used by the class."""
"""Sets the current session used by the class.
Args:
sess (tf.Session): Session to set the attribute with.
"""
self.sess = sess
def get_flat_size(self):
"""Returns the total length of all of the flattened variables.
Returns:
The length of all flattened variables concatenated.
"""
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
@@ -91,31 +116,64 @@ class TensorFlowVariables(object):
"calling set_session(sess).")
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
"""Gets the weights and returns them as a flat array.
Returns:
1D Array containing the flattened weights.
"""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
"""Sets the weights to new_weights, converting from a flat array.
Note:
You can only set all weights in the network using this function,
i.e., the length of the array must match get_flat_size.
Args:
new_weights (np.ndarray): Flat array containing weights.
"""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k]
for k, v in self.variables.items()]
self.sess.run(self.assignment_nodes,
placeholders = [self.placeholders[k] for k, v
in self.variables.items()]
self.sess.run(list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns a list of the weights of the loss function variables."""
"""Returns a dictionary containing the weights of the network.
Returns:
Dictionary mapping variable names to their weights.
"""
self._check_sess()
return {k: v.eval(session=self.sess)
for k, v in self.variables.items()}
return {k: v.eval(session=self.sess) for k, v
in self.variables.items()}
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
"""Sets the weights to new_weights.
Note:
Can set subsets of variables as well, by only passing in the
variables you want to be set.
Args:
new_weights (Dict): Dictionary mapping variable names to their
weights.
"""
self._check_sess()
self.sess.run(self.assignment_nodes,
assign_list = [self.assignment_nodes[name]
for name in new_weights.keys()
if name in self.assignment_nodes]
assert assign_list, ("No variables in the input matched those in the "
"network. Possible cause: Two networks were "
"defined in the same TensorFlow graph. To fix "
"this, place each network definition in its own "
"tf.Graph.")
self.sess.run(assign_list,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})