[RLlib] Tf2x preparation; part 2 (upgrading try_import_tf()). (#9136)

* WIP.

* Fixes.

* LINT.

* WIP.

* WIP.

* Fixes.

* Fixes.

* Fixes.

* Fixes.

* WIP.

* Fixes.

* Test

* Fix.

* Fixes and LINT.

* Fixes and LINT.

* LINT.
This commit is contained in:
Sven Mika
2020-06-30 10:13:20 +02:00
committed by GitHub
parent fb074da7c3
commit 43043ee4d5
125 changed files with 617 additions and 584 deletions
+43 -23
View File
@@ -4,7 +4,7 @@ import numpy as np
from ray.rllib.utils import force_list
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
tf1, tf, tfv = try_import_tf()
def unflatten(vector, shapes):
@@ -79,24 +79,29 @@ class TensorFlowVariables:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
variable_list = [
v for v in tf.global_variables()
v for v in tf1.global_variables()
if v.op.node_def.name in variable_names
]
if input_variables is not None:
variable_list += input_variables
for v in variable_list:
self.variables[v.op.node_def.name] = v
self.placeholders = {}
self.assignment_nodes = {}
if not tf1.executing_eagerly():
for v in variable_list:
self.variables[v.op.node_def.name] = v
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(
var.value().dtype,
var.get_shape().as_list(),
name="Placeholder_" + k)
self.assignment_nodes[k] = var.assign(self.placeholders[k])
self.placeholders = {}
self.assignment_nodes = {}
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf1.placeholder(
var.value().dtype,
var.get_shape().as_list(),
name="Placeholder_" + k)
self.assignment_nodes[k] = var.assign(self.placeholders[k])
else:
for v in variable_list:
self.variables[v.name] = v
def set_session(self, sess):
"""Sets the current session used by the class.
@@ -117,10 +122,12 @@ class TensorFlowVariables:
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, ("The session is not set. Set the "
"session either by passing it into the "
"TensorFlowVariables constructor or by "
"calling set_session(sess).")
if tf1.executing_eagerly():
return
assert self.sess is not None, \
"The session is not set. Set the session either by passing it " \
"into the TensorFlowVariables constructor or by calling " \
"set_session(sess)."
def get_flat(self):
"""Gets the weights and returns them as a flat array.
@@ -129,6 +136,11 @@ class TensorFlowVariables:
1D Array containing the flattened weights.
"""
self._check_sess()
# Eager mode.
if not self.sess:
return np.concatenate(
[v.numpy().flatten() for v in self.variables.values()])
# Graph mode.
return np.concatenate([
v.eval(session=self.sess).flatten()
for v in self.variables.values()
@@ -147,12 +159,16 @@ class TensorFlowVariables:
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [
self.placeholders[k] for k, v in self.variables.items()
]
self.sess.run(
list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
if not self.sess:
for v, a in zip(self.variables.values(), arrays):
v.assign(a)
else:
placeholders = [
self.placeholders[k] for k, v in self.variables.items()
]
self.sess.run(
list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns a dictionary containing the weights of the network.
@@ -161,6 +177,10 @@ class TensorFlowVariables:
Dictionary mapping variable names to their weights.
"""
self._check_sess()
# Eager mode.
if not self.sess:
return self.variables
# Graph mode.
return self.sess.run(self.variables)
def set_weights(self, new_weights):