[RLlib] Preparatory PR for: Documentation on Model Building. (#13260)

This commit is contained in:
Sven Mika
2021-01-08 10:56:09 +01:00
committed by GitHub
parent a247c71e2e
commit 6f342a2221
13 changed files with 896 additions and 37 deletions
+35
View File
@@ -1808,6 +1808,23 @@ py_test(
args = ["--stop-iters=2"]
)
py_test(
name = "examples/custom_model_api_tf",
main = "examples/custom_model_api.py",
tags = ["examples", "examples_C"],
size = "small",
srcs = ["examples/custom_model_api.py"],
)
py_test(
name = "examples/custom_model_api_torch",
main = "examples/custom_model_api.py",
tags = ["examples", "examples_C"],
size = "small",
srcs = ["examples/custom_model_api.py"],
args = ["--framework=torch"],
)
py_test(
name = "examples/custom_observation_filters",
main = "examples/custom_observation_filters.py",
@@ -2047,6 +2064,24 @@ py_test(
args = ["--as-test", "--torch"],
)
py_test(
name = "examples/trajectory_view_api_tf",
main = "examples/trajectory_view_api.py",
tags = ["examples", "examples_T"],
size = "medium",
srcs = ["examples/trajectory_view_api.py"],
args = ["--as-test", "--framework=tf", "--stop-reward=80.0"]
)
py_test(
name = "examples/trajectory_view_api_torch",
main = "examples/trajectory_view_api.py",
tags = ["examples", "examples_T"],
size = "medium",
srcs = ["examples/trajectory_view_api.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=80.0"]
)
py_test(
name = "examples/two_trainer_workflow_tf",
main = "examples/two_trainer_workflow.py",
@@ -166,6 +166,8 @@ class _AgentCollector:
# resulting data=[[-3, -2, -1], [7, 8, 9]]
# Range of 3 consecutive items repeats every 10 timesteps.
if view_req.shift_from is not None:
# Batch repeat value > 1: Only repeat the shift_from/to range
# every n timesteps.
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
@@ -179,11 +181,20 @@ class _AgentCollector:
view_req.shift_to + 1 + obs_shift]
for i in range(count)
])
# Batch repeat value = 1: Repeat the shift_from/to range at
# each timestep.
else:
data = np_data[data_col][self.shift_before +
view_req.shift_from +
obs_shift:self.shift_before +
view_req.shift_to + 1 + obs_shift]
d = np_data[data_col]
# TODO: For now, assume simple 1D data (B x x).
# Will expand this for Atari examples.
assert len(d.shape) == 2
shift_win = view_req.shift_to - view_req.shift_from + 1
data_size = d.itemsize * int(np.product(d.shape[1:]))
data = np.lib.stride_tricks.as_strided(
d[self.shift_before - shift_win:],
[self.agent_steps, shift_win, d.shape[1]],
[data_size, data_size, d.itemsize])
# Set of (probably non-consecutive) indices.
# Example:
# shift=[-3, 0]
+103
View File
@@ -0,0 +1,103 @@
import argparse
from gym.spaces import Box, Discrete
import numpy as np
from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
TorchDuelingQModel, ContActionQModel, TorchContActionQModel
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
if __name__ == "__main__":
args = parser.parse_args()
# Test API wrapper for dueling Q-head.
obs_space = Box(-1.0, 1.0, (3, ))
action_space = Discrete(3)
# Run in eager mode for value checking and debugging.
tf1.enable_eager_execution()
# __sphinx_doc_model_construct_begin__
my_dueling_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=action_space.n,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=DuelingQModel
if args.framework != "torch" else TorchDuelingQModel,
name="dueling_q_model",
)
# __sphinx_doc_model_construct_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = {
"obs": input_,
"is_training": False,
}
out, state_outs = my_dueling_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` into `get_q_values`
q_values = my_dueling_model.get_q_values(out)
assert q_values.shape == (10, action_space.n)
# Test API wrapper for single value Q-head from obs/action input.
obs_space = Box(-1.0, 1.0, (3, ))
action_space = Box(-1.0, -1.0, (2, ))
# __sphinx_doc_model_construct_begin__
my_cont_action_q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=2,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=ContActionQModel
if args.framework != "torch" else TorchContActionQModel,
name="cont_action_q_model",
)
# __sphinx_doc_model_construct_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = {
"obs": input_,
"is_training": False,
}
# Note that for PyTorch, you will have to provide torch tensors here.
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` and an action into `my_cont_action_q_model`
action = np.array([action_space.sample() for _ in range(batch_size)])
if args.framework == "torch":
action = torch.from_numpy(action)
q_value = my_cont_action_q_model.get_single_q_value(out, action)
assert q_value.shape == (10, 1)
+63
View File
@@ -0,0 +1,63 @@
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
# __sphinx_doc_begin__
# The custom model that will be wrapped by an LSTM.
class MyCustomModel(TorchModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.num_outputs = int(np.product(self.obs_space.shape))
self._last_batch_size = None
# Implement your own forward logic, whose output will then be sent
# through an LSTM.
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"]
# Store last batch size for value_function output.
self._last_batch_size = obs.shape[0]
# Return 2x the obs (and empty states).
# This will further be sent through an automatically provided
# LSTM head (b/c we are setting use_lstm=True below).
return obs * 2.0, []
def value_function(self):
return torch.from_numpy(np.zeros(shape=(self._last_batch_size, )))
if __name__ == "__main__":
ray.init()
# Register the above custom model.
ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)
# Create the Trainer.
trainer = ppo.PPOTrainer(
env="CartPole-v0",
config={
"framework": "torch",
"model": {
# Auto-wrap the custom(!) model with an LSTM.
"use_lstm": True,
# To further customize the LSTM auto-wrapper.
"lstm_cell_size": 64,
# Specify our custom model from above.
"custom_model": "my_torch_model",
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})
trainer.train()
# __sphinx_doc_end__
@@ -0,0 +1,220 @@
from gym.spaces import Discrete, Tuple
from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import normc_initializer as \
torch_normc_initializer, SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_begin__
class CNNPlusFCConcatModel(TFModelV2):
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
assert isinstance(obs_space.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
concat_size = 0
for i, component in enumerate(obs_space.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
}
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="tf",
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.register_variables(cnn.variables())
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
assert len(component.shape) == 1, \
"Only input Box 1D or 3D spaces allowed!"
concat_size += component.shape[-1]
self.logits_and_value_model = None
self._value_out = None
if num_outputs:
# Action-distribution head.
concat_layer = tf.keras.layers.Input((concat_size, ))
logits_layer = tf.keras.layers.Dense(
num_outputs,
activation=tf.keras.activations.linear,
name="logits")(concat_layer)
# Create the value branch model.
value_layer = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(concat_layer)
self.logits_and_value_model = tf.keras.models.Model(
concat_layer, [logits_layer, value_layer])
self.register_variables(self.logits_and_value_model.variables)
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
else:
outs.append(component)
# Concat all outputs and the non-image inputs.
out = tf.concat(outs, axis=1)
if not self.logits_and_value_model:
return out, []
# Value branch.
logits, values = self.logits_and_value_model(out)
self._value_out = tf.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
# __sphinx_doc_end__
class TorchCNNPlusFCConcatModel(TorchModelV2, nn.Module):
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
assert isinstance(obs_space.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
nn.Module.__init__(self)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
self.cnn_type = self.model_config["custom_model_config"].get(
"conv_type", "atari")
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
concat_size = 0
for i, component in enumerate(obs_space.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
}
if self.cnn_type == "atari":
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="torch",
name="cnn_{}".format(i))
else:
cnn = TorchImpalaVisionNet(
component,
action_space,
num_outputs=None,
model_config=config,
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.add_module("cnn_{}".format(i), cnn)
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
assert len(component.shape) == 1, \
"Only input Box 1D or 3D spaces allowed!"
concat_size += component.shape[-1]
self.logits_layer = None
self.value_layer = None
self._value_out = None
if num_outputs:
# Action-distribution head.
self.logits_layer = SlimFC(
in_size=concat_size,
out_size=num_outputs,
activation_fn=None,
)
# Create the value branch model.
self.value_layer = SlimFC(
in_size=concat_size,
out_size=1,
activation_fn=None,
initializer=torch_normc_initializer(0.01))
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
else:
outs.append(component)
# Concat all outputs and the non-image inputs.
out = torch.cat(outs, dim=1)
if self.logits_layer is None:
return out, []
# Value branch.
logits, values = self.logits_layer(out), self.value_layer(out)
self._value_out = torch.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
+173
View File
@@ -0,0 +1,173 @@
from gym.spaces import Box
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as \
TorchFullyConnectedNetwork
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_model_api_tf_begin__
class DuelingQModel(TFModelV2): # or: TorchModelV2
"""A simple, hard-coded dueling head model."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# Pass num_outputs=None into super constructor (so that no action/
# logits output layer is built).
# Alternatively, you can pass in num_outputs=[last layer size of
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
# this seems more tedious as you will have to explain users of this
# class that num_outputs is NOT the size of your Q-output layer.
super(DuelingQModel, self).__init__(obs_space, action_space, None,
model_config, name)
# Now: self.num_outputs contains the last layer's size, which
# we can use to construct the dueling head (see torch: SlimFC
# below).
# Construct advantage head ...
self.A = tf.keras.layers.Dense(num_outputs)
# torch:
# self.A = SlimFC(
# in_size=self.num_outputs, out_size=num_outputs)
# ... and value head.
self.V = tf.keras.layers.Dense(1)
# torch:
# self.V = SlimFC(in_size=self.num_outputs, out_size=1)
def get_q_values(self, underlying_output):
# Calculate q-values following dueling logic:
v = self.V(underlying_output) # value
a = self.A(underlying_output) # advantages (per action)
advantages_mean = tf.reduce_mean(a, 1)
advantages_centered = a - tf.expand_dims(advantages_mean, 1)
return v + advantages_centered # q-values
class TorchDuelingQModel(TorchModelV2):
"""A simple, hard-coded dueling head model."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# Pass num_outputs=None into super constructor (so that no action/
# logits output layer is built).
# Alternatively, you can pass in num_outputs=[last layer size of
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
# this seems more tedious as you will have to explain users of this
# class that num_outputs is NOT the size of your Q-output layer.
nn.Module.__init__(self)
super(TorchDuelingQModel, self).__init__(obs_space, action_space, None,
model_config, name)
# Now: self.num_outputs contains the last layer's size, which
# we can use to construct the dueling head (see torch: SlimFC
# below).
# Construct advantage head ...
self.A = SlimFC(in_size=self.num_outputs, out_size=num_outputs)
# ... and value head.
self.V = SlimFC(in_size=self.num_outputs, out_size=1)
def get_q_values(self, underlying_output):
# Calculate q-values following dueling logic:
v = self.V(underlying_output) # value
a = self.A(underlying_output) # advantages (per action)
advantages_mean = torch.mean(a, 1)
advantages_centered = a - torch.unsqueeze(advantages_mean, 1)
return v + advantages_centered # q-values
# __sphinx_doc_model_api_tf_end__
class ContActionQModel(TFModelV2):
"""A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# Pass num_outputs=None into super constructor (so that no action/
# logits output layer is built).
# Alternatively, you can pass in num_outputs=[last layer size of
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
# this seems more tedious as you will have to explain users of this
# class that num_outputs is NOT the size of your Q-output layer.
super(ContActionQModel, self).__init__(obs_space, action_space, None,
model_config, name)
# Now: self.num_outputs contains the last layer's size, which
# we can use to construct the single q-value computing head.
# Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
# to be used for Q-value calculation.
# Use the current value of self.num_outputs, which is the wrapped
# model's output layer size.
combined_space = Box(-1.0, 1.0,
(self.num_outputs + action_space.shape[0], ))
self.q_head = FullyConnectedNetwork(combined_space, action_space, 1,
model_config, "q_head")
# Missing here: Probably still have to provide action output layer
# and value layer and make sure self.num_outputs is correctly set.
def get_single_q_value(self, underlying_output, action):
# Calculate the q-value after concating the underlying output with
# the given action.
input_ = tf.concat([underlying_output, action], axis=-1)
# Construct a simple input_dict (needed for self.q_head as it's an
# RLlib ModelV2).
input_dict = {"obs": input_}
# Ignore state outputs.
q_values, _ = self.q_head(input_dict)
return q_values
# __sphinx_doc_model_api_torch_start__
class TorchContActionQModel(TorchModelV2):
"""A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
nn.Module.__init__(self)
# Pass num_outputs=None into super constructor (so that no action/
# logits output layer is built).
# Alternatively, you can pass in num_outputs=[last layer size of
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
# this seems more tedious as you will have to explain users of this
# class that num_outputs is NOT the size of your Q-output layer.
super(TorchContActionQModel, self).__init__(obs_space, action_space,
None, model_config, name)
# Now: self.num_outputs contains the last layer's size, which
# we can use to construct the single q-value computing head.
# Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
# to be used for Q-value calculation.
# Use the current value of self.num_outputs, which is the wrapped
# model's output layer size.
combined_space = Box(-1.0, 1.0,
(self.num_outputs + action_space.shape[0], ))
self.q_head = TorchFullyConnectedNetwork(combined_space, action_space,
1, model_config, "q_head")
# Missing here: Probably still have to provide action output layer
# and value layer and make sure self.num_outputs is correctly set.
def get_single_q_value(self, underlying_output, action):
# Calculate the q-value after concating the underlying output with
# the given action.
input_ = torch.cat([underlying_output, action], dim=-1)
# Construct a simple input_dict (needed for self.q_head as it's an
# RLlib ModelV2).
input_dict = {"obs": input_}
# Ignore state outputs.
q_values, _ = self.q_head(input_dict)
return q_values
# __sphinx_doc_model_api_torch_end__
@@ -0,0 +1,108 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_model_api_begin__
class FrameStackingCartPoleModel(TFModelV2):
"""A simple FC model that takes the last n observations as input."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
num_frames=3):
super(FrameStackingCartPoleModel, self).__init__(
obs_space, action_space, None, model_config, name)
self.num_frames = num_frames
self.num_outputs = num_outputs
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
input_ = tf.keras.layers.Input(
shape=(self.num_frames, obs_space.shape[0]))
reshaped = tf.keras.layers.Reshape(
[obs_space.shape[0] * self.num_frames])(input_)
layer1 = tf.keras.layers.Dense(64, activation=tf.nn.relu)(reshaped)
out = tf.keras.layers.Dense(self.num_outputs)(layer1)
values = tf.keras.layers.Dense(1)(layer1)
self.base_model = tf.keras.models.Model([input_], [out, values])
self.register_variables(self.base_model.variables)
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs",
shift="-{}:0".format(num_frames - 1),
space=obs_space)
self.view_requirements["prev_rewards"] = ViewRequirement(
data_col="rewards", shift=-1)
def forward(self, input_dict, states, seq_lens):
obs = input_dict["prev_n_obs"]
out, self._last_value = self.base_model(obs)
return out, []
def value_function(self):
return tf.squeeze(self._last_value, -1)
# __sphinx_doc_model_api_end__
class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
"""A simple FC model that takes the last n observations as input."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
num_frames=3):
nn.Module.__init__(self)
super(TorchFrameStackingCartPoleModel, self).__init__(
obs_space, action_space, None, model_config, name)
self.num_frames = num_frames
self.num_outputs = num_outputs
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
self.layer1 = SlimFC(
in_size=obs_space.shape[0] * self.num_frames,
out_size=64,
activation_fn="relu")
self.out = SlimFC(
in_size=64, out_size=self.num_outputs, activation_fn="linear")
self.values = SlimFC(in_size=64, out_size=1, activation_fn="linear")
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs",
shift="-{}:0".format(num_frames - 1),
space=obs_space)
self.view_requirements["prev_rewards"] = ViewRequirement(
data_col="rewards", shift=-1)
def forward(self, input_dict, states, seq_lens):
obs = input_dict["prev_n_obs"]
obs = torch.reshape(obs,
[-1, self.obs_space.shape[0] * self.num_frames])
features = self.layer1(obs)
out = self.out(features)
self._last_value = self.values(features)
return out, []
def value_function(self):
return torch.squeeze(self._last_value, -1)
+50
View File
@@ -0,0 +1,50 @@
import argparse
import ray
from ray import tune
from ray.rllib.examples.models.trajectory_view_utilizing_models import \
FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved
tf1, tf, tfv = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=50)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--stop-reward", type=float, default=150.0)
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=3)
ModelCatalog.register_custom_model(
"frame_stack_model", FrameStackingCartPoleModel
if args.framework != "torch" else TorchFrameStackingCartPoleModel)
config = {
"env": "CartPole-v0",
"model": {
"custom_model": "frame_stack_model",
"custom_model_config": {
"num_frames": 4,
}
},
"framework": args.framework,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=2)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()
+17 -4
View File
@@ -33,14 +33,27 @@ logger = logging.getLogger(__name__)
# __sphinx_doc_begin__
MODEL_DEFAULTS: ModelConfigDict = {
# === Built-in options ===
# Number of hidden layers for fully connected net
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
# These are used if no custom model is specified and the input space is 1D.
# Number of hidden layers to be used.
"fcnet_hiddens": [256, 256],
# Nonlinearity for fully connected net (tanh, relu)
# Activation function descriptor.
# Supported values are: "tanh", "relu", "swish" (or "silu"),
# "linear" (or None).
"fcnet_activation": "tanh",
# Filter config. List of [out_channels, kernel, stride] for each filter
# VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
# These are used if no custom model is specified and the input space is 2D.
# Filter config: List of [out_channels, kernel, stride] for each filter.
# Example:
# Use None for making RLlib try to find a default filter setup given the
# observation space.
"conv_filters": None,
# Nonlinearity for built-in convnet
# Activation function descriptor.
# Supported values are: "tanh", "relu", "swish" (or "silu"),
# "linear" (or None).
"conv_activation": "relu",
# For DiagGaussian action distributions, make the second half of the model
# outputs floating bias variables instead of state-dependent. This only
# has an effect is using the default fully connected net.
+10 -2
View File
@@ -346,16 +346,24 @@ class ModelV2:
data_col = view_req.data_col or view_col
if index == "last":
data_col = last_mappings.get(data_col, data_col)
# Range needed.
if view_req.shift_from is not None:
data = sample_batch[view_col][-1]
traj_len = len(sample_batch[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
obs_shift = -1 if data_col in [
SampleBatch.OBS, SampleBatch.NEXT_OBS
] else 0
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array([
np.concatenate([
data, sample_batch[data_col][-missing_at_end:]
])[view_req.shift_from:view_req.shift_to +
1 if view_req.shift_to != -1 else None]
])[from_:to_]
])
# Single index.
else:
data = sample_batch[data_col][-1]
input_dict[view_col] = np.array([data])
+92 -19
View File
@@ -329,7 +329,7 @@ class TFPolicy(Policy):
builder = TFRunBuilder(self._sess, "compute_actions")
to_fetch = self._build_compute_actions(
builder,
obs_batch,
obs_batch=obs_batch,
state_batches=state_batches,
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
@@ -345,6 +345,33 @@ class TFPolicy(Policy):
return fetched
@override(Policy)
def compute_actions_from_input_dict(
self,
input_dict: Dict[str, TensorType],
explore: bool = None,
timestep: Optional[int] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict")
obs_batch = input_dict[SampleBatch.OBS]
to_fetch = self._build_compute_actions(
builder, input_dict=input_dict, explore=explore, timestep=timestep)
# Execute session run to get action (and other fetches).
fetched = builder.get(to_fetch)
# Update our global timestep by the batch size.
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
else obs_batch.shape[0]
return fetched
@override(Policy)
def compute_log_likelihoods(
self,
@@ -700,15 +727,15 @@ class TFPolicy(Policy):
def _build_compute_actions(self,
builder,
obs_batch,
*,
input_dict=None,
obs_batch=None,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None,
explore=None,
timestep=None):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
@@ -716,27 +743,73 @@ class TFPolicy(Policy):
self.exploration.before_compute_actions(
timestep=timestep, explore=explore, tf_sess=self.get_session())
state_batches = state_batches or []
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, got {}".
format(self._state_inputs, state_batches))
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
if self._prev_action_input is not None and \
prev_action_batch is not None:
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
if self._prev_reward_input is not None and \
prev_reward_batch is not None:
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
# `input_dict` given: Simply build what's in that dict.
if input_dict is not None:
if hasattr(self, "_input_dict"):
for key, value in input_dict.items():
if key in self._input_dict:
builder.add_feed_dict({self._input_dict[key]: value})
# For policies that inherit directly from TFPolicy.
else:
builder.add_feed_dict({
self._obs_input: input_dict[SampleBatch.OBS]
})
if SampleBatch.PREV_ACTIONS in input_dict:
builder.add_feed_dict({
self._prev_action_input: input_dict[
SampleBatch.PREV_ACTIONS]
})
if SampleBatch.PREV_REWARDS in input_dict:
builder.add_feed_dict({
self._prev_reward_input: input_dict[
SampleBatch.PREV_REWARDS]
})
state_batches = []
i = 0
while "state_in_{}".format(i) in input_dict:
state_batches.append(input_dict["state_in_{}".format(i)])
i += 1
builder.add_feed_dict(
dict(zip(self._state_inputs, state_batches)))
if "state_in_0" in input_dict:
builder.add_feed_dict({
self._seq_lens: np.ones(len(input_dict["state_in_0"]))
})
# Hardcoded old way: Build fixed fields, if provided.
# TODO: (sven) This can be deprecated after trajectory view API flag is
# removed and always True.
else:
state_batches = state_batches or []
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, "
"got {}".format(self._state_inputs, state_batches))
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({
self._seq_lens: np.ones(len(obs_batch))
})
if self._prev_action_input is not None and \
prev_action_batch is not None:
builder.add_feed_dict({
self._prev_action_input: prev_action_batch
})
if self._prev_reward_input is not None and \
prev_reward_batch is not None:
builder.add_feed_dict({
self._prev_reward_input: prev_reward_batch
})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
builder.add_feed_dict({self._is_training: False})
builder.add_feed_dict({self._is_exploring: explore})
if timestep is not None:
builder.add_feed_dict({self._timestep: timestep})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
# Determine, what exactly to fetch from the graph.
to_fetch = [self._sampled_action] + self._state_outputs + \
+7 -7
View File
@@ -311,13 +311,13 @@ class TestMultiAgentEnv(unittest.TestCase):
def test_returning_model_based_rollouts_data(self):
class ModelBasedPolicy(DQNTFPolicy):
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None,
**kwargs):
def compute_actions_from_input_dict(self,
input_dict,
explore=None,
timestep=None,
episodes=None,
**kwargs):
obs_batch = input_dict["obs"]
# In policy loss initialization phase, no episodes are passed
# in.
if episodes is not None:
+3 -1
View File
@@ -276,7 +276,7 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
if framework == "torch":
if name in ["linear", None]:
return None
if name == "swish":
if name in ["swish", "silu"]:
from ray.rllib.utils.torch_ops import Swish
return Swish
_, nn = try_import_torch()
@@ -297,6 +297,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
else:
if name in ["linear", None]:
return None
if name == "swish":
name = "silu"
tf1, tf, tfv = try_import_tf()
fn = getattr(tf.nn, name, None)
if fn is not None: