mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
Selects from all variables now independent of graph, and uses standar… (#199)
* Smarter variable retrieval and doc update * doc update and small fixes * addressing robert's comments
This commit is contained in:
committed by
Robert Nishihara
parent
303d0fed3e
commit
6fe69bec11
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from collections import deque, OrderedDict
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
i = 0
|
||||
@@ -27,28 +28,42 @@ class TensorFlowVariables(object):
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
|
||||
passed to.
|
||||
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
|
||||
prefix (Bool): Boolean for if there is a prefix on the variable names.
|
||||
"""
|
||||
def __init__(self, loss, sess=None):
|
||||
def __init__(self, loss, sess=None, prefix=False):
|
||||
"""Creates a TensorFlowVariables instance."""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
self.loss = loss
|
||||
variable_names = [op.node_def.name for op in loss.graph.get_operations() if op.node_def.op == "Variable"]
|
||||
self.variables = [v for v in tf.trainable_variables() if v.op.node_def.name in variable_names]
|
||||
self.prefix = prefix
|
||||
queue = deque([loss])
|
||||
variable_names = []
|
||||
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# the variables.
|
||||
while len(queue) != 0:
|
||||
op = queue.popleft().op
|
||||
queue.extend(op.inputs)
|
||||
if op.node_def.op == "Variable":
|
||||
variable_names.append(op.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
|
||||
name = v.op.node_def.name.split("/", 1 if prefix else 0)[-1]
|
||||
self.variables[name] = v
|
||||
self.assignment_placeholders = dict()
|
||||
self.assignment_nodes = []
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for var in self.variables:
|
||||
self.assignment_placeholders[var.op.node_def.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.assignment_placeholders[var.op.node_def.name]))
|
||||
for k, var in self.variables.items():
|
||||
self.assignment_placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.assignment_placeholders[k]))
|
||||
|
||||
def set_session(self, sess):
|
||||
"""Modifies the current session used by the class."""
|
||||
self.sess = sess
|
||||
|
||||
def get_flat_size(self):
|
||||
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
|
||||
return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()])
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
@@ -57,20 +72,20 @@ class TensorFlowVariables(object):
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array."""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
|
||||
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array."""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables]
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
|
||||
placeholders = [self.assignment_placeholders[k] for k, v in self.variables.items()]
|
||||
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns the weights of the variables of the loss function in a list."""
|
||||
self._check_sess()
|
||||
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
|
||||
return {k: v.eval(session=self.sess) for k, v in self.variables.items()}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
|
||||
Reference in New Issue
Block a user