Fixed lbfgs for ray-cluster (#180)

* Updated lbfgs example to include TensorflowVariables

* Whitespace.
This commit is contained in:
Wapaul1
2017-01-10 18:40:06 -08:00
committed by Philipp Moritz
parent be4a37bf37
commit aaf3be3c53
3 changed files with 141 additions and 97 deletions
+34 -2
View File
@@ -1,6 +1,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def unflatten(vector, shapes):
i = 0
arrays = []
for shape in shapes:
size = np.prod(shape)
array = vector[i:(i + size)].reshape(shape)
arrays.append(array)
i += size
assert len(vector) == i, "Passed weight does not have the correct shape."
return arrays
class TensorFlowVariables(object):
"""An object used to extract variables from a loss function.
@@ -35,12 +47,32 @@ class TensorFlowVariables(object):
"""Modifies the current session used by the class."""
self.sess = sess
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables]
arrays = unflatten(new_weights, shapes)
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
def get_weights(self):
"""Returns the weights of the variables of the loss function in a list."""
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
self._check_sess()
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
self._check_sess()
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})