[RLlib] tf-eager support for ES and ARS (tf2.x preparation). (#9207)

This commit is contained in:
Sven Mika
2020-07-02 13:03:10 +02:00
committed by GitHub
parent 8a1cc7f8f9
commit c4ccbfdfa9
6 changed files with 99 additions and 71 deletions
+21 -37
View File
@@ -44,6 +44,7 @@ class TensorFlowVariables:
operation to extract all variables from.
sess (Optional[tf.Session]): Optional tf.Session used for running
the get and set methods in tf graph mode.
Use None for tf eager.
input_variables (List[tf.Variables]): Variables to include in the
list.
"""
@@ -103,14 +104,6 @@ class TensorFlowVariables:
for v in variable_list:
self.variables[v.name] = v
def set_session(self, sess):
"""Sets the current session used by the class.
Args:
sess (tf.Session): Session to set the attribute with.
"""
self.sess = sess
def get_flat_size(self):
"""Returns the total length of all of the flattened variables.
@@ -120,22 +113,12 @@ class TensorFlowVariables:
return sum(
np.prod(v.get_shape().as_list()) for v in self.variables.values())
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
if tf1.executing_eagerly():
return
assert self.sess is not None, \
"The session is not set. Set the session either by passing it " \
"into the TensorFlowVariables constructor or by calling " \
"set_session(sess)."
def get_flat(self):
"""Gets the weights and returns them as a flat array.
Returns:
1D Array containing the flattened weights.
"""
self._check_sess()
# Eager mode.
if not self.sess:
return np.concatenate(
@@ -156,7 +139,6 @@ class TensorFlowVariables:
Args:
new_weights (np.ndarray): Flat array containing weights.
"""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
if not self.sess:
@@ -176,7 +158,6 @@ class TensorFlowVariables:
Returns:
Dictionary mapping variable names to their weights.
"""
self._check_sess()
# Eager mode.
if not self.sess:
return self.variables
@@ -194,20 +175,23 @@ class TensorFlowVariables:
new_weights (Dict): Dictionary mapping variable names to their
weights.
"""
self._check_sess()
assign_list = [
self.assignment_nodes[name] for name in new_weights.keys()
if name in self.assignment_nodes
]
assert assign_list, ("No variables in the input matched those in the "
"network. Possible cause: Two networks were "
"defined in the same TensorFlow graph. To fix "
"this, place each network definition in its own "
"tf.Graph.")
self.sess.run(
assign_list,
feed_dict={
self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders
})
if self.sess is None:
for name, var in self.variables.items():
var.assign(new_weights[name])
else:
assign_list = [
self.assignment_nodes[name] for name in new_weights.keys()
if name in self.assignment_nodes
]
assert assign_list, \
"No variables in the input matched those in the network. " \
"Possible cause: Two networks were defined in the same " \
"TensorFlow graph. To fix this, place each network " \
"definition in its own tf.Graph."
self.sess.run(
assign_list,
feed_dict={
self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders
})
+2 -3
View File
@@ -122,10 +122,9 @@ def test_tensorflow_variables(ray_start_2_cpus):
variables2.set_flat(flat_weights)
assert_almost_equal(flat_weights, variables2.get_flat())
variables3 = ray.experimental.tf_utils.TensorFlowVariables([loss2])
assert variables3.sess is None
sess = tf.Session()
variables3.set_session(sess)
variables3 = ray.experimental.tf_utils.TensorFlowVariables(
[loss2], sess=sess)
assert variables3.sess == sess