mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:15:35 +08:00
[RLlib] tf-eager support for ES and ARS (tf2.x preparation). (#9207)
This commit is contained in:
@@ -44,6 +44,7 @@ class TensorFlowVariables:
|
||||
operation to extract all variables from.
|
||||
sess (Optional[tf.Session]): Optional tf.Session used for running
|
||||
the get and set methods in tf graph mode.
|
||||
Use None for tf eager.
|
||||
input_variables (List[tf.Variables]): Variables to include in the
|
||||
list.
|
||||
"""
|
||||
@@ -103,14 +104,6 @@ class TensorFlowVariables:
|
||||
for v in variable_list:
|
||||
self.variables[v.name] = v
|
||||
|
||||
def set_session(self, sess):
|
||||
"""Sets the current session used by the class.
|
||||
|
||||
Args:
|
||||
sess (tf.Session): Session to set the attribute with.
|
||||
"""
|
||||
self.sess = sess
|
||||
|
||||
def get_flat_size(self):
|
||||
"""Returns the total length of all of the flattened variables.
|
||||
|
||||
@@ -120,22 +113,12 @@ class TensorFlowVariables:
|
||||
return sum(
|
||||
np.prod(v.get_shape().as_list()) for v in self.variables.values())
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
if tf1.executing_eagerly():
|
||||
return
|
||||
assert self.sess is not None, \
|
||||
"The session is not set. Set the session either by passing it " \
|
||||
"into the TensorFlowVariables constructor or by calling " \
|
||||
"set_session(sess)."
|
||||
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array.
|
||||
|
||||
Returns:
|
||||
1D Array containing the flattened weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
# Eager mode.
|
||||
if not self.sess:
|
||||
return np.concatenate(
|
||||
@@ -156,7 +139,6 @@ class TensorFlowVariables:
|
||||
Args:
|
||||
new_weights (np.ndarray): Flat array containing weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
if not self.sess:
|
||||
@@ -176,7 +158,6 @@ class TensorFlowVariables:
|
||||
Returns:
|
||||
Dictionary mapping variable names to their weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
# Eager mode.
|
||||
if not self.sess:
|
||||
return self.variables
|
||||
@@ -194,20 +175,23 @@ class TensorFlowVariables:
|
||||
new_weights (Dict): Dictionary mapping variable names to their
|
||||
weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
assign_list = [
|
||||
self.assignment_nodes[name] for name in new_weights.keys()
|
||||
if name in self.assignment_nodes
|
||||
]
|
||||
assert assign_list, ("No variables in the input matched those in the "
|
||||
"network. Possible cause: Two networks were "
|
||||
"defined in the same TensorFlow graph. To fix "
|
||||
"this, place each network definition in its own "
|
||||
"tf.Graph.")
|
||||
self.sess.run(
|
||||
assign_list,
|
||||
feed_dict={
|
||||
self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders
|
||||
})
|
||||
if self.sess is None:
|
||||
for name, var in self.variables.items():
|
||||
var.assign(new_weights[name])
|
||||
else:
|
||||
assign_list = [
|
||||
self.assignment_nodes[name] for name in new_weights.keys()
|
||||
if name in self.assignment_nodes
|
||||
]
|
||||
assert assign_list, \
|
||||
"No variables in the input matched those in the network. " \
|
||||
"Possible cause: Two networks were defined in the same " \
|
||||
"TensorFlow graph. To fix this, place each network " \
|
||||
"definition in its own tf.Graph."
|
||||
self.sess.run(
|
||||
assign_list,
|
||||
feed_dict={
|
||||
self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders
|
||||
})
|
||||
|
||||
@@ -122,10 +122,9 @@ def test_tensorflow_variables(ray_start_2_cpus):
|
||||
variables2.set_flat(flat_weights)
|
||||
assert_almost_equal(flat_weights, variables2.get_flat())
|
||||
|
||||
variables3 = ray.experimental.tf_utils.TensorFlowVariables([loss2])
|
||||
assert variables3.sess is None
|
||||
sess = tf.Session()
|
||||
variables3.set_session(sess)
|
||||
variables3 = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
[loss2], sess=sess)
|
||||
assert variables3.sess == sess
|
||||
|
||||
|
||||
|
||||
@@ -9,11 +9,13 @@ import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.es.es_tf_policy import make_session
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.space_utils import unbatch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
tree = try_import_tree()
|
||||
|
||||
|
||||
class ARSTFPolicy:
|
||||
@@ -27,13 +29,17 @@ class ARSTFPolicy:
|
||||
self.preprocessor.shape)
|
||||
|
||||
self.single_threaded = config.get("single_threaded", False)
|
||||
self.sess = make_session(single_threaded=self.single_threaded)
|
||||
|
||||
self.inputs = tf1.placeholder(tf.float32,
|
||||
[None] + list(self.preprocessor.shape))
|
||||
if config["framework"] == "tf":
|
||||
self.sess = make_session(single_threaded=self.single_threaded)
|
||||
self.inputs = tf1.placeholder(
|
||||
tf.float32, [None] + list(self.preprocessor.shape))
|
||||
else:
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
self.sess = self.inputs = None
|
||||
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, config["model"], dist_type="deterministic")
|
||||
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
@@ -41,18 +47,22 @@ class ARSTFPolicy:
|
||||
action_space=self.action_space,
|
||||
num_outputs=dist_dim,
|
||||
model_config=config["model"])
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
|
||||
dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
dist_inputs, self.sess)
|
||||
self.sampler = None
|
||||
if self.sess:
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
|
||||
dist = self.dist_class(dist_inputs, self.model)
|
||||
self.sampler = dist.sample()
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
dist_inputs, self.sess)
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
else:
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
[], None, self.model.variables())
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
for _, variable in self.variables.variables.items())
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
|
||||
def compute_actions(self,
|
||||
observation,
|
||||
@@ -64,12 +74,23 @@ class ARSTFPolicy:
|
||||
observation = observation[0]
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
action = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
action = unbatch(action)
|
||||
|
||||
# `actions` is a list of (component) batches.
|
||||
# Eager mode.
|
||||
if not self.sess:
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation})
|
||||
dist = self.dist_class(dist_inputs, self.model)
|
||||
actions = dist.sample()
|
||||
actions = tree.map_structure(lambda a: a.numpy(), actions)
|
||||
# Graph mode.
|
||||
else:
|
||||
actions = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
|
||||
actions = unbatch(actions)
|
||||
if add_noise and isinstance(self.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
actions += np.random.randn(*actions.shape) * self.action_noise_std
|
||||
return actions
|
||||
|
||||
def compute_single_action(self,
|
||||
observation,
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestARS(unittest.TestCase):
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, ("tf", "torch")):
|
||||
for _ in framework_iterator(config):
|
||||
plain_config = config.copy()
|
||||
trainer = ars.ARSTrainer(config=plain_config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
@@ -25,6 +25,8 @@ class TestARS(unittest.TestCase):
|
||||
print(results)
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -76,29 +76,40 @@ class ESTFPolicy:
|
||||
self.observation_filter = get_filter(config["observation_filter"],
|
||||
self.preprocessor.shape)
|
||||
self.single_threaded = config.get("single_threaded", False)
|
||||
self.sess = make_session(single_threaded=self.single_threaded)
|
||||
self.inputs = tf1.placeholder(tf.float32,
|
||||
[None] + list(self.preprocessor.shape))
|
||||
if config["framework"] == "tf":
|
||||
self.sess = make_session(single_threaded=self.single_threaded)
|
||||
self.inputs = tf1.placeholder(
|
||||
tf.float32, [None] + list(self.preprocessor.shape))
|
||||
else:
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
self.sess = self.inputs = None
|
||||
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, config["model"], dist_type="deterministic")
|
||||
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
obs_space=self.preprocessor.observation_space,
|
||||
action_space=action_space,
|
||||
num_outputs=dist_dim,
|
||||
model_config=config["model"])
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
|
||||
dist = dist_class(dist_inputs, self.model)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
dist_inputs, self.sess)
|
||||
self.sampler = None
|
||||
if self.sess:
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
|
||||
dist = self.dist_class(dist_inputs, self.model)
|
||||
self.sampler = dist.sample()
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
dist_inputs, self.sess)
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
else:
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
[], None, self.model.variables())
|
||||
|
||||
self.num_params = sum(
|
||||
np.prod(variable.shape.as_list())
|
||||
for _, variable in self.variables.variables.items())
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
|
||||
def compute_actions(self,
|
||||
observation,
|
||||
@@ -111,8 +122,17 @@ class ESTFPolicy:
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
# `actions` is a list of (component) batches.
|
||||
actions = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
# Eager mode.
|
||||
if not self.sess:
|
||||
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation})
|
||||
dist = self.dist_class(dist_inputs, self.model)
|
||||
actions = dist.sample()
|
||||
actions = tree.map_structure(lambda a: a.numpy(), actions)
|
||||
# Graph mode.
|
||||
else:
|
||||
actions = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
|
||||
if add_noise:
|
||||
actions = tree.map_structure(self._add_noise, actions,
|
||||
self.action_space_struct)
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestES(unittest.TestCase):
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, ("tf", "torch")):
|
||||
for _ in framework_iterator(config):
|
||||
plain_config = config.copy()
|
||||
trainer = es.ESTrainer(config=plain_config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
@@ -26,6 +26,8 @@ class TestES(unittest.TestCase):
|
||||
print(results)
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user