[RLlib] Trajectory view API example script (enhancements and tf2 support). (#13786)

This commit is contained in:
Sven Mika
2021-02-02 18:42:18 +01:00
committed by GitHub
parent a6138ca31f
commit 0a0d9183fe
6 changed files with 106 additions and 46 deletions
+2 -2
View File
@@ -2114,7 +2114,7 @@ py_test(
tags = ["examples", "examples_T"],
size = "medium",
srcs = ["examples/trajectory_view_api.py"],
args = ["--as-test", "--framework=tf", "--stop-reward=80.0"]
args = ["--as-test", "--framework=tf", "--stop-reward=100.0"]
)
py_test(
@@ -2123,7 +2123,7 @@ py_test(
tags = ["examples", "examples_T"],
size = "medium",
srcs = ["examples/trajectory_view_api.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=80.0"]
args = ["--as-test", "--framework=torch", "--stop-reward=100.0"]
)
py_test(
@@ -3,6 +3,8 @@ from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.tf_ops import one_hot
from ray.rllib.utils.torch_ops import one_hot as torch_one_hot
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
@@ -28,27 +30,42 @@ class FrameStackingCartPoleModel(TFModelV2):
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
input_ = tf.keras.layers.Input(
obs = tf.keras.layers.Input(
shape=(self.num_frames, obs_space.shape[0]))
reshaped = tf.keras.layers.Reshape(
[obs_space.shape[0] * self.num_frames])(input_)
layer1 = tf.keras.layers.Dense(64, activation=tf.nn.relu)(reshaped)
out = tf.keras.layers.Dense(self.num_outputs)(layer1)
obs_reshaped = tf.keras.layers.Reshape(
[obs_space.shape[0] * self.num_frames])(obs)
rewards = tf.keras.layers.Input(shape=(self.num_frames))
rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
actions = tf.keras.layers.Input(
shape=(self.num_frames, self.action_space.n))
actions_reshaped = tf.keras.layers.Reshape(
[action_space.n * self.num_frames])(actions)
input_ = tf.keras.layers.Concatenate(axis=-1)(
[obs_reshaped, actions_reshaped, rewards_reshaped])
layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
out = tf.keras.layers.Dense(self.num_outputs)(layer2)
values = tf.keras.layers.Dense(1)(layer1)
self.base_model = tf.keras.models.Model([input_], [out, values])
self.base_model = tf.keras.models.Model([obs, actions, rewards],
[out, values])
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs",
shift="-{}:0".format(num_frames - 1),
space=obs_space)
self.view_requirements["prev_rewards"] = ViewRequirement(
data_col="rewards", shift=-1)
self.view_requirements["prev_n_rewards"] = ViewRequirement(
data_col="rewards", shift="-{}:-1".format(self.num_frames))
self.view_requirements["prev_n_actions"] = ViewRequirement(
data_col="actions",
shift="-{}:-1".format(self.num_frames),
space=self.action_space)
def forward(self, input_dict, states, seq_lens):
obs = input_dict["prev_n_obs"]
out, self._last_value = self.base_model(obs)
obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
actions = one_hot(input_dict["prev_n_actions"], self.action_space)
out, self._last_value = self.base_model([obs, actions, rewards])
return out, []
def value_function(self):
@@ -77,13 +94,13 @@ class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
self.layer1 = SlimFC(
in_size=obs_space.shape[0] * self.num_frames,
out_size=64,
activation_fn="relu")
in_size=in_size, out_size=256, activation_fn="relu")
self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
self.out = SlimFC(
in_size=64, out_size=self.num_outputs, activation_fn="linear")
self.values = SlimFC(in_size=64, out_size=1, activation_fn="linear")
in_size=256, out_size=self.num_outputs, activation_fn="linear")
self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")
self._last_value = None
@@ -91,14 +108,26 @@ class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
data_col="obs",
shift="-{}:0".format(num_frames - 1),
space=obs_space)
self.view_requirements["prev_rewards"] = ViewRequirement(
data_col="rewards", shift=-1)
self.view_requirements["prev_n_rewards"] = ViewRequirement(
data_col="rewards", shift="-{}:-1".format(self.num_frames))
self.view_requirements["prev_n_actions"] = ViewRequirement(
data_col="actions",
shift="-{}:-1".format(self.num_frames),
space=self.action_space)
def forward(self, input_dict, states, seq_lens):
obs = input_dict["prev_n_obs"]
obs = torch.reshape(obs,
[-1, self.obs_space.shape[0] * self.num_frames])
features = self.layer1(obs)
rewards = torch.reshape(input_dict["prev_n_rewards"],
[-1, self.num_frames])
actions = torch_one_hot(input_dict["prev_n_actions"],
self.action_space)
actions = torch.reshape(actions,
[-1, self.num_frames * actions.shape[-1]])
input_ = torch.cat([obs, actions, rewards], dim=-1)
features = self.layer1(input_)
features = self.layer2(features)
out = self.out(features)
self._last_value = self.values(features)
return out, []
+5 -3
View File
@@ -2,6 +2,7 @@ import argparse
import ray
from ray import tune
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.examples.models.trajectory_view_utilizing_models import \
FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel
from ray.rllib.models.catalog import ModelCatalog
@@ -16,7 +17,7 @@ parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=50)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--stop-timesteps", type=int, default=200000)
parser.add_argument("--stop-reward", type=float, default=150.0)
if __name__ == "__main__":
@@ -26,13 +27,14 @@ if __name__ == "__main__":
ModelCatalog.register_custom_model(
"frame_stack_model", FrameStackingCartPoleModel
if args.framework != "torch" else TorchFrameStackingCartPoleModel)
tune.register_env("stateless_cartpole", lambda c: StatelessCartPole())
config = {
"env": "CartPole-v0",
"env": "stateless_cartpole",
"model": {
"custom_model": "frame_stack_model",
"custom_model_config": {
"num_frames": 4,
"num_frames": 16,
}
},
"framework": args.framework,
+3 -2
View File
@@ -139,8 +139,9 @@ class SlimFC(nn.Module):
layers = []
# Actual nn.Linear layer (including correct initialization logic).
linear = nn.Linear(in_size, out_size, bias=use_bias)
if initializer:
initializer(linear.weight)
if initializer is None:
initializer = nn.init.xavier_uniform_
initializer(linear.weight)
if use_bias is True:
nn.init.constant_(linear.bias, bias_init)
layers.append(linear)
+46 -14
View File
@@ -5,6 +5,7 @@ It supports both traced and non-traced eager execution modes."""
import functools
import logging
import threading
from typing import Dict, List, Optional, Tuple
from ray.util.debug import log_once
from ray.rllib.models.catalog import ModelCatalog
@@ -18,6 +19,7 @@ from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.tracking_dict import UsageTrackingDict
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@@ -361,10 +363,7 @@ def build_eager_tf_policy(name,
grads = [g for g, v in grads_and_vars]
return grads, stats
@with_lock
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def compute_actions(self,
obs_batch,
state_batches=None,
@@ -376,16 +375,9 @@ def build_eager_tf_policy(name,
timestep=None,
**kwargs):
explore = explore if explore is not None else \
self.config["explore"]
timestep = timestep if timestep is not None else \
self.global_timestep
# TODO: remove python side effect to cull sources of bugs.
self._is_training = False
self._is_recurrent = \
state_batches is not None and state_batches != []
self._state_in = state_batches or []
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
@@ -394,8 +386,6 @@ def build_eager_tf_policy(name,
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
"is_training": tf.constant(False),
}
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
seq_lens = tf.ones(batch_size, dtype=tf.int32)
if obs_include_prev_action_reward:
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
@@ -404,6 +394,50 @@ def build_eager_tf_policy(name,
input_dict[SampleBatch.PREV_REWARDS] = \
tf.convert_to_tensor(prev_reward_batch)
return self._compute_action_helper(input_dict, state_batches,
episodes, explore, timestep)
@override(Policy)
def compute_actions_from_input_dict(
self,
input_dict: Dict[str, TensorType],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
]
return self._compute_action_helper(input_dict, state_batches, None,
explore, timestep)
@with_lock
@convert_eager_inputs
@convert_eager_outputs
def _compute_action_helper(self, input_dict, state_batches, episodes,
explore, timestep):
explore = explore if explore is not None else \
self.config["explore"]
timestep = timestep if timestep is not None else \
self.global_timestep
if isinstance(timestep, tf.Tensor):
timestep = int(timestep.numpy())
self._is_training = False
self._state_in = state_batches or []
# Calculate RNN sequence lengths.
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
else None
# Use Exploration object.
with tf.variable_creator_scope(_disallow_var_creation):
if action_sampler_fn:
@@ -496,8 +530,6 @@ def build_eager_tf_policy(name,
input_dict[SampleBatch.CUR_OBS],
explore=False,
is_training=False)
action_dist = dist_class(dist_inputs, self.model)
log_likelihoods = action_dist.logp(actions)
# Default log-likelihood calculation.
else:
dist_inputs, _ = self.model(input_dict, state_batches,
+2 -6
View File
@@ -159,9 +159,6 @@ class TorchPolicy(Policy):
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict({
@@ -190,9 +187,6 @@ class TorchPolicy(Policy):
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
@@ -216,6 +210,8 @@ class TorchPolicy(Policy):
Tuple:
- actions, state_out, extra_fetches, logp.
"""
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
self._is_recurrent = state_batches is not None and state_batches != []
# Switch to eval mode.