mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 06:08:03 +08:00
Polished TensorFlowVariables code and documentation (#566)
This commit is contained in:
committed by
Robert Nishihara
parent
ca0f08d100
commit
f2b6a7b58d
@@ -18,24 +18,34 @@ def unflatten(vector, shapes):
|
||||
|
||||
|
||||
class TensorFlowVariables(object):
|
||||
"""An object used to extract variables from a loss function.
|
||||
|
||||
This object also provides methods for getting and setting the weights of
|
||||
the relevant variables.
|
||||
"""A class used to set and get weights for Tensorflow networks.
|
||||
|
||||
Attributes:
|
||||
sess (tf.Session): The tensorflow session used to run assignment.
|
||||
loss: The loss function passed in by the user.
|
||||
variables (List[tf.Variable]): Extracted variables from the loss.
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights
|
||||
get passed to.
|
||||
assignment _nodes (List[tf.Tensor]): The nodes that assign the weights.
|
||||
variables (Dict[str, tf.Variable]): Extracted variables from the loss
|
||||
or additional variables that are passed in.
|
||||
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
|
||||
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
|
||||
"""
|
||||
def __init__(self, loss, sess=None):
|
||||
"""Creates a TensorFlowVariables instance."""
|
||||
def __init__(self, loss, sess=None, input_variables=None):
|
||||
"""Creates TensorFlowVariables containing extracted variables.
|
||||
|
||||
The variables are extracted by performing a BFS search on the
|
||||
dependency graph with loss as the root node. After the tree is
|
||||
traversed and those variables are collected, we append input_variables
|
||||
to the collected variables. For each variable in the list, the
|
||||
variable has a placeholder and assignment operation created for it.
|
||||
|
||||
Args:
|
||||
loss (tf.Operation): The tensorflow operation to extract all
|
||||
variables from.
|
||||
sess (tf.Session): Session used for running the get and set
|
||||
methods.
|
||||
input_variables (List[tf.Variables]): Variables to include in the
|
||||
list.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
self.loss = loss
|
||||
queue = deque([loss])
|
||||
variable_names = []
|
||||
explored_inputs = set([loss])
|
||||
@@ -44,9 +54,10 @@ class TensorFlowVariables(object):
|
||||
# the variables.
|
||||
while len(queue) != 0:
|
||||
tf_obj = queue.popleft()
|
||||
|
||||
# The object put into the queue is not necessarily an operation, so
|
||||
# we want the op attribute to get the operation underlying the
|
||||
if tf_obj is None:
|
||||
continue
|
||||
# The object put into the queue is not necessarily an operation,
|
||||
# so we want the op attribute to get the operation underlying the
|
||||
# object. Only operations contain the inputs that we can explore.
|
||||
if hasattr(tf_obj, "op"):
|
||||
tf_obj = tf_obj.op
|
||||
@@ -63,23 +74,37 @@ class TensorFlowVariables(object):
|
||||
if "Variable" in tf_obj.node_def.op:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
for v in [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]:
|
||||
variable_list = [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]
|
||||
if input_variables is not None:
|
||||
variable_list += input_variables
|
||||
for v in variable_list:
|
||||
self.variables[v.op.node_def.name] = v
|
||||
|
||||
self.placeholders = dict()
|
||||
self.assignment_nodes = []
|
||||
self.assignment_nodes = dict()
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype,
|
||||
var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.placeholders[k]))
|
||||
var.get_shape().as_list(),
|
||||
name="Placeholder_" + k)
|
||||
self.assignment_nodes[k] = var.assign(self.placeholders[k])
|
||||
|
||||
def set_session(self, sess):
|
||||
"""Modifies the current session used by the class."""
|
||||
"""Sets the current session used by the class.
|
||||
|
||||
Args:
|
||||
sess (tf.Session): Session to set the attribute with.
|
||||
"""
|
||||
self.sess = sess
|
||||
|
||||
def get_flat_size(self):
|
||||
"""Returns the total length of all of the flattened variables.
|
||||
|
||||
Returns:
|
||||
The length of all flattened variables concatenated.
|
||||
"""
|
||||
return sum([np.prod(v.get_shape().as_list())
|
||||
for v in self.variables.values()])
|
||||
|
||||
@@ -91,31 +116,64 @@ class TensorFlowVariables(object):
|
||||
"calling set_session(sess).")
|
||||
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array."""
|
||||
"""Gets the weights and returns them as a flat array.
|
||||
|
||||
Returns:
|
||||
1D Array containing the flattened weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array."""
|
||||
"""Sets the weights to new_weights, converting from a flat array.
|
||||
|
||||
Note:
|
||||
You can only set all weights in the network using this function,
|
||||
i.e., the length of the array must match get_flat_size.
|
||||
|
||||
Args:
|
||||
new_weights (np.ndarray): Flat array containing weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.placeholders[k]
|
||||
for k, v in self.variables.items()]
|
||||
self.sess.run(self.assignment_nodes,
|
||||
placeholders = [self.placeholders[k] for k, v
|
||||
in self.variables.items()]
|
||||
self.sess.run(list(self.assignment_nodes.values()),
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns a list of the weights of the loss function variables."""
|
||||
"""Returns a dictionary containing the weights of the network.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping variable names to their weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
return {k: v.eval(session=self.sess)
|
||||
for k, v in self.variables.items()}
|
||||
return {k: v.eval(session=self.sess) for k, v
|
||||
in self.variables.items()}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
"""Sets the weights to new_weights.
|
||||
|
||||
Note:
|
||||
Can set subsets of variables as well, by only passing in the
|
||||
variables you want to be set.
|
||||
|
||||
Args:
|
||||
new_weights (Dict): Dictionary mapping variable names to their
|
||||
weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
self.sess.run(self.assignment_nodes,
|
||||
assign_list = [self.assignment_nodes[name]
|
||||
for name in new_weights.keys()
|
||||
if name in self.assignment_nodes]
|
||||
assert assign_list, ("No variables in the input matched those in the "
|
||||
"network. Possible cause: Two networks were "
|
||||
"defined in the same TensorFlow graph. To fix "
|
||||
"this, place each network definition in its own "
|
||||
"tf.Graph.")
|
||||
self.sess.run(assign_list,
|
||||
feed_dict={self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders})
|
||||
|
||||
Reference in New Issue
Block a user