Selects from all variables now independent of graph, and uses standar… (#199)

* Smarter variable retrieval and doc update

* doc update and small fixes

* addressing robert's comments
This commit is contained in:
Wapaul1
2017-01-18 17:36:58 -08:00
committed by Robert Nishihara
parent 303d0fed3e
commit 6fe69bec11
3 changed files with 204 additions and 50 deletions
+26 -11
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from collections import deque, OrderedDict
def unflatten(vector, shapes):
i = 0
@@ -27,28 +28,42 @@ class TensorFlowVariables(object):
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
passed to.
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
prefix (Bool): Boolean for if there is a prefix on the variable names.
"""
def __init__(self, loss, sess=None):
def __init__(self, loss, sess=None, prefix=False):
"""Creates a TensorFlowVariables instance."""
import tensorflow as tf
self.sess = sess
self.loss = loss
variable_names = [op.node_def.name for op in loss.graph.get_operations() if op.node_def.op == "Variable"]
self.variables = [v for v in tf.trainable_variables() if v.op.node_def.name in variable_names]
self.prefix = prefix
queue = deque([loss])
variable_names = []
# We do a BFS on the dependency graph of the input function to find
# the variables.
while len(queue) != 0:
op = queue.popleft().op
queue.extend(op.inputs)
if op.node_def.op == "Variable":
variable_names.append(op.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
name = v.op.node_def.name.split("/", 1 if prefix else 0)[-1]
self.variables[name] = v
self.assignment_placeholders = dict()
self.assignment_nodes = []
# Create new placeholders to put in custom weights.
for var in self.variables:
self.assignment_placeholders[var.op.node_def.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.assignment_placeholders[var.op.node_def.name]))
for k, var in self.variables.items():
self.assignment_placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.assignment_placeholders[k]))
def set_session(self, sess):
"""Modifies the current session used by the class."""
self.sess = sess
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
@@ -57,20 +72,20 @@ class TensorFlowVariables(object):
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables]
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
placeholders = [self.assignment_placeholders[k] for k, v in self.variables.items()]
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
def get_weights(self):
"""Returns the weights of the variables of the loss function in a list."""
self._check_sess()
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
return {k: v.eval(session=self.sess) for k, v in self.variables.items()}
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""