mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 09:06:10 +08:00
Fixed lbfgs for ray-cluster (#180)
* Updated lbfgs example to include TensorflowVariables * Whitespace.
This commit is contained in:
@@ -1,6 +1,18 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
i = 0
|
||||
arrays = []
|
||||
for shape in shapes:
|
||||
size = np.prod(shape)
|
||||
array = vector[i:(i + size)].reshape(shape)
|
||||
arrays.append(array)
|
||||
i += size
|
||||
assert len(vector) == i, "Passed weight does not have the correct shape."
|
||||
return arrays
|
||||
|
||||
class TensorFlowVariables(object):
|
||||
"""An object used to extract variables from a loss function.
|
||||
@@ -35,12 +47,32 @@ class TensorFlowVariables(object):
|
||||
"""Modifies the current session used by the class."""
|
||||
self.sess = sess
|
||||
|
||||
def get_flat_size(self):
|
||||
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
|
||||
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array."""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array."""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
|
||||
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns the weights of the variables of the loss function in a list."""
|
||||
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
|
||||
self._check_sess()
|
||||
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
|
||||
self._check_sess()
|
||||
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})
|
||||
|
||||
Reference in New Issue
Block a user