mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[RLlib] Trajectory view API example script (enhancements and tf2 support). (#13786)
This commit is contained in:
+2
-2
@@ -2114,7 +2114,7 @@ py_test(
|
||||
tags = ["examples", "examples_T"],
|
||||
size = "medium",
|
||||
srcs = ["examples/trajectory_view_api.py"],
|
||||
args = ["--as-test", "--framework=tf", "--stop-reward=80.0"]
|
||||
args = ["--as-test", "--framework=tf", "--stop-reward=100.0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
@@ -2123,7 +2123,7 @@ py_test(
|
||||
tags = ["examples", "examples_T"],
|
||||
size = "medium",
|
||||
srcs = ["examples/trajectory_view_api.py"],
|
||||
args = ["--as-test", "--framework=torch", "--stop-reward=80.0"]
|
||||
args = ["--as-test", "--framework=torch", "--stop-reward=100.0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
||||
@@ -3,6 +3,8 @@ from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.tf_ops import one_hot
|
||||
from ray.rllib.utils.torch_ops import one_hot as torch_one_hot
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
@@ -28,27 +30,42 @@ class FrameStackingCartPoleModel(TFModelV2):
|
||||
|
||||
# Construct actual (very simple) FC model.
|
||||
assert len(obs_space.shape) == 1
|
||||
input_ = tf.keras.layers.Input(
|
||||
obs = tf.keras.layers.Input(
|
||||
shape=(self.num_frames, obs_space.shape[0]))
|
||||
reshaped = tf.keras.layers.Reshape(
|
||||
[obs_space.shape[0] * self.num_frames])(input_)
|
||||
layer1 = tf.keras.layers.Dense(64, activation=tf.nn.relu)(reshaped)
|
||||
out = tf.keras.layers.Dense(self.num_outputs)(layer1)
|
||||
obs_reshaped = tf.keras.layers.Reshape(
|
||||
[obs_space.shape[0] * self.num_frames])(obs)
|
||||
rewards = tf.keras.layers.Input(shape=(self.num_frames))
|
||||
rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
|
||||
actions = tf.keras.layers.Input(
|
||||
shape=(self.num_frames, self.action_space.n))
|
||||
actions_reshaped = tf.keras.layers.Reshape(
|
||||
[action_space.n * self.num_frames])(actions)
|
||||
input_ = tf.keras.layers.Concatenate(axis=-1)(
|
||||
[obs_reshaped, actions_reshaped, rewards_reshaped])
|
||||
layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
|
||||
layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
|
||||
out = tf.keras.layers.Dense(self.num_outputs)(layer2)
|
||||
values = tf.keras.layers.Dense(1)(layer1)
|
||||
self.base_model = tf.keras.models.Model([input_], [out, values])
|
||||
|
||||
self.base_model = tf.keras.models.Model([obs, actions, rewards],
|
||||
[out, values])
|
||||
self._last_value = None
|
||||
|
||||
self.view_requirements["prev_n_obs"] = ViewRequirement(
|
||||
data_col="obs",
|
||||
shift="-{}:0".format(num_frames - 1),
|
||||
space=obs_space)
|
||||
self.view_requirements["prev_rewards"] = ViewRequirement(
|
||||
data_col="rewards", shift=-1)
|
||||
self.view_requirements["prev_n_rewards"] = ViewRequirement(
|
||||
data_col="rewards", shift="-{}:-1".format(self.num_frames))
|
||||
self.view_requirements["prev_n_actions"] = ViewRequirement(
|
||||
data_col="actions",
|
||||
shift="-{}:-1".format(self.num_frames),
|
||||
space=self.action_space)
|
||||
|
||||
def forward(self, input_dict, states, seq_lens):
|
||||
obs = input_dict["prev_n_obs"]
|
||||
out, self._last_value = self.base_model(obs)
|
||||
obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
|
||||
rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
|
||||
actions = one_hot(input_dict["prev_n_actions"], self.action_space)
|
||||
out, self._last_value = self.base_model([obs, actions, rewards])
|
||||
return out, []
|
||||
|
||||
def value_function(self):
|
||||
@@ -77,13 +94,13 @@ class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
|
||||
|
||||
# Construct actual (very simple) FC model.
|
||||
assert len(obs_space.shape) == 1
|
||||
in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
|
||||
self.layer1 = SlimFC(
|
||||
in_size=obs_space.shape[0] * self.num_frames,
|
||||
out_size=64,
|
||||
activation_fn="relu")
|
||||
in_size=in_size, out_size=256, activation_fn="relu")
|
||||
self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
|
||||
self.out = SlimFC(
|
||||
in_size=64, out_size=self.num_outputs, activation_fn="linear")
|
||||
self.values = SlimFC(in_size=64, out_size=1, activation_fn="linear")
|
||||
in_size=256, out_size=self.num_outputs, activation_fn="linear")
|
||||
self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")
|
||||
|
||||
self._last_value = None
|
||||
|
||||
@@ -91,14 +108,26 @@ class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
|
||||
data_col="obs",
|
||||
shift="-{}:0".format(num_frames - 1),
|
||||
space=obs_space)
|
||||
self.view_requirements["prev_rewards"] = ViewRequirement(
|
||||
data_col="rewards", shift=-1)
|
||||
self.view_requirements["prev_n_rewards"] = ViewRequirement(
|
||||
data_col="rewards", shift="-{}:-1".format(self.num_frames))
|
||||
self.view_requirements["prev_n_actions"] = ViewRequirement(
|
||||
data_col="actions",
|
||||
shift="-{}:-1".format(self.num_frames),
|
||||
space=self.action_space)
|
||||
|
||||
def forward(self, input_dict, states, seq_lens):
|
||||
obs = input_dict["prev_n_obs"]
|
||||
obs = torch.reshape(obs,
|
||||
[-1, self.obs_space.shape[0] * self.num_frames])
|
||||
features = self.layer1(obs)
|
||||
rewards = torch.reshape(input_dict["prev_n_rewards"],
|
||||
[-1, self.num_frames])
|
||||
actions = torch_one_hot(input_dict["prev_n_actions"],
|
||||
self.action_space)
|
||||
actions = torch.reshape(actions,
|
||||
[-1, self.num_frames * actions.shape[-1]])
|
||||
input_ = torch.cat([obs, actions, rewards], dim=-1)
|
||||
features = self.layer1(input_)
|
||||
features = self.layer2(features)
|
||||
out = self.out(features)
|
||||
self._last_value = self.values(features)
|
||||
return out, []
|
||||
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.rllib.examples.models.trajectory_view_utilizing_models import \
|
||||
FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
@@ -16,7 +17,7 @@ parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
parser.add_argument("--as-test", action="store_true")
|
||||
parser.add_argument("--stop-iters", type=int, default=50)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=200000)
|
||||
parser.add_argument("--stop-reward", type=float, default=150.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -26,13 +27,14 @@ if __name__ == "__main__":
|
||||
ModelCatalog.register_custom_model(
|
||||
"frame_stack_model", FrameStackingCartPoleModel
|
||||
if args.framework != "torch" else TorchFrameStackingCartPoleModel)
|
||||
tune.register_env("stateless_cartpole", lambda c: StatelessCartPole())
|
||||
|
||||
config = {
|
||||
"env": "CartPole-v0",
|
||||
"env": "stateless_cartpole",
|
||||
"model": {
|
||||
"custom_model": "frame_stack_model",
|
||||
"custom_model_config": {
|
||||
"num_frames": 4,
|
||||
"num_frames": 16,
|
||||
}
|
||||
},
|
||||
"framework": args.framework,
|
||||
|
||||
@@ -139,8 +139,9 @@ class SlimFC(nn.Module):
|
||||
layers = []
|
||||
# Actual nn.Linear layer (including correct initialization logic).
|
||||
linear = nn.Linear(in_size, out_size, bias=use_bias)
|
||||
if initializer:
|
||||
initializer(linear.weight)
|
||||
if initializer is None:
|
||||
initializer = nn.init.xavier_uniform_
|
||||
initializer(linear.weight)
|
||||
if use_bias is True:
|
||||
nn.init.constant_(linear.bias, bias_init)
|
||||
layers.append(linear)
|
||||
|
||||
@@ -5,6 +5,7 @@ It supports both traced and non-traced eager execution modes."""
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
@@ -18,6 +19,7 @@ from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -361,10 +363,7 @@ def build_eager_tf_policy(name,
|
||||
grads = [g for g, v in grads_and_vars]
|
||||
return grads, stats
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
@@ -376,16 +375,9 @@ def build_eager_tf_policy(name,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
|
||||
explore = explore if explore is not None else \
|
||||
self.config["explore"]
|
||||
timestep = timestep if timestep is not None else \
|
||||
self.global_timestep
|
||||
|
||||
# TODO: remove python side effect to cull sources of bugs.
|
||||
self._is_training = False
|
||||
self._is_recurrent = \
|
||||
state_batches is not None and state_batches != []
|
||||
self._state_in = state_batches or []
|
||||
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
@@ -394,8 +386,6 @@ def build_eager_tf_policy(name,
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||
"is_training": tf.constant(False),
|
||||
}
|
||||
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
|
||||
seq_lens = tf.ones(batch_size, dtype=tf.int32)
|
||||
if obs_include_prev_action_reward:
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
@@ -404,6 +394,50 @@ def build_eager_tf_policy(name,
|
||||
input_dict[SampleBatch.PREV_REWARDS] = \
|
||||
tf.convert_to_tensor(prev_reward_batch)
|
||||
|
||||
return self._compute_action_helper(input_dict, state_batches,
|
||||
episodes, explore, timestep)
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
input_dict: Dict[str, TensorType],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
# Pack internal state inputs into (separate) list.
|
||||
state_batches = [
|
||||
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
||||
]
|
||||
|
||||
return self._compute_action_helper(input_dict, state_batches, None,
|
||||
explore, timestep)
|
||||
|
||||
@with_lock
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def _compute_action_helper(self, input_dict, state_batches, episodes,
|
||||
explore, timestep):
|
||||
|
||||
explore = explore if explore is not None else \
|
||||
self.config["explore"]
|
||||
timestep = timestep if timestep is not None else \
|
||||
self.global_timestep
|
||||
if isinstance(timestep, tf.Tensor):
|
||||
timestep = int(timestep.numpy())
|
||||
self._is_training = False
|
||||
self._state_in = state_batches or []
|
||||
# Calculate RNN sequence lengths.
|
||||
batch_size = input_dict[SampleBatch.CUR_OBS].shape[0]
|
||||
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
|
||||
else None
|
||||
|
||||
# Use Exploration object.
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
if action_sampler_fn:
|
||||
@@ -496,8 +530,6 @@ def build_eager_tf_policy(name,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
explore=False,
|
||||
is_training=False)
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
log_likelihoods = action_dist.logp(actions)
|
||||
# Default log-likelihood calculation.
|
||||
else:
|
||||
dist_inputs, _ = self.model(input_dict, state_batches,
|
||||
|
||||
@@ -159,9 +159,6 @@ class TorchPolicy(Policy):
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
with torch.no_grad():
|
||||
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
@@ -190,9 +187,6 @@ class TorchPolicy(Policy):
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
with torch.no_grad():
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
@@ -216,6 +210,8 @@ class TorchPolicy(Policy):
|
||||
Tuple:
|
||||
- actions, state_out, extra_fetches, logp.
|
||||
"""
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
self._is_recurrent = state_batches is not None and state_batches != []
|
||||
|
||||
# Switch to eval mode.
|
||||
|
||||
Reference in New Issue
Block a user