mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[RLlib] Preparatory PR for: Documentation on Model Building. (#13260)
This commit is contained in:
+35
@@ -1808,6 +1808,23 @@ py_test(
|
||||
args = ["--stop-iters=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_model_api_tf",
|
||||
main = "examples/custom_model_api.py",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "small",
|
||||
srcs = ["examples/custom_model_api.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_model_api_torch",
|
||||
main = "examples/custom_model_api.py",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "small",
|
||||
srcs = ["examples/custom_model_api.py"],
|
||||
args = ["--framework=torch"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_observation_filters",
|
||||
main = "examples/custom_observation_filters.py",
|
||||
@@ -2047,6 +2064,24 @@ py_test(
|
||||
args = ["--as-test", "--torch"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/trajectory_view_api_tf",
|
||||
main = "examples/trajectory_view_api.py",
|
||||
tags = ["examples", "examples_T"],
|
||||
size = "medium",
|
||||
srcs = ["examples/trajectory_view_api.py"],
|
||||
args = ["--as-test", "--framework=tf", "--stop-reward=80.0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/trajectory_view_api_torch",
|
||||
main = "examples/trajectory_view_api.py",
|
||||
tags = ["examples", "examples_T"],
|
||||
size = "medium",
|
||||
srcs = ["examples/trajectory_view_api.py"],
|
||||
args = ["--as-test", "--framework=torch", "--stop-reward=80.0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/two_trainer_workflow_tf",
|
||||
main = "examples/two_trainer_workflow.py",
|
||||
|
||||
@@ -166,6 +166,8 @@ class _AgentCollector:
|
||||
# resulting data=[[-3, -2, -1], [7, 8, 9]]
|
||||
# Range of 3 consecutive items repeats every 10 timesteps.
|
||||
if view_req.shift_from is not None:
|
||||
# Batch repeat value > 1: Only repeat the shift_from/to range
|
||||
# every n timesteps.
|
||||
if view_req.batch_repeat_value > 1:
|
||||
count = int(
|
||||
math.ceil((len(np_data[data_col]) - self.shift_before)
|
||||
@@ -179,11 +181,20 @@ class _AgentCollector:
|
||||
view_req.shift_to + 1 + obs_shift]
|
||||
for i in range(count)
|
||||
])
|
||||
# Batch repeat value = 1: Repeat the shift_from/to range at
|
||||
# each timestep.
|
||||
else:
|
||||
data = np_data[data_col][self.shift_before +
|
||||
view_req.shift_from +
|
||||
obs_shift:self.shift_before +
|
||||
view_req.shift_to + 1 + obs_shift]
|
||||
d = np_data[data_col]
|
||||
# TODO: For now, assume simple 1D data (B x x).
|
||||
# Will expand this for Atari examples.
|
||||
assert len(d.shape) == 2
|
||||
shift_win = view_req.shift_to - view_req.shift_from + 1
|
||||
data_size = d.itemsize * int(np.product(d.shape[1:]))
|
||||
|
||||
data = np.lib.stride_tricks.as_strided(
|
||||
d[self.shift_before - shift_win:],
|
||||
[self.agent_steps, shift_win, d.shape[1]],
|
||||
[data_size, data_size, d.itemsize])
|
||||
# Set of (probably non-consecutive) indices.
|
||||
# Example:
|
||||
# shift=[-3, 0]
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
|
||||
TorchDuelingQModel, ContActionQModel, TorchContActionQModel
|
||||
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# Test API wrapper for dueling Q-head.
|
||||
|
||||
obs_space = Box(-1.0, 1.0, (3, ))
|
||||
action_space = Discrete(3)
|
||||
|
||||
# Run in eager mode for value checking and debugging.
|
||||
tf1.enable_eager_execution()
|
||||
|
||||
# __sphinx_doc_model_construct_begin__
|
||||
my_dueling_model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=action_space.n,
|
||||
model_config=MODEL_DEFAULTS,
|
||||
framework=args.framework,
|
||||
# Providing the `model_interface` arg will make the factory
|
||||
# wrap the chosen default model with our new model API class
|
||||
# (DuelingQModel). This way, both `forward` and `get_q_values`
|
||||
# are available in the returned class.
|
||||
model_interface=DuelingQModel
|
||||
if args.framework != "torch" else TorchDuelingQModel,
|
||||
name="dueling_q_model",
|
||||
)
|
||||
# __sphinx_doc_model_construct_end__
|
||||
|
||||
batch_size = 10
|
||||
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
|
||||
# Note that for PyTorch, you will have to provide torch tensors here.
|
||||
if args.framework == "torch":
|
||||
input_ = torch.from_numpy(input_)
|
||||
|
||||
input_dict = {
|
||||
"obs": input_,
|
||||
"is_training": False,
|
||||
}
|
||||
out, state_outs = my_dueling_model(input_dict=input_dict)
|
||||
assert out.shape == (10, 256)
|
||||
# Pass `out` into `get_q_values`
|
||||
q_values = my_dueling_model.get_q_values(out)
|
||||
assert q_values.shape == (10, action_space.n)
|
||||
|
||||
# Test API wrapper for single value Q-head from obs/action input.
|
||||
|
||||
obs_space = Box(-1.0, 1.0, (3, ))
|
||||
action_space = Box(-1.0, -1.0, (2, ))
|
||||
|
||||
# __sphinx_doc_model_construct_begin__
|
||||
my_cont_action_q_model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=2,
|
||||
model_config=MODEL_DEFAULTS,
|
||||
framework=args.framework,
|
||||
# Providing the `model_interface` arg will make the factory
|
||||
# wrap the chosen default model with our new model API class
|
||||
# (DuelingQModel). This way, both `forward` and `get_q_values`
|
||||
# are available in the returned class.
|
||||
model_interface=ContActionQModel
|
||||
if args.framework != "torch" else TorchContActionQModel,
|
||||
name="cont_action_q_model",
|
||||
)
|
||||
# __sphinx_doc_model_construct_end__
|
||||
|
||||
batch_size = 10
|
||||
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
|
||||
|
||||
# Note that for PyTorch, you will have to provide torch tensors here.
|
||||
if args.framework == "torch":
|
||||
input_ = torch.from_numpy(input_)
|
||||
|
||||
input_dict = {
|
||||
"obs": input_,
|
||||
"is_training": False,
|
||||
}
|
||||
# Note that for PyTorch, you will have to provide torch tensors here.
|
||||
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
|
||||
assert out.shape == (10, 256)
|
||||
# Pass `out` and an action into `my_cont_action_q_model`
|
||||
action = np.array([action_space.sample() for _ in range(batch_size)])
|
||||
if args.framework == "torch":
|
||||
action = torch.from_numpy(action)
|
||||
|
||||
q_value = my_cont_action_q_model.get_single_q_value(out, action)
|
||||
assert q_value.shape == (10, 1)
|
||||
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
|
||||
# The custom model that will be wrapped by an LSTM.
|
||||
class MyCustomModel(TorchModelV2):
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
self.num_outputs = int(np.product(self.obs_space.shape))
|
||||
self._last_batch_size = None
|
||||
|
||||
# Implement your own forward logic, whose output will then be sent
|
||||
# through an LSTM.
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
obs = input_dict["obs_flat"]
|
||||
# Store last batch size for value_function output.
|
||||
self._last_batch_size = obs.shape[0]
|
||||
# Return 2x the obs (and empty states).
|
||||
# This will further be sent through an automatically provided
|
||||
# LSTM head (b/c we are setting use_lstm=True below).
|
||||
return obs * 2.0, []
|
||||
|
||||
def value_function(self):
|
||||
return torch.from_numpy(np.zeros(shape=(self._last_batch_size, )))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
# Register the above custom model.
|
||||
ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)
|
||||
|
||||
# Create the Trainer.
|
||||
trainer = ppo.PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"framework": "torch",
|
||||
"model": {
|
||||
# Auto-wrap the custom(!) model with an LSTM.
|
||||
"use_lstm": True,
|
||||
# To further customize the LSTM auto-wrapper.
|
||||
"lstm_cell_size": 64,
|
||||
|
||||
# Specify our custom model from above.
|
||||
"custom_model": "my_torch_model",
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
"custom_model_config": {},
|
||||
},
|
||||
})
|
||||
trainer.train()
|
||||
|
||||
# __sphinx_doc_end__
|
||||
@@ -0,0 +1,220 @@
|
||||
from gym.spaces import Discrete, Tuple
|
||||
|
||||
from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.misc import normc_initializer as \
|
||||
torch_normc_initializer, SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
class CNNPlusFCConcatModel(TFModelV2):
|
||||
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
|
||||
|
||||
Note: This model should be used for complex (Dict or Tuple) observation
|
||||
spaces that have one or more image components.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# TODO: (sven) Support Dicts as well.
|
||||
assert isinstance(obs_space.original_space, (Tuple)), \
|
||||
"`obs_space.original_space` must be Tuple!"
|
||||
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
# Build the CNN(s) given obs_space's image components.
|
||||
self.cnns = {}
|
||||
concat_size = 0
|
||||
for i, component in enumerate(obs_space.original_space):
|
||||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
}
|
||||
cnn = ModelCatalog.get_model_v2(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
framework="tf",
|
||||
name="cnn_{}".format(i))
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
self.register_variables(cnn.variables())
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
concat_size += component.n
|
||||
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
|
||||
# Everything else (1D Box).
|
||||
else:
|
||||
assert len(component.shape) == 1, \
|
||||
"Only input Box 1D or 3D spaces allowed!"
|
||||
concat_size += component.shape[-1]
|
||||
|
||||
self.logits_and_value_model = None
|
||||
self._value_out = None
|
||||
if num_outputs:
|
||||
# Action-distribution head.
|
||||
concat_layer = tf.keras.layers.Input((concat_size, ))
|
||||
logits_layer = tf.keras.layers.Dense(
|
||||
num_outputs,
|
||||
activation=tf.keras.activations.linear,
|
||||
name="logits")(concat_layer)
|
||||
|
||||
# Create the value branch model.
|
||||
value_layer = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(concat_layer)
|
||||
self.logits_and_value_model = tf.keras.models.Model(
|
||||
concat_layer, [logits_layer, value_layer])
|
||||
self.register_variables(self.logits_and_value_model.variables)
|
||||
else:
|
||||
self.num_outputs = concat_size
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Push image observations through our CNNs.
|
||||
outs = []
|
||||
for i, component in enumerate(input_dict["obs"]):
|
||||
if i in self.cnns:
|
||||
cnn_out, _ = self.cnns[i]({"obs": component})
|
||||
outs.append(cnn_out)
|
||||
else:
|
||||
outs.append(component)
|
||||
# Concat all outputs and the non-image inputs.
|
||||
out = tf.concat(outs, axis=1)
|
||||
if not self.logits_and_value_model:
|
||||
return out, []
|
||||
|
||||
# Value branch.
|
||||
logits, values = self.logits_and_value_model(out)
|
||||
self._value_out = tf.reshape(values, [-1])
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return self._value_out
|
||||
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class TorchCNNPlusFCConcatModel(TorchModelV2, nn.Module):
|
||||
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
|
||||
|
||||
Note: This model should be used for complex (Dict or Tuple) observation
|
||||
spaces that have one or more image components.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# TODO: (sven) Support Dicts as well.
|
||||
assert isinstance(obs_space.original_space, (Tuple)), \
|
||||
"`obs_space.original_space` must be Tuple!"
|
||||
|
||||
nn.Module.__init__(self)
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
|
||||
self.cnn_type = self.model_config["custom_model_config"].get(
|
||||
"conv_type", "atari")
|
||||
|
||||
# Build the CNN(s) given obs_space's image components.
|
||||
self.cnns = {}
|
||||
concat_size = 0
|
||||
for i, component in enumerate(obs_space.original_space):
|
||||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
}
|
||||
if self.cnn_type == "atari":
|
||||
cnn = ModelCatalog.get_model_v2(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
framework="torch",
|
||||
name="cnn_{}".format(i))
|
||||
else:
|
||||
cnn = TorchImpalaVisionNet(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
name="cnn_{}".format(i))
|
||||
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
self.add_module("cnn_{}".format(i), cnn)
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
concat_size += component.n
|
||||
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
|
||||
# Everything else (1D Box).
|
||||
else:
|
||||
assert len(component.shape) == 1, \
|
||||
"Only input Box 1D or 3D spaces allowed!"
|
||||
concat_size += component.shape[-1]
|
||||
|
||||
self.logits_layer = None
|
||||
self.value_layer = None
|
||||
self._value_out = None
|
||||
|
||||
if num_outputs:
|
||||
# Action-distribution head.
|
||||
self.logits_layer = SlimFC(
|
||||
in_size=concat_size,
|
||||
out_size=num_outputs,
|
||||
activation_fn=None,
|
||||
)
|
||||
# Create the value branch model.
|
||||
self.value_layer = SlimFC(
|
||||
in_size=concat_size,
|
||||
out_size=1,
|
||||
activation_fn=None,
|
||||
initializer=torch_normc_initializer(0.01))
|
||||
else:
|
||||
self.num_outputs = concat_size
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Push image observations through our CNNs.
|
||||
outs = []
|
||||
for i, component in enumerate(input_dict["obs"]):
|
||||
if i in self.cnns:
|
||||
cnn_out, _ = self.cnns[i]({"obs": component})
|
||||
outs.append(cnn_out)
|
||||
else:
|
||||
outs.append(component)
|
||||
# Concat all outputs and the non-image inputs.
|
||||
out = torch.cat(outs, dim=1)
|
||||
if self.logits_layer is None:
|
||||
return out, []
|
||||
|
||||
# Value branch.
|
||||
logits, values = self.logits_layer(out), self.value_layer(out)
|
||||
self._value_out = torch.reshape(values, [-1])
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return self._value_out
|
||||
@@ -0,0 +1,173 @@
|
||||
from gym.spaces import Box
|
||||
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as \
|
||||
TorchFullyConnectedNetwork
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
# __sphinx_doc_model_api_tf_begin__
|
||||
class DuelingQModel(TFModelV2): # or: TorchModelV2
|
||||
"""A simple, hard-coded dueling head model."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# Pass num_outputs=None into super constructor (so that no action/
|
||||
# logits output layer is built).
|
||||
# Alternatively, you can pass in num_outputs=[last layer size of
|
||||
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
|
||||
# this seems more tedious as you will have to explain users of this
|
||||
# class that num_outputs is NOT the size of your Q-output layer.
|
||||
super(DuelingQModel, self).__init__(obs_space, action_space, None,
|
||||
model_config, name)
|
||||
# Now: self.num_outputs contains the last layer's size, which
|
||||
# we can use to construct the dueling head (see torch: SlimFC
|
||||
# below).
|
||||
|
||||
# Construct advantage head ...
|
||||
self.A = tf.keras.layers.Dense(num_outputs)
|
||||
# torch:
|
||||
# self.A = SlimFC(
|
||||
# in_size=self.num_outputs, out_size=num_outputs)
|
||||
|
||||
# ... and value head.
|
||||
self.V = tf.keras.layers.Dense(1)
|
||||
# torch:
|
||||
# self.V = SlimFC(in_size=self.num_outputs, out_size=1)
|
||||
|
||||
def get_q_values(self, underlying_output):
|
||||
# Calculate q-values following dueling logic:
|
||||
v = self.V(underlying_output) # value
|
||||
a = self.A(underlying_output) # advantages (per action)
|
||||
advantages_mean = tf.reduce_mean(a, 1)
|
||||
advantages_centered = a - tf.expand_dims(advantages_mean, 1)
|
||||
return v + advantages_centered # q-values
|
||||
|
||||
|
||||
class TorchDuelingQModel(TorchModelV2):
|
||||
"""A simple, hard-coded dueling head model."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# Pass num_outputs=None into super constructor (so that no action/
|
||||
# logits output layer is built).
|
||||
# Alternatively, you can pass in num_outputs=[last layer size of
|
||||
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
|
||||
# this seems more tedious as you will have to explain users of this
|
||||
# class that num_outputs is NOT the size of your Q-output layer.
|
||||
nn.Module.__init__(self)
|
||||
super(TorchDuelingQModel, self).__init__(obs_space, action_space, None,
|
||||
model_config, name)
|
||||
# Now: self.num_outputs contains the last layer's size, which
|
||||
# we can use to construct the dueling head (see torch: SlimFC
|
||||
# below).
|
||||
|
||||
# Construct advantage head ...
|
||||
self.A = SlimFC(in_size=self.num_outputs, out_size=num_outputs)
|
||||
|
||||
# ... and value head.
|
||||
self.V = SlimFC(in_size=self.num_outputs, out_size=1)
|
||||
|
||||
def get_q_values(self, underlying_output):
|
||||
# Calculate q-values following dueling logic:
|
||||
v = self.V(underlying_output) # value
|
||||
a = self.A(underlying_output) # advantages (per action)
|
||||
advantages_mean = torch.mean(a, 1)
|
||||
advantages_centered = a - torch.unsqueeze(advantages_mean, 1)
|
||||
return v + advantages_centered # q-values
|
||||
|
||||
|
||||
# __sphinx_doc_model_api_tf_end__
|
||||
|
||||
|
||||
class ContActionQModel(TFModelV2):
|
||||
"""A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# Pass num_outputs=None into super constructor (so that no action/
|
||||
# logits output layer is built).
|
||||
# Alternatively, you can pass in num_outputs=[last layer size of
|
||||
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
|
||||
# this seems more tedious as you will have to explain users of this
|
||||
# class that num_outputs is NOT the size of your Q-output layer.
|
||||
super(ContActionQModel, self).__init__(obs_space, action_space, None,
|
||||
model_config, name)
|
||||
|
||||
# Now: self.num_outputs contains the last layer's size, which
|
||||
# we can use to construct the single q-value computing head.
|
||||
|
||||
# Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
|
||||
# to be used for Q-value calculation.
|
||||
# Use the current value of self.num_outputs, which is the wrapped
|
||||
# model's output layer size.
|
||||
combined_space = Box(-1.0, 1.0,
|
||||
(self.num_outputs + action_space.shape[0], ))
|
||||
self.q_head = FullyConnectedNetwork(combined_space, action_space, 1,
|
||||
model_config, "q_head")
|
||||
|
||||
# Missing here: Probably still have to provide action output layer
|
||||
# and value layer and make sure self.num_outputs is correctly set.
|
||||
|
||||
def get_single_q_value(self, underlying_output, action):
|
||||
# Calculate the q-value after concating the underlying output with
|
||||
# the given action.
|
||||
input_ = tf.concat([underlying_output, action], axis=-1)
|
||||
# Construct a simple input_dict (needed for self.q_head as it's an
|
||||
# RLlib ModelV2).
|
||||
input_dict = {"obs": input_}
|
||||
# Ignore state outputs.
|
||||
q_values, _ = self.q_head(input_dict)
|
||||
return q_values
|
||||
|
||||
|
||||
# __sphinx_doc_model_api_torch_start__
|
||||
class TorchContActionQModel(TorchModelV2):
|
||||
"""A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
nn.Module.__init__(self)
|
||||
# Pass num_outputs=None into super constructor (so that no action/
|
||||
# logits output layer is built).
|
||||
# Alternatively, you can pass in num_outputs=[last layer size of
|
||||
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
|
||||
# this seems more tedious as you will have to explain users of this
|
||||
# class that num_outputs is NOT the size of your Q-output layer.
|
||||
super(TorchContActionQModel, self).__init__(obs_space, action_space,
|
||||
None, model_config, name)
|
||||
|
||||
# Now: self.num_outputs contains the last layer's size, which
|
||||
# we can use to construct the single q-value computing head.
|
||||
|
||||
# Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
|
||||
# to be used for Q-value calculation.
|
||||
# Use the current value of self.num_outputs, which is the wrapped
|
||||
# model's output layer size.
|
||||
combined_space = Box(-1.0, 1.0,
|
||||
(self.num_outputs + action_space.shape[0], ))
|
||||
self.q_head = TorchFullyConnectedNetwork(combined_space, action_space,
|
||||
1, model_config, "q_head")
|
||||
|
||||
# Missing here: Probably still have to provide action output layer
|
||||
# and value layer and make sure self.num_outputs is correctly set.
|
||||
|
||||
def get_single_q_value(self, underlying_output, action):
|
||||
# Calculate the q-value after concating the underlying output with
|
||||
# the given action.
|
||||
input_ = torch.cat([underlying_output, action], dim=-1)
|
||||
# Construct a simple input_dict (needed for self.q_head as it's an
|
||||
# RLlib ModelV2).
|
||||
input_dict = {"obs": input_}
|
||||
# Ignore state outputs.
|
||||
q_values, _ = self.q_head(input_dict)
|
||||
return q_values
|
||||
|
||||
|
||||
# __sphinx_doc_model_api_torch_end__
|
||||
@@ -0,0 +1,108 @@
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
# __sphinx_doc_model_api_begin__
|
||||
|
||||
|
||||
class FrameStackingCartPoleModel(TFModelV2):
|
||||
"""A simple FC model that takes the last n observations as input."""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
num_frames=3):
|
||||
super(FrameStackingCartPoleModel, self).__init__(
|
||||
obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.num_frames = num_frames
|
||||
self.num_outputs = num_outputs
|
||||
|
||||
# Construct actual (very simple) FC model.
|
||||
assert len(obs_space.shape) == 1
|
||||
input_ = tf.keras.layers.Input(
|
||||
shape=(self.num_frames, obs_space.shape[0]))
|
||||
reshaped = tf.keras.layers.Reshape(
|
||||
[obs_space.shape[0] * self.num_frames])(input_)
|
||||
layer1 = tf.keras.layers.Dense(64, activation=tf.nn.relu)(reshaped)
|
||||
out = tf.keras.layers.Dense(self.num_outputs)(layer1)
|
||||
values = tf.keras.layers.Dense(1)(layer1)
|
||||
self.base_model = tf.keras.models.Model([input_], [out, values])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
self._last_value = None
|
||||
|
||||
self.view_requirements["prev_n_obs"] = ViewRequirement(
|
||||
data_col="obs",
|
||||
shift="-{}:0".format(num_frames - 1),
|
||||
space=obs_space)
|
||||
self.view_requirements["prev_rewards"] = ViewRequirement(
|
||||
data_col="rewards", shift=-1)
|
||||
|
||||
def forward(self, input_dict, states, seq_lens):
|
||||
obs = input_dict["prev_n_obs"]
|
||||
out, self._last_value = self.base_model(obs)
|
||||
return out, []
|
||||
|
||||
def value_function(self):
|
||||
return tf.squeeze(self._last_value, -1)
|
||||
|
||||
|
||||
# __sphinx_doc_model_api_end__
|
||||
|
||||
|
||||
class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
|
||||
"""A simple FC model that takes the last n observations as input."""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
num_frames=3):
|
||||
nn.Module.__init__(self)
|
||||
super(TorchFrameStackingCartPoleModel, self).__init__(
|
||||
obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.num_frames = num_frames
|
||||
self.num_outputs = num_outputs
|
||||
|
||||
# Construct actual (very simple) FC model.
|
||||
assert len(obs_space.shape) == 1
|
||||
self.layer1 = SlimFC(
|
||||
in_size=obs_space.shape[0] * self.num_frames,
|
||||
out_size=64,
|
||||
activation_fn="relu")
|
||||
self.out = SlimFC(
|
||||
in_size=64, out_size=self.num_outputs, activation_fn="linear")
|
||||
self.values = SlimFC(in_size=64, out_size=1, activation_fn="linear")
|
||||
|
||||
self._last_value = None
|
||||
|
||||
self.view_requirements["prev_n_obs"] = ViewRequirement(
|
||||
data_col="obs",
|
||||
shift="-{}:0".format(num_frames - 1),
|
||||
space=obs_space)
|
||||
self.view_requirements["prev_rewards"] = ViewRequirement(
|
||||
data_col="rewards", shift=-1)
|
||||
|
||||
def forward(self, input_dict, states, seq_lens):
|
||||
obs = input_dict["prev_n_obs"]
|
||||
obs = torch.reshape(obs,
|
||||
[-1, self.obs_space.shape[0] * self.num_frames])
|
||||
features = self.layer1(obs)
|
||||
out = self.out(features)
|
||||
self._last_value = self.values(features)
|
||||
return out, []
|
||||
|
||||
def value_function(self):
|
||||
return torch.squeeze(self._last_value, -1)
|
||||
@@ -0,0 +1,50 @@
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.examples.models.trajectory_view_utilizing_models import \
|
||||
FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
parser.add_argument("--as-test", action="store_true")
|
||||
parser.add_argument("--stop-iters", type=int, default=50)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||
parser.add_argument("--stop-reward", type=float, default=150.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(num_cpus=3)
|
||||
|
||||
ModelCatalog.register_custom_model(
|
||||
"frame_stack_model", FrameStackingCartPoleModel
|
||||
if args.framework != "torch" else TorchFrameStackingCartPoleModel)
|
||||
|
||||
config = {
|
||||
"env": "CartPole-v0",
|
||||
"model": {
|
||||
"custom_model": "frame_stack_model",
|
||||
"custom_model_config": {
|
||||
"num_frames": 4,
|
||||
}
|
||||
},
|
||||
"framework": args.framework,
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
ray.shutdown()
|
||||
+17
-4
@@ -33,14 +33,27 @@ logger = logging.getLogger(__name__)
|
||||
# __sphinx_doc_begin__
|
||||
MODEL_DEFAULTS: ModelConfigDict = {
|
||||
# === Built-in options ===
|
||||
# Number of hidden layers for fully connected net
|
||||
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
|
||||
# These are used if no custom model is specified and the input space is 1D.
|
||||
# Number of hidden layers to be used.
|
||||
"fcnet_hiddens": [256, 256],
|
||||
# Nonlinearity for fully connected net (tanh, relu)
|
||||
# Activation function descriptor.
|
||||
# Supported values are: "tanh", "relu", "swish" (or "silu"),
|
||||
# "linear" (or None).
|
||||
"fcnet_activation": "tanh",
|
||||
# Filter config. List of [out_channels, kernel, stride] for each filter
|
||||
|
||||
# VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
|
||||
# These are used if no custom model is specified and the input space is 2D.
|
||||
# Filter config: List of [out_channels, kernel, stride] for each filter.
|
||||
# Example:
|
||||
# Use None for making RLlib try to find a default filter setup given the
|
||||
# observation space.
|
||||
"conv_filters": None,
|
||||
# Nonlinearity for built-in convnet
|
||||
# Activation function descriptor.
|
||||
# Supported values are: "tanh", "relu", "swish" (or "silu"),
|
||||
# "linear" (or None).
|
||||
"conv_activation": "relu",
|
||||
|
||||
# For DiagGaussian action distributions, make the second half of the model
|
||||
# outputs floating bias variables instead of state-dependent. This only
|
||||
# has an effect is using the default fully connected net.
|
||||
|
||||
+10
-2
@@ -346,16 +346,24 @@ class ModelV2:
|
||||
data_col = view_req.data_col or view_col
|
||||
if index == "last":
|
||||
data_col = last_mappings.get(data_col, data_col)
|
||||
# Range needed.
|
||||
if view_req.shift_from is not None:
|
||||
data = sample_batch[view_col][-1]
|
||||
traj_len = len(sample_batch[data_col])
|
||||
missing_at_end = traj_len % view_req.batch_repeat_value
|
||||
obs_shift = -1 if data_col in [
|
||||
SampleBatch.OBS, SampleBatch.NEXT_OBS
|
||||
] else 0
|
||||
from_ = view_req.shift_from + obs_shift
|
||||
to_ = view_req.shift_to + obs_shift + 1
|
||||
if to_ == 0:
|
||||
to_ = None
|
||||
input_dict[view_col] = np.array([
|
||||
np.concatenate([
|
||||
data, sample_batch[data_col][-missing_at_end:]
|
||||
])[view_req.shift_from:view_req.shift_to +
|
||||
1 if view_req.shift_to != -1 else None]
|
||||
])[from_:to_]
|
||||
])
|
||||
# Single index.
|
||||
else:
|
||||
data = sample_batch[data_col][-1]
|
||||
input_dict[view_col] = np.array([data])
|
||||
|
||||
+92
-19
@@ -329,7 +329,7 @@ class TFPolicy(Policy):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
to_fetch = self._build_compute_actions(
|
||||
builder,
|
||||
obs_batch,
|
||||
obs_batch=obs_batch,
|
||||
state_batches=state_batches,
|
||||
prev_action_batch=prev_action_batch,
|
||||
prev_reward_batch=prev_reward_batch,
|
||||
@@ -345,6 +345,33 @@ class TFPolicy(Policy):
|
||||
|
||||
return fetched
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
input_dict: Dict[str, TensorType],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
episodes: Optional[List["MultiAgentEpisode"]] = None,
|
||||
**kwargs) -> \
|
||||
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict")
|
||||
obs_batch = input_dict[SampleBatch.OBS]
|
||||
to_fetch = self._build_compute_actions(
|
||||
builder, input_dict=input_dict, explore=explore, timestep=timestep)
|
||||
|
||||
# Execute session run to get action (and other fetches).
|
||||
fetched = builder.get(to_fetch)
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
|
||||
else obs_batch.shape[0]
|
||||
|
||||
return fetched
|
||||
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
@@ -700,15 +727,15 @@ class TFPolicy(Policy):
|
||||
|
||||
def _build_compute_actions(self,
|
||||
builder,
|
||||
obs_batch,
|
||||
*,
|
||||
input_dict=None,
|
||||
obs_batch=None,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None):
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
@@ -716,27 +743,73 @@ class TFPolicy(Policy):
|
||||
self.exploration.before_compute_actions(
|
||||
timestep=timestep, explore=explore, tf_sess=self.get_session())
|
||||
|
||||
state_batches = state_batches or []
|
||||
if len(self._state_inputs) != len(state_batches):
|
||||
raise ValueError(
|
||||
"Must pass in RNN state batches for placeholders {}, got {}".
|
||||
format(self._state_inputs, state_batches))
|
||||
|
||||
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
if state_batches:
|
||||
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
||||
if self._prev_action_input is not None and \
|
||||
prev_action_batch is not None:
|
||||
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
||||
if self._prev_reward_input is not None and \
|
||||
prev_reward_batch is not None:
|
||||
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
||||
|
||||
# `input_dict` given: Simply build what's in that dict.
|
||||
if input_dict is not None:
|
||||
if hasattr(self, "_input_dict"):
|
||||
for key, value in input_dict.items():
|
||||
if key in self._input_dict:
|
||||
builder.add_feed_dict({self._input_dict[key]: value})
|
||||
# For policies that inherit directly from TFPolicy.
|
||||
else:
|
||||
builder.add_feed_dict({
|
||||
self._obs_input: input_dict[SampleBatch.OBS]
|
||||
})
|
||||
if SampleBatch.PREV_ACTIONS in input_dict:
|
||||
builder.add_feed_dict({
|
||||
self._prev_action_input: input_dict[
|
||||
SampleBatch.PREV_ACTIONS]
|
||||
})
|
||||
if SampleBatch.PREV_REWARDS in input_dict:
|
||||
builder.add_feed_dict({
|
||||
self._prev_reward_input: input_dict[
|
||||
SampleBatch.PREV_REWARDS]
|
||||
})
|
||||
state_batches = []
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in input_dict:
|
||||
state_batches.append(input_dict["state_in_{}".format(i)])
|
||||
i += 1
|
||||
builder.add_feed_dict(
|
||||
dict(zip(self._state_inputs, state_batches)))
|
||||
|
||||
if "state_in_0" in input_dict:
|
||||
builder.add_feed_dict({
|
||||
self._seq_lens: np.ones(len(input_dict["state_in_0"]))
|
||||
})
|
||||
|
||||
# Hardcoded old way: Build fixed fields, if provided.
|
||||
# TODO: (sven) This can be deprecated after trajectory view API flag is
|
||||
# removed and always True.
|
||||
else:
|
||||
state_batches = state_batches or []
|
||||
if len(self._state_inputs) != len(state_batches):
|
||||
raise ValueError(
|
||||
"Must pass in RNN state batches for placeholders {}, "
|
||||
"got {}".format(self._state_inputs, state_batches))
|
||||
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
if state_batches:
|
||||
builder.add_feed_dict({
|
||||
self._seq_lens: np.ones(len(obs_batch))
|
||||
})
|
||||
if self._prev_action_input is not None and \
|
||||
prev_action_batch is not None:
|
||||
builder.add_feed_dict({
|
||||
self._prev_action_input: prev_action_batch
|
||||
})
|
||||
if self._prev_reward_input is not None and \
|
||||
prev_reward_batch is not None:
|
||||
builder.add_feed_dict({
|
||||
self._prev_reward_input: prev_reward_batch
|
||||
})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
|
||||
builder.add_feed_dict({self._is_training: False})
|
||||
builder.add_feed_dict({self._is_exploring: explore})
|
||||
if timestep is not None:
|
||||
builder.add_feed_dict({self._timestep: timestep})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
|
||||
# Determine, what exactly to fetch from the graph.
|
||||
to_fetch = [self._sampled_action] + self._state_outputs + \
|
||||
|
||||
@@ -311,13 +311,13 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
def test_returning_model_based_rollouts_data(self):
|
||||
class ModelBasedPolicy(DQNTFPolicy):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
def compute_actions_from_input_dict(self,
|
||||
input_dict,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
obs_batch = input_dict["obs"]
|
||||
# In policy loss initialization phase, no episodes are passed
|
||||
# in.
|
||||
if episodes is not None:
|
||||
|
||||
@@ -276,7 +276,7 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
||||
if framework == "torch":
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
if name == "swish":
|
||||
if name in ["swish", "silu"]:
|
||||
from ray.rllib.utils.torch_ops import Swish
|
||||
return Swish
|
||||
_, nn = try_import_torch()
|
||||
@@ -297,6 +297,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
||||
else:
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
if name == "swish":
|
||||
name = "silu"
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
fn = getattr(tf.nn, name, None)
|
||||
if fn is not None:
|
||||
|
||||
Reference in New Issue
Block a user