mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 11:27:06 +08:00
[RLlib] Tf2x preparation; part 2 (upgrading try_import_tf()). (#9136)
* WIP. * Fixes. * LINT. * WIP. * WIP. * Fixes. * Fixes. * Fixes. * Fixes. * WIP. * Fixes. * Test * Fix. * Fixes and LINT. * Fixes and LINT. * LINT.
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
@@ -79,24 +79,29 @@ class TensorFlowVariables:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
variable_list = [
|
||||
v for v in tf.global_variables()
|
||||
v for v in tf1.global_variables()
|
||||
if v.op.node_def.name in variable_names
|
||||
]
|
||||
if input_variables is not None:
|
||||
variable_list += input_variables
|
||||
for v in variable_list:
|
||||
self.variables[v.op.node_def.name] = v
|
||||
|
||||
self.placeholders = {}
|
||||
self.assignment_nodes = {}
|
||||
if not tf1.executing_eagerly():
|
||||
for v in variable_list:
|
||||
self.variables[v.op.node_def.name] = v
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(
|
||||
var.value().dtype,
|
||||
var.get_shape().as_list(),
|
||||
name="Placeholder_" + k)
|
||||
self.assignment_nodes[k] = var.assign(self.placeholders[k])
|
||||
self.placeholders = {}
|
||||
self.assignment_nodes = {}
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf1.placeholder(
|
||||
var.value().dtype,
|
||||
var.get_shape().as_list(),
|
||||
name="Placeholder_" + k)
|
||||
self.assignment_nodes[k] = var.assign(self.placeholders[k])
|
||||
else:
|
||||
for v in variable_list:
|
||||
self.variables[v.name] = v
|
||||
|
||||
def set_session(self, sess):
|
||||
"""Sets the current session used by the class.
|
||||
@@ -117,10 +122,12 @@ class TensorFlowVariables:
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
assert self.sess is not None, ("The session is not set. Set the "
|
||||
"session either by passing it into the "
|
||||
"TensorFlowVariables constructor or by "
|
||||
"calling set_session(sess).")
|
||||
if tf1.executing_eagerly():
|
||||
return
|
||||
assert self.sess is not None, \
|
||||
"The session is not set. Set the session either by passing it " \
|
||||
"into the TensorFlowVariables constructor or by calling " \
|
||||
"set_session(sess)."
|
||||
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array.
|
||||
@@ -129,6 +136,11 @@ class TensorFlowVariables:
|
||||
1D Array containing the flattened weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
# Eager mode.
|
||||
if not self.sess:
|
||||
return np.concatenate(
|
||||
[v.numpy().flatten() for v in self.variables.values()])
|
||||
# Graph mode.
|
||||
return np.concatenate([
|
||||
v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()
|
||||
@@ -147,12 +159,16 @@ class TensorFlowVariables:
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [
|
||||
self.placeholders[k] for k, v in self.variables.items()
|
||||
]
|
||||
self.sess.run(
|
||||
list(self.assignment_nodes.values()),
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
if not self.sess:
|
||||
for v, a in zip(self.variables.values(), arrays):
|
||||
v.assign(a)
|
||||
else:
|
||||
placeholders = [
|
||||
self.placeholders[k] for k, v in self.variables.items()
|
||||
]
|
||||
self.sess.run(
|
||||
list(self.assignment_nodes.values()),
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns a dictionary containing the weights of the network.
|
||||
@@ -161,6 +177,10 @@ class TensorFlowVariables:
|
||||
Dictionary mapping variable names to their weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
# Eager mode.
|
||||
if not self.sess:
|
||||
return self.variables
|
||||
# Graph mode.
|
||||
return self.sess.run(self.variables)
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
|
||||
Reference in New Issue
Block a user