From 0a0d9183feec47cc3a8e26adea687a4ea4e5c243 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 2 Feb 2021 18:42:18 +0100 Subject: [PATCH] [RLlib] Trajectory view API example script (enhancements and tf2 support). (#13786) --- rllib/BUILD | 4 +- .../trajectory_view_utilizing_models.py | 67 +++++++++++++------ rllib/examples/trajectory_view_api.py | 8 ++- rllib/models/torch/misc.py | 5 +- rllib/policy/eager_tf_policy.py | 60 +++++++++++++---- rllib/policy/torch_policy.py | 8 +-- 6 files changed, 106 insertions(+), 46 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 9658983ab..cfe22c60f 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2114,7 +2114,7 @@ py_test( tags = ["examples", "examples_T"], size = "medium", srcs = ["examples/trajectory_view_api.py"], - args = ["--as-test", "--framework=tf", "--stop-reward=80.0"] + args = ["--as-test", "--framework=tf", "--stop-reward=100.0"] ) py_test( @@ -2123,7 +2123,7 @@ py_test( tags = ["examples", "examples_T"], size = "medium", srcs = ["examples/trajectory_view_api.py"], - args = ["--as-test", "--framework=torch", "--stop-reward=80.0"] + args = ["--as-test", "--framework=torch", "--stop-reward=100.0"] ) py_test( diff --git a/rllib/examples/models/trajectory_view_utilizing_models.py b/rllib/examples/models/trajectory_view_utilizing_models.py index 41f53d872..0fd4e22cb 100644 --- a/rllib/examples/models/trajectory_view_utilizing_models.py +++ b/rllib/examples/models/trajectory_view_utilizing_models.py @@ -3,6 +3,8 @@ from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.torch_ops import one_hot as torch_one_hot tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -28,27 +30,42 @@ class FrameStackingCartPoleModel(TFModelV2): # Construct actual (very simple) FC model. assert len(obs_space.shape) == 1 - input_ = tf.keras.layers.Input( + obs = tf.keras.layers.Input( shape=(self.num_frames, obs_space.shape[0])) - reshaped = tf.keras.layers.Reshape( - [obs_space.shape[0] * self.num_frames])(input_) - layer1 = tf.keras.layers.Dense(64, activation=tf.nn.relu)(reshaped) - out = tf.keras.layers.Dense(self.num_outputs)(layer1) + obs_reshaped = tf.keras.layers.Reshape( + [obs_space.shape[0] * self.num_frames])(obs) + rewards = tf.keras.layers.Input(shape=(self.num_frames)) + rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards) + actions = tf.keras.layers.Input( + shape=(self.num_frames, self.action_space.n)) + actions_reshaped = tf.keras.layers.Reshape( + [action_space.n * self.num_frames])(actions) + input_ = tf.keras.layers.Concatenate(axis=-1)( + [obs_reshaped, actions_reshaped, rewards_reshaped]) + layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_) + layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1) + out = tf.keras.layers.Dense(self.num_outputs)(layer2) values = tf.keras.layers.Dense(1)(layer1) - self.base_model = tf.keras.models.Model([input_], [out, values]) - + self.base_model = tf.keras.models.Model([obs, actions, rewards], + [out, values]) self._last_value = None self.view_requirements["prev_n_obs"] = ViewRequirement( data_col="obs", shift="-{}:0".format(num_frames - 1), space=obs_space) - self.view_requirements["prev_rewards"] = ViewRequirement( - data_col="rewards", shift=-1) + self.view_requirements["prev_n_rewards"] = ViewRequirement( + data_col="rewards", shift="-{}:-1".format(self.num_frames)) + self.view_requirements["prev_n_actions"] = ViewRequirement( + data_col="actions", + shift="-{}:-1".format(self.num_frames), + space=self.action_space) def forward(self, input_dict, states, seq_lens): - obs = input_dict["prev_n_obs"] - out, self._last_value = self.base_model(obs) + obs = tf.cast(input_dict["prev_n_obs"], tf.float32) + rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32) + actions = one_hot(input_dict["prev_n_actions"], self.action_space) + out, self._last_value = self.base_model([obs, actions, rewards]) return out, [] def value_function(self): @@ -77,13 +94,13 @@ class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module): # Construct actual (very simple) FC model. assert len(obs_space.shape) == 1 + in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1) self.layer1 = SlimFC( - in_size=obs_space.shape[0] * self.num_frames, - out_size=64, - activation_fn="relu") + in_size=in_size, out_size=256, activation_fn="relu") + self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu") self.out = SlimFC( - in_size=64, out_size=self.num_outputs, activation_fn="linear") - self.values = SlimFC(in_size=64, out_size=1, activation_fn="linear") + in_size=256, out_size=self.num_outputs, activation_fn="linear") + self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear") self._last_value = None @@ -91,14 +108,26 @@ class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module): data_col="obs", shift="-{}:0".format(num_frames - 1), space=obs_space) - self.view_requirements["prev_rewards"] = ViewRequirement( - data_col="rewards", shift=-1) + self.view_requirements["prev_n_rewards"] = ViewRequirement( + data_col="rewards", shift="-{}:-1".format(self.num_frames)) + self.view_requirements["prev_n_actions"] = ViewRequirement( + data_col="actions", + shift="-{}:-1".format(self.num_frames), + space=self.action_space) def forward(self, input_dict, states, seq_lens): obs = input_dict["prev_n_obs"] obs = torch.reshape(obs, [-1, self.obs_space.shape[0] * self.num_frames]) - features = self.layer1(obs) + rewards = torch.reshape(input_dict["prev_n_rewards"], + [-1, self.num_frames]) + actions = torch_one_hot(input_dict["prev_n_actions"], + self.action_space) + actions = torch.reshape(actions, + [-1, self.num_frames * actions.shape[-1]]) + input_ = torch.cat([obs, actions, rewards], dim=-1) + features = self.layer1(input_) + features = self.layer2(features) out = self.out(features) self._last_value = self.values(features) return out, [] diff --git a/rllib/examples/trajectory_view_api.py b/rllib/examples/trajectory_view_api.py index 400051ad5..a72061779 100644 --- a/rllib/examples/trajectory_view_api.py +++ b/rllib/examples/trajectory_view_api.py @@ -2,6 +2,7 @@ import argparse import ray from ray import tune +from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.models.trajectory_view_utilizing_models import \ FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel from ray.rllib.models.catalog import ModelCatalog @@ -16,7 +17,7 @@ parser.add_argument( "--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf") parser.add_argument("--as-test", action="store_true") parser.add_argument("--stop-iters", type=int, default=50) -parser.add_argument("--stop-timesteps", type=int, default=100000) +parser.add_argument("--stop-timesteps", type=int, default=200000) parser.add_argument("--stop-reward", type=float, default=150.0) if __name__ == "__main__": @@ -26,13 +27,14 @@ if __name__ == "__main__": ModelCatalog.register_custom_model( "frame_stack_model", FrameStackingCartPoleModel if args.framework != "torch" else TorchFrameStackingCartPoleModel) + tune.register_env("stateless_cartpole", lambda c: StatelessCartPole()) config = { - "env": "CartPole-v0", + "env": "stateless_cartpole", "model": { "custom_model": "frame_stack_model", "custom_model_config": { - "num_frames": 4, + "num_frames": 16, } }, "framework": args.framework, diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py index 830e8bc33..9f6d8234e 100644 --- a/rllib/models/torch/misc.py +++ b/rllib/models/torch/misc.py @@ -139,8 +139,9 @@ class SlimFC(nn.Module): layers = [] # Actual nn.Linear layer (including correct initialization logic). linear = nn.Linear(in_size, out_size, bias=use_bias) - if initializer: - initializer(linear.weight) + if initializer is None: + initializer = nn.init.xavier_uniform_ + initializer(linear.weight) if use_bias is True: nn.init.constant_(linear.bias, bias_init) layers.append(linear) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 805cacaaa..1e1f42c05 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -5,6 +5,7 @@ It supports both traced and non-traced eager execution modes.""" import functools import logging import threading +from typing import Dict, List, Optional, Tuple from ray.util.debug import log_once from ray.rllib.models.catalog import ModelCatalog @@ -18,6 +19,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_ops import convert_to_non_tf_type from ray.rllib.utils.threading import with_lock from ray.rllib.utils.tracking_dict import UsageTrackingDict +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) @@ -361,10 +363,7 @@ def build_eager_tf_policy(name, grads = [g for g, v in grads_and_vars] return grads, stats - @with_lock @override(Policy) - @convert_eager_inputs - @convert_eager_outputs def compute_actions(self, obs_batch, state_batches=None, @@ -376,16 +375,9 @@ def build_eager_tf_policy(name, timestep=None, **kwargs): - explore = explore if explore is not None else \ - self.config["explore"] - timestep = timestep if timestep is not None else \ - self.global_timestep - - # TODO: remove python side effect to cull sources of bugs. self._is_training = False self._is_recurrent = \ state_batches is not None and state_batches != [] - self._state_in = state_batches or [] if not tf1.executing_eagerly(): tf1.enable_eager_execution() @@ -394,8 +386,6 @@ def build_eager_tf_policy(name, SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), "is_training": tf.constant(False), } - batch_size = input_dict[SampleBatch.CUR_OBS].shape[0] - seq_lens = tf.ones(batch_size, dtype=tf.int32) if obs_include_prev_action_reward: if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ @@ -404,6 +394,50 @@ def build_eager_tf_policy(name, input_dict[SampleBatch.PREV_REWARDS] = \ tf.convert_to_tensor(prev_reward_batch) + return self._compute_action_helper(input_dict, state_batches, + episodes, explore, timestep) + + @override(Policy) + def compute_actions_from_input_dict( + self, + input_dict: Dict[str, TensorType], + explore: bool = None, + timestep: Optional[int] = None, + **kwargs + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + + if not tf1.executing_eagerly(): + tf1.enable_eager_execution() + + # Pass lazy (torch) tensor dict to Model as `input_dict`. + input_dict = self._lazy_tensor_dict(input_dict) + # Pack internal state inputs into (separate) list. + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] + ] + + return self._compute_action_helper(input_dict, state_batches, None, + explore, timestep) + + @with_lock + @convert_eager_inputs + @convert_eager_outputs + def _compute_action_helper(self, input_dict, state_batches, episodes, + explore, timestep): + + explore = explore if explore is not None else \ + self.config["explore"] + timestep = timestep if timestep is not None else \ + self.global_timestep + if isinstance(timestep, tf.Tensor): + timestep = int(timestep.numpy()) + self._is_training = False + self._state_in = state_batches or [] + # Calculate RNN sequence lengths. + batch_size = input_dict[SampleBatch.CUR_OBS].shape[0] + seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \ + else None + # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): if action_sampler_fn: @@ -496,8 +530,6 @@ def build_eager_tf_policy(name, input_dict[SampleBatch.CUR_OBS], explore=False, is_training=False) - action_dist = dist_class(dist_inputs, self.model) - log_likelihoods = action_dist.logp(actions) # Default log-likelihood calculation. else: dist_inputs, _ = self.model(input_dict, state_batches, diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 19d576d37..e492a5048 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -159,9 +159,6 @@ class TorchPolicy(Policy): **kwargs) -> \ Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - explore = explore if explore is not None else self.config["explore"] - timestep = timestep if timestep is not None else self.global_timestep - with torch.no_grad(): seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) input_dict = self._lazy_tensor_dict({ @@ -190,9 +187,6 @@ class TorchPolicy(Policy): **kwargs) -> \ Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - explore = explore if explore is not None else self.config["explore"] - timestep = timestep if timestep is not None else self.global_timestep - with torch.no_grad(): # Pass lazy (torch) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) @@ -216,6 +210,8 @@ class TorchPolicy(Policy): Tuple: - actions, state_out, extra_fetches, logp. """ + explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep self._is_recurrent = state_batches is not None and state_batches != [] # Switch to eval mode.