From c4ccbfdfa96ffddeeb170b7cb79b80c7e1a6c9b4 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 2 Jul 2020 13:03:10 +0200 Subject: [PATCH] [RLlib] tf-eager support for ES and ARS (tf2.x preparation). (#9207) --- python/ray/experimental/tf_utils.py | 58 +++++++++++------------------ python/ray/tests/test_tensorflow.py | 5 +-- rllib/agents/ars/ars_tf_policy.py | 55 ++++++++++++++++++--------- rllib/agents/ars/tests/test_ars.py | 4 +- rllib/agents/es/es_tf_policy.py | 44 ++++++++++++++++------ rllib/agents/es/tests/test_es.py | 4 +- 6 files changed, 99 insertions(+), 71 deletions(-) diff --git a/python/ray/experimental/tf_utils.py b/python/ray/experimental/tf_utils.py index 6677161a4..0490f3400 100644 --- a/python/ray/experimental/tf_utils.py +++ b/python/ray/experimental/tf_utils.py @@ -44,6 +44,7 @@ class TensorFlowVariables: operation to extract all variables from. sess (Optional[tf.Session]): Optional tf.Session used for running the get and set methods in tf graph mode. + Use None for tf eager. input_variables (List[tf.Variables]): Variables to include in the list. """ @@ -103,14 +104,6 @@ class TensorFlowVariables: for v in variable_list: self.variables[v.name] = v - def set_session(self, sess): - """Sets the current session used by the class. - - Args: - sess (tf.Session): Session to set the attribute with. - """ - self.sess = sess - def get_flat_size(self): """Returns the total length of all of the flattened variables. @@ -120,22 +113,12 @@ class TensorFlowVariables: return sum( np.prod(v.get_shape().as_list()) for v in self.variables.values()) - def _check_sess(self): - """Checks if the session is set, and if not throw an error message.""" - if tf1.executing_eagerly(): - return - assert self.sess is not None, \ - "The session is not set. Set the session either by passing it " \ - "into the TensorFlowVariables constructor or by calling " \ - "set_session(sess)." - def get_flat(self): """Gets the weights and returns them as a flat array. Returns: 1D Array containing the flattened weights. """ - self._check_sess() # Eager mode. if not self.sess: return np.concatenate( @@ -156,7 +139,6 @@ class TensorFlowVariables: Args: new_weights (np.ndarray): Flat array containing weights. """ - self._check_sess() shapes = [v.get_shape().as_list() for v in self.variables.values()] arrays = unflatten(new_weights, shapes) if not self.sess: @@ -176,7 +158,6 @@ class TensorFlowVariables: Returns: Dictionary mapping variable names to their weights. """ - self._check_sess() # Eager mode. if not self.sess: return self.variables @@ -194,20 +175,23 @@ class TensorFlowVariables: new_weights (Dict): Dictionary mapping variable names to their weights. """ - self._check_sess() - assign_list = [ - self.assignment_nodes[name] for name in new_weights.keys() - if name in self.assignment_nodes - ] - assert assign_list, ("No variables in the input matched those in the " - "network. Possible cause: Two networks were " - "defined in the same TensorFlow graph. To fix " - "this, place each network definition in its own " - "tf.Graph.") - self.sess.run( - assign_list, - feed_dict={ - self.placeholders[name]: value - for (name, value) in new_weights.items() - if name in self.placeholders - }) + if self.sess is None: + for name, var in self.variables.items(): + var.assign(new_weights[name]) + else: + assign_list = [ + self.assignment_nodes[name] for name in new_weights.keys() + if name in self.assignment_nodes + ] + assert assign_list, \ + "No variables in the input matched those in the network. " \ + "Possible cause: Two networks were defined in the same " \ + "TensorFlow graph. To fix this, place each network " \ + "definition in its own tf.Graph." + self.sess.run( + assign_list, + feed_dict={ + self.placeholders[name]: value + for (name, value) in new_weights.items() + if name in self.placeholders + }) diff --git a/python/ray/tests/test_tensorflow.py b/python/ray/tests/test_tensorflow.py index ce4834992..705bf833b 100644 --- a/python/ray/tests/test_tensorflow.py +++ b/python/ray/tests/test_tensorflow.py @@ -122,10 +122,9 @@ def test_tensorflow_variables(ray_start_2_cpus): variables2.set_flat(flat_weights) assert_almost_equal(flat_weights, variables2.get_flat()) - variables3 = ray.experimental.tf_utils.TensorFlowVariables([loss2]) - assert variables3.sess is None sess = tf.Session() - variables3.set_session(sess) + variables3 = ray.experimental.tf_utils.TensorFlowVariables( + [loss2], sess=sess) assert variables3.sess == sess diff --git a/rllib/agents/ars/ars_tf_policy.py b/rllib/agents/ars/ars_tf_policy.py index 6c2f38022..f3a321770 100644 --- a/rllib/agents/ars/ars_tf_policy.py +++ b/rllib/agents/ars/ars_tf_policy.py @@ -9,11 +9,13 @@ import ray.experimental.tf_utils from ray.rllib.agents.es.es_tf_policy import make_session from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import try_import_tree from ray.rllib.utils.filter import get_filter from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.space_utils import unbatch tf1, tf, tfv = try_import_tf() +tree = try_import_tree() class ARSTFPolicy: @@ -27,13 +29,17 @@ class ARSTFPolicy: self.preprocessor.shape) self.single_threaded = config.get("single_threaded", False) - self.sess = make_session(single_threaded=self.single_threaded) - - self.inputs = tf1.placeholder(tf.float32, - [None] + list(self.preprocessor.shape)) + if config["framework"] == "tf": + self.sess = make_session(single_threaded=self.single_threaded) + self.inputs = tf1.placeholder( + tf.float32, [None] + list(self.preprocessor.shape)) + else: + if not tf1.executing_eagerly(): + tf1.enable_eager_execution() + self.sess = self.inputs = None # Policy network. - dist_class, dist_dim = ModelCatalog.get_action_dist( + self.dist_class, dist_dim = ModelCatalog.get_action_dist( self.action_space, config["model"], dist_type="deterministic") self.model = ModelCatalog.get_model_v2( @@ -41,18 +47,22 @@ class ARSTFPolicy: action_space=self.action_space, num_outputs=dist_dim, model_config=config["model"]) - dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs}) - dist = dist_class(dist_inputs, self.model) - self.sampler = dist.sample() - - self.variables = ray.experimental.tf_utils.TensorFlowVariables( - dist_inputs, self.sess) + self.sampler = None + if self.sess: + dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs}) + dist = self.dist_class(dist_inputs, self.model) + self.sampler = dist.sample() + self.variables = ray.experimental.tf_utils.TensorFlowVariables( + dist_inputs, self.sess) + self.sess.run(tf1.global_variables_initializer()) + else: + self.variables = ray.experimental.tf_utils.TensorFlowVariables( + [], None, self.model.variables()) self.num_params = sum( np.prod(variable.shape.as_list()) for _, variable in self.variables.variables.items()) - self.sess.run(tf1.global_variables_initializer()) def compute_actions(self, observation, @@ -64,12 +74,23 @@ class ARSTFPolicy: observation = observation[0] observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) - action = self.sess.run( - self.sampler, feed_dict={self.inputs: observation}) - action = unbatch(action) + + # `actions` is a list of (component) batches. + # Eager mode. + if not self.sess: + dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation}) + dist = self.dist_class(dist_inputs, self.model) + actions = dist.sample() + actions = tree.map_structure(lambda a: a.numpy(), actions) + # Graph mode. + else: + actions = self.sess.run( + self.sampler, feed_dict={self.inputs: observation}) + + actions = unbatch(actions) if add_noise and isinstance(self.action_space, gym.spaces.Box): - action += np.random.randn(*action.shape) * self.action_noise_std - return action + actions += np.random.randn(*actions.shape) * self.action_noise_std + return actions def compute_single_action(self, observation, diff --git a/rllib/agents/ars/tests/test_ars.py b/rllib/agents/ars/tests/test_ars.py index c3a3c1d6f..2bf0b5470 100644 --- a/rllib/agents/ars/tests/test_ars.py +++ b/rllib/agents/ars/tests/test_ars.py @@ -17,7 +17,7 @@ class TestARS(unittest.TestCase): num_iterations = 2 - for _ in framework_iterator(config, ("tf", "torch")): + for _ in framework_iterator(config): plain_config = config.copy() trainer = ars.ARSTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): @@ -25,6 +25,8 @@ class TestARS(unittest.TestCase): print(results) check_compute_single_action(trainer) + trainer.stop() + ray.shutdown() if __name__ == "__main__": diff --git a/rllib/agents/es/es_tf_policy.py b/rllib/agents/es/es_tf_policy.py index c739a906c..242eaf6c3 100644 --- a/rllib/agents/es/es_tf_policy.py +++ b/rllib/agents/es/es_tf_policy.py @@ -76,29 +76,40 @@ class ESTFPolicy: self.observation_filter = get_filter(config["observation_filter"], self.preprocessor.shape) self.single_threaded = config.get("single_threaded", False) - self.sess = make_session(single_threaded=self.single_threaded) - self.inputs = tf1.placeholder(tf.float32, - [None] + list(self.preprocessor.shape)) + if config["framework"] == "tf": + self.sess = make_session(single_threaded=self.single_threaded) + self.inputs = tf1.placeholder( + tf.float32, [None] + list(self.preprocessor.shape)) + else: + if not tf1.executing_eagerly(): + tf1.enable_eager_execution() + self.sess = self.inputs = None # Policy network. - dist_class, dist_dim = ModelCatalog.get_action_dist( + self.dist_class, dist_dim = ModelCatalog.get_action_dist( self.action_space, config["model"], dist_type="deterministic") + self.model = ModelCatalog.get_model_v2( obs_space=self.preprocessor.observation_space, action_space=action_space, num_outputs=dist_dim, model_config=config["model"]) - dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs}) - dist = dist_class(dist_inputs, self.model) - self.sampler = dist.sample() - self.variables = ray.experimental.tf_utils.TensorFlowVariables( - dist_inputs, self.sess) + self.sampler = None + if self.sess: + dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs}) + dist = self.dist_class(dist_inputs, self.model) + self.sampler = dist.sample() + self.variables = ray.experimental.tf_utils.TensorFlowVariables( + dist_inputs, self.sess) + self.sess.run(tf1.global_variables_initializer()) + else: + self.variables = ray.experimental.tf_utils.TensorFlowVariables( + [], None, self.model.variables()) self.num_params = sum( np.prod(variable.shape.as_list()) for _, variable in self.variables.variables.items()) - self.sess.run(tf1.global_variables_initializer()) def compute_actions(self, observation, @@ -111,8 +122,17 @@ class ESTFPolicy: observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) # `actions` is a list of (component) batches. - actions = self.sess.run( - self.sampler, feed_dict={self.inputs: observation}) + # Eager mode. + if not self.sess: + dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation}) + dist = self.dist_class(dist_inputs, self.model) + actions = dist.sample() + actions = tree.map_structure(lambda a: a.numpy(), actions) + # Graph mode. + else: + actions = self.sess.run( + self.sampler, feed_dict={self.inputs: observation}) + if add_noise: actions = tree.map_structure(self._add_noise, actions, self.action_space_struct) diff --git a/rllib/agents/es/tests/test_es.py b/rllib/agents/es/tests/test_es.py index 38033b2bb..17982e5be 100644 --- a/rllib/agents/es/tests/test_es.py +++ b/rllib/agents/es/tests/test_es.py @@ -18,7 +18,7 @@ class TestES(unittest.TestCase): num_iterations = 2 - for _ in framework_iterator(config, ("tf", "torch")): + for _ in framework_iterator(config): plain_config = config.copy() trainer = es.ESTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): @@ -26,6 +26,8 @@ class TestES(unittest.TestCase): print(results) check_compute_single_action(trainer) + trainer.stop() + ray.shutdown() if __name__ == "__main__":