[RLlib] Allow SAC to use custom models as Q- or policy nets and deprecate "state-preprocessor" for image spaces. (#13522)

This commit is contained in:
Sven Mika
2021-02-02 13:05:58 +01:00
committed by GitHub
parent fa4290090d
commit 52c94b7ee9
25 changed files with 1011 additions and 611 deletions
+1 -1
View File
@@ -453,7 +453,7 @@ with the remaining non-image (flat) inputs (the 1D Box and discrete/one-hot comp
Take a look at this model example that does exactly that:
.. literalinclude:: ../../rllib/examples/models/cnn_plus_fc_concat_model.py
.. literalinclude:: ../../rllib/models/tf/complex_input_net.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
+31 -10
View File
@@ -16,6 +16,7 @@ from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
@@ -39,16 +40,37 @@ DEFAULT_CONFIG = with_common_config({
# Use a e.g. conv2D state preprocessing network before concatenating the
# resulting (feature) vector with the action input for the input to
# the Q-networks.
"use_state_preprocessor": False,
# Model options for the Q network(s).
"use_state_preprocessor": DEPRECATED_VALUE,
# Model options for the Q network(s). These will override MODEL_DEFAULTS.
# The `Q_model` dict is treated just as the top-level `model` dict in
# setting up the Q-network(s) (2 if twin_q=True).
# That means, you can do for different observation spaces:
# obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
# obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
# -> post_fcnet
# obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
# -> vision-net -> concat w/ Box(1D) and action -> post_fcnet
# You can also have SAC use your custom_model as Q-model(s), by simply
# specifying the `custom_model` sub-key in below dict (just like you would
# do in the top-level `model` dict.
"Q_model": {
"fcnet_activation": "relu",
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define custom Q-model(s).
"custom_model_config": {},
},
# Model options for the policy function.
# Model options for the policy function (see `Q_model` above for details).
# The difference to `Q_model` above is that no action concat'ing is
# performed before the post_fcnet stack.
"policy_model": {
"fcnet_activation": "relu",
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
},
# Unsquash actions to the upper and lower bounds of env's action space.
# Ignored for discrete action spaces.
@@ -145,11 +167,10 @@ def validate_config(config: TrainerConfigDict) -> None:
Raises:
ValueError: In case something is wrong with the config.
"""
if config["model"].get("custom_model"):
logger.warning(
"Setting use_state_preprocessor=True since a custom model "
"was specified.")
config["use_state_preprocessor"] = True
if config["use_state_preprocessor"] != DEPRECATED_VALUE:
deprecation_warning(
old="config['use_state_preprocessor']", error=False)
config["use_state_preprocessor"] = DEPRECATED_VALUE
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")
+139 -73
View File
@@ -1,9 +1,12 @@
import gym
from gym.spaces import Box, Discrete
import numpy as np
from typing import Optional, Tuple
from typing import Dict, List, Optional
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.typing import ModelConfigDict, TensorType
@@ -14,14 +17,21 @@ tf1, tf, tfv = try_import_tf()
class SACTFModel(TFModelV2):
"""Extension of the standard TFModelV2 for SAC.
Instances of this Model get created via wrapping this class around another
default- or custom model (inside
rllib/agents/sac/sac_tf_policy.py::build_sac_model). Doing so simply adds
this class' methods (`get_q_values`, etc..) to the wrapped model, such that
the wrapped model can be used by the SAC algorithm.
To customize, do one of the following:
- sub-class SACTFModel and override one or more of its methods.
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
to specify your own custom Q-model(s) and policy-models, which will be
created within this SACTFModel (see `build_policy_model` and
`build_q_model`.
Note: It is not recommended to override the `forward` method for SAC. This
would lead to shared weights (between policy and Q-nets), which will then
not be optimized by either of the critic- or actor-optimizers!
Data flow:
`obs` -> forward() -> `model_out`
`obs` -> forward() (should stay a noop method!) -> `model_out`
`model_out` -> get_policy_output() -> pi(actions|obs)
`model_out`, `actions` -> get_q_values() -> Q(s, a)
`model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a)
@@ -33,20 +43,18 @@ class SACTFModel(TFModelV2):
num_outputs: Optional[int],
model_config: ModelConfigDict,
name: str,
actor_hidden_activation: str = "relu",
actor_hiddens: Tuple[int] = (256, 256),
critic_hidden_activation: str = "relu",
critic_hiddens: Tuple[int] = (256, 256),
policy_model_config: ModelConfigDict = None,
q_model_config: ModelConfigDict = None,
twin_q: bool = False,
initial_alpha: float = 1.0,
target_entropy: Optional[float] = None):
"""Initialize a SACTFModel instance.
Args:
actor_hidden_activation (str): Activation for the actor network.
actor_hiddens (list): Hidden layers sizes for the actor network.
critic_hidden_activation (str): Activation for the critic network.
critic_hiddens (list): Hidden layers sizes for the critic network.
policy_model_config (ModelConfigDict): The config dict for the
policy network.
q_model_config (ModelConfigDict): The config dict for the
Q-network(s) (2 if twin_q=True).
twin_q (bool): Build twin Q networks (Q-net and target) for more
stable Q-learning.
initial_alpha (float): The initial value for the to-be-optimized
@@ -77,54 +85,15 @@ class SACTFModel(TFModelV2):
action_outs = self.action_dim
q_outs = 1
self.model_out = tf.keras.layers.Input(
shape=(self.num_outputs, ), name="model_out")
self.action_model = tf.keras.Sequential([
tf.keras.layers.Dense(
units=hidden,
activation=getattr(tf.nn, actor_hidden_activation, None),
name="action_{}".format(i + 1))
for i, hidden in enumerate(actor_hiddens)
] + [
tf.keras.layers.Dense(
units=action_outs, activation=None, name="action_out")
])
self.shift_and_log_scale_diag = self.action_model(self.model_out)
self.actions_input = None
if not self.discrete:
self.actions_input = tf.keras.layers.Input(
shape=(self.action_dim, ), name="actions")
def build_q_net(name, observations, actions):
# For continuous actions: Feed obs and actions (concatenated)
# through the NN. For discrete actions, only obs.
q_net = tf.keras.Sequential(([
tf.keras.layers.Concatenate(axis=1),
] if not self.discrete else []) + [
tf.keras.layers.Dense(
units=units,
activation=getattr(tf.nn, critic_hidden_activation, None),
name="{}_hidden_{}".format(name, i))
for i, units in enumerate(critic_hiddens)
] + [
tf.keras.layers.Dense(
units=q_outs, activation=None, name="{}_out".format(name))
])
# TODO(hartikainen): Remove the unnecessary Model calls here
if self.discrete:
q_net = tf.keras.Model(observations, q_net(observations))
else:
q_net = tf.keras.Model([observations, actions],
q_net([observations, actions]))
return q_net
self.q_net = build_q_net("q", self.model_out, self.actions_input)
self.action_model = self.build_policy_model(
self.obs_space, action_outs, policy_model_config, "policy_model")
self.q_net = self.build_q_model(self.obs_space, self.action_space,
q_outs, q_model_config, "q")
if twin_q:
self.twin_q_net = build_q_net("twin_q", self.model_out,
self.actions_input)
self.twin_q_net = self.build_q_model(self.obs_space,
self.action_space, q_outs,
q_model_config, "twin_q")
else:
self.twin_q_net = None
@@ -143,6 +112,80 @@ class SACTFModel(TFModelV2):
target_entropy = -np.prod(action_space.shape)
self.target_entropy = target_entropy
@override(TFModelV2)
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""The common (Q-net and policy-net) forward pass.
NOTE: It is not(!) recommended to override this method as it would
introduce a shared pre-network, which would be updated by both
actor- and critic optimizers.
"""
return input_dict["obs"], state
def build_policy_model(self, obs_space, num_outputs, policy_model_config,
name):
"""Builds the policy model used by this SAC.
Override this method in a sub-class of SACTFModel to implement your
own policy net. Alternatively, simply set `custom_model` within the
top level SAC `policy_model` config key to make this default
implementation of `build_policy_model` use your custom policy network.
Returns:
TFModelV2: The TFModelV2 policy sub-model.
"""
model = ModelCatalog.get_model_v2(
obs_space,
self.action_space,
num_outputs,
policy_model_config,
framework="tf",
name=name)
return model
def build_q_model(self, obs_space, action_space, num_outputs,
q_model_config, name):
"""Builds one of the (twin) Q-nets used by this SAC.
Override this method in a sub-class of SACTFModel to implement your
own Q-nets. Alternatively, simply set `custom_model` within the
top level SAC `Q_model` config key to make this default implementation
of `build_q_model` use your custom Q-nets.
Returns:
TFModelV2: The TFModelV2 Q-net sub-model.
"""
self.concat_obs_and_actions = False
if self.discrete:
input_space = obs_space
else:
orig_space = getattr(obs_space, "original_space", obs_space)
if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
input_space = Box(
float("-inf"),
float("inf"),
shape=(orig_space.shape[0] + action_space.shape[0], ))
self.concat_obs_and_actions = True
else:
if isinstance(orig_space, gym.spaces.Tuple):
spaces = orig_space.spaces
elif isinstance(orig_space, gym.spaces.Dict):
spaces = list(orig_space.spaces.values())
else:
spaces = [obs_space]
input_space = gym.spaces.Tuple(spaces + [action_space])
model = ModelCatalog.get_model_v2(
input_space,
action_space,
num_outputs,
q_model_config,
framework="tf",
name=name)
return model
def get_q_values(self,
model_out: TensorType,
actions: Optional[TensorType] = None) -> TensorType:
@@ -161,12 +204,7 @@ class SACTFModel(TFModelV2):
Returns:
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
"""
# Continuous case -> concat actions to model_out.
if actions is not None:
return self.q_net([model_out, actions])
# Discrete case -> return q-vals for all actions.
else:
return self.q_net(model_out)
return self._get_q_value(model_out, actions, self.q_net)
def get_twin_q_values(self,
model_out: TensorType,
@@ -185,12 +223,32 @@ class SACTFModel(TFModelV2):
Returns:
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
"""
return self._get_q_value(model_out, actions, self.twin_q_net)
def _get_q_value(self, model_out, actions, net):
# Model outs may come as original Tuple/Dict observations, concat them
# here if this is the case.
if isinstance(net.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = tf.concat(model_out, axis=-1)
elif isinstance(model_out, dict):
model_out = list(model_out.values())
# Continuous case -> concat actions to model_out.
if actions is not None:
return self.twin_q_net([model_out, actions])
if self.concat_obs_and_actions:
input_dict = {"obs": tf.concat([model_out, actions], axis=-1)}
else:
input_dict = {"obs": force_list(model_out) + [actions]}
# Discrete case -> return q-vals for all actions.
else:
return self.twin_q_net(model_out)
input_dict = {"obs": model_out}
# Switch on training mode (when getting Q-values, we are usually in
# training).
input_dict["is_training"] = True
out, _ = net(input_dict, [], None)
return out
def get_policy_output(self, model_out: TensorType) -> TensorType:
"""Returns policy outputs, given the output of self.__call__().
@@ -207,15 +265,23 @@ class SACTFModel(TFModelV2):
Returns:
TensorType: Distribution inputs for sampling actions.
"""
return self.action_model(model_out)
# Model outs may come as original Tuple observations, concat them
# here if this is the case.
if isinstance(self.action_model.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = tf.concat(model_out, axis=-1)
elif isinstance(model_out, dict):
model_out = tf.concat(list(model_out.values()), axis=-1)
out, _ = self.action_model({"obs": model_out}, [], None)
return out
def policy_variables(self):
"""Return the list of variables for the policy net."""
return list(self.action_model.variables)
return self.action_model.variables()
def q_variables(self):
"""Return the list of variables for Q / twin Q nets."""
return self.q_net.variables + (self.twin_q_net.variables
if self.twin_q_net else [])
return self.q_net.variables() + (self.twin_q_net.variables()
if self.twin_q_net else [])
+24 -29
View File
@@ -6,6 +6,7 @@ import gym
from gym.spaces import Box, Discrete
from functools import partial
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple, Type, Union
import ray
@@ -17,7 +18,7 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.models import ModelCatalog
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
DiagGaussian, Dirichlet, SquashedGaussian, TFActionDistribution
@@ -55,40 +56,35 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
`policy.target_model`.
"""
# With separate state-preprocessor (before obs+action concat).
if config["use_state_preprocessor"]:
num_outputs = 256 # Flatten last Conv2D to this many nodes.
# No separate state-preprocessor: concat obs+actions right away.
else:
num_outputs = 0
# No state preprocessor: fcnet_hiddens should be empty.
if config["model"]["fcnet_hiddens"]:
logger.warning(
"When not using a state-preprocessor with SAC, `fcnet_hiddens`"
" will be set to an empty list! Any hidden layer sizes are "
"defined via `policy_model.fcnet_hiddens` and "
"`Q_model.fcnet_hiddens`.")
config["model"]["fcnet_hiddens"] = []
num_outputs = int(np.product(obs_space.shape))
# Force-ignore any additionally provided hidden layer sizes.
# Everything should be configured using SAC's "Q_model" and "policy_model"
# settings.
policy_model_config = MODEL_DEFAULTS.copy()
policy_model_config.update(config["policy_model"])
q_model_config = MODEL_DEFAULTS.copy()
q_model_config.update(config["Q_model"])
default_model_cls = SACTorchModel if config["framework"] == "torch" \
else SACTFModel
model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=config["model"],
framework=config["framework"],
model_interface=SACTorchModel
if config["framework"] == "torch" else SACTFModel,
default_model=default_model_cls,
name="sac_model",
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
critic_hidden_activation=config["Q_model"]["fcnet_activation"],
critic_hiddens=config["Q_model"]["fcnet_hiddens"],
policy_model_config=policy_model_config,
q_model_config=q_model_config,
twin_q=config["twin_q"],
initial_alpha=config["initial_alpha"],
target_entropy=config["target_entropy"])
assert isinstance(model, default_model_cls)
# Create an exact copy of the model and store it in `policy.target_model`.
# This will be used for tau-synched Q-target models that run behind the
# actual Q-networks and are used for target q-value calculations in the
@@ -99,17 +95,16 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
num_outputs=num_outputs,
model_config=config["model"],
framework=config["framework"],
model_interface=SACTorchModel
if config["framework"] == "torch" else SACTFModel,
default_model=default_model_cls,
name="target_sac_model",
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
critic_hidden_activation=config["Q_model"]["fcnet_activation"],
critic_hiddens=config["Q_model"]["fcnet_hiddens"],
policy_model_config=policy_model_config,
q_model_config=q_model_config,
twin_q=config["twin_q"],
initial_alpha=config["initial_alpha"],
target_entropy=config["target_entropy"])
assert isinstance(policy.target_model, default_model_cls)
return model
@@ -198,14 +193,14 @@ def get_distribution_inputs_and_class(
dist inputs, dist class, and a list of internal state outputs
(in the RNN case).
"""
# Get base-model output (w/o the SAC specific parts of the network).
model_out, state_out = model({
# Get base-model (forward) output (this should be a noop call).
forward_out, state_out = model({
"obs": obs_batch,
"is_training": policy._get_is_training_placeholder(),
}, [], None)
# Use the base output to get the policy outputs from the SAC model's
# policy components.
distribution_inputs = model.get_policy_output(model_out)
distribution_inputs = model.get_policy_output(forward_out)
# Get a distribution class to be used with the just calculated dist-inputs.
action_dist_class = _get_dist_class(policy.config, policy.action_space)
+140 -82
View File
@@ -1,11 +1,12 @@
import gym
from gym.spaces import Box, Discrete
import numpy as np
from typing import Optional, Tuple
from typing import Dict, List, Optional
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.typing import ModelConfigDict, TensorType
@@ -16,14 +17,21 @@ torch, nn = try_import_torch()
class SACTorchModel(TorchModelV2, nn.Module):
"""Extension of the standard TorchModelV2 for SAC.
Instances of this Model get created via wrapping this class around another
default- or custom model (inside
rllib/agents/sac/sac_torch_policy.py::build_sac_model). Doing so simply
adds this class' methods (`get_q_values`, etc..) to the wrapped model, such
that the wrapped model can be used by the SAC algorithm.
To customize, do one of the following:
- sub-class SACTorchModel and override one or more of its methods.
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
to specify your own custom Q-model(s) and policy-models, which will be
created within this SACTFModel (see `build_policy_model` and
`build_q_model`.
Note: It is not recommended to override the `forward` method for SAC. This
would lead to shared weights (between policy and Q-nets), which will then
not be optimized by either of the critic- or actor-optimizers!
Data flow:
`obs` -> forward() -> `model_out`
`obs` -> forward() (should stay a noop method!) -> `model_out`
`model_out` -> get_policy_output() -> pi(actions|obs)
`model_out`, `actions` -> get_q_values() -> Q(s, a)
`model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a)
@@ -35,20 +43,18 @@ class SACTorchModel(TorchModelV2, nn.Module):
num_outputs: Optional[int],
model_config: ModelConfigDict,
name: str,
actor_hidden_activation: str = "relu",
actor_hiddens: Tuple[int] = (256, 256),
critic_hidden_activation: str = "relu",
critic_hiddens: Tuple[int] = (256, 256),
policy_model_config: ModelConfigDict = None,
q_model_config: ModelConfigDict = None,
twin_q: bool = False,
initial_alpha: float = 1.0,
target_entropy: Optional[float] = None):
"""Initializes a SACTorchModel instance.
7
Args:
actor_hidden_activation (str): Activation for the actor network.
actor_hiddens (list): Hidden layers sizes for the actor network.
critic_hidden_activation (str): Activation for the critic network.
critic_hiddens (list): Hidden layers sizes for the critic network.
policy_model_config (ModelConfigDict): The config dict for the
policy network.
q_model_config (ModelConfigDict): The config dict for the
Q-network(s) (2 if twin_q=True).
twin_q (bool): Build twin Q networks (Q-net and target) for more
stable Q-learning.
initial_alpha (float): The initial value for the to-be-optimized
@@ -69,74 +75,29 @@ class SACTorchModel(TorchModelV2, nn.Module):
self.action_dim = action_space.n
self.discrete = True
action_outs = q_outs = self.action_dim
action_ins = None # No action inputs for the discrete case.
elif isinstance(action_space, Box):
self.action_dim = np.product(action_space.shape)
self.discrete = False
action_outs = 2 * self.action_dim
action_ins = self.action_dim
q_outs = 1
else:
assert isinstance(action_space, Simplex)
self.action_dim = np.product(action_space.shape)
self.discrete = False
action_outs = self.action_dim
action_ins = self.action_dim
q_outs = 1
# Build the policy network.
self.action_model = nn.Sequential()
ins = self.num_outputs
self.obs_ins = ins
activation = get_activation_fn(
actor_hidden_activation, framework="torch")
for i, n in enumerate(actor_hiddens):
self.action_model.add_module(
"action_{}".format(i),
SlimFC(
ins,
n,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=activation))
ins = n
self.action_model.add_module(
"action_out",
SlimFC(
ins,
action_outs,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=None))
self.action_model = self.build_policy_model(
self.obs_space, action_outs, policy_model_config, "policy_model")
# Build the Q-net(s), including target Q-net(s).
def build_q_net(name_):
activation = get_activation_fn(
critic_hidden_activation, framework="torch")
# For continuous actions: Feed obs and actions (concatenated)
# through the NN. For discrete actions, only obs.
q_net = nn.Sequential()
ins = self.obs_ins + (0 if self.discrete else action_ins)
for i, n in enumerate(critic_hiddens):
q_net.add_module(
"{}_hidden_{}".format(name_, i),
SlimFC(
ins,
n,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=activation))
ins = n
q_net.add_module(
"{}_out".format(name_),
SlimFC(
ins,
q_outs,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=None))
return q_net
self.q_net = build_q_net("q")
# Build the Q-network(s).
self.q_net = self.build_q_model(self.obs_space, self.action_space,
q_outs, q_model_config, "q")
if twin_q:
self.twin_q_net = build_q_net("twin_q")
self.twin_q_net = self.build_q_model(self.obs_space,
self.action_space, q_outs,
q_model_config, "twin_q")
else:
self.twin_q_net = None
@@ -157,6 +118,80 @@ class SACTorchModel(TorchModelV2, nn.Module):
self.target_entropy = torch.tensor(
data=[target_entropy], dtype=torch.float32, requires_grad=False)
@override(TorchModelV2)
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""The common (Q-net and policy-net) forward pass.
NOTE: It is not(!) recommended to override this method as it would
introduce a shared pre-network, which would be updated by both
actor- and critic optimizers.
"""
return input_dict["obs"], state
def build_policy_model(self, obs_space, num_outputs, policy_model_config,
name):
"""Builds the policy model used by this SAC.
Override this method in a sub-class of SACTFModel to implement your
own policy net. Alternatively, simply set `custom_model` within the
top level SAC `policy_model` config key to make this default
implementation of `build_policy_model` use your custom policy network.
Returns:
TorchModelV2: The TorchModelV2 policy sub-model.
"""
model = ModelCatalog.get_model_v2(
obs_space,
self.action_space,
num_outputs,
policy_model_config,
framework="torch",
name=name)
return model
def build_q_model(self, obs_space, action_space, num_outputs,
q_model_config, name):
"""Builds one of the (twin) Q-nets used by this SAC.
Override this method in a sub-class of SACTFModel to implement your
own Q-nets. Alternatively, simply set `custom_model` within the
top level SAC `Q_model` config key to make this default implementation
of `build_q_model` use your custom Q-nets.
Returns:
TorchModelV2: The TorchModelV2 Q-net sub-model.
"""
self.concat_obs_and_actions = False
if self.discrete:
input_space = obs_space
else:
orig_space = getattr(obs_space, "original_space", obs_space)
if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
input_space = Box(
float("-inf"),
float("inf"),
shape=(orig_space.shape[0] + action_space.shape[0], ))
self.concat_obs_and_actions = True
else:
if isinstance(orig_space, gym.spaces.Tuple):
spaces = orig_space.spaces
elif isinstance(orig_space, gym.spaces.Dict):
spaces = list(orig_space.spaces.values())
else:
spaces = [obs_space]
input_space = gym.spaces.Tuple(spaces + [action_space])
model = ModelCatalog.get_model_v2(
input_space,
action_space,
num_outputs,
q_model_config,
framework="torch",
name=name)
return model
def get_q_values(self,
model_out: TensorType,
actions: Optional[TensorType] = None) -> TensorType:
@@ -175,12 +210,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
Returns:
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
"""
# Continuous case -> concat actions to model_out.
if actions is not None:
return self.q_net(torch.cat([model_out, actions], -1))
# Discrete case -> return q-vals for all actions.
else:
return self.q_net(model_out)
return self._get_q_value(model_out, actions, self.q_net)
def get_twin_q_values(self,
model_out: TensorType,
@@ -199,12 +229,32 @@ class SACTorchModel(TorchModelV2, nn.Module):
Returns:
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
"""
return self._get_q_value(model_out, actions, self.twin_q_net)
def _get_q_value(self, model_out, actions, net):
# Model outs may come as original Tuple observations, concat them
# here if this is the case.
if isinstance(net.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = torch.cat(model_out, dim=-1)
elif isinstance(model_out, dict):
model_out = list(model_out.values())
# Continuous case -> concat actions to model_out.
if actions is not None:
return self.twin_q_net(torch.cat([model_out, actions], -1))
if self.concat_obs_and_actions:
input_dict = {"obs": torch.cat([model_out, actions], dim=-1)}
else:
input_dict = {"obs": force_list(model_out) + [actions]}
# Discrete case -> return q-vals for all actions.
else:
return self.twin_q_net(model_out)
input_dict = {"obs": model_out}
# Switch on training mode (when getting Q-values, we are usually in
# training).
input_dict["is_training"] = True
out, _ = net(input_dict, [], None)
return out
def get_policy_output(self, model_out: TensorType) -> TensorType:
"""Returns policy outputs, given the output of self.__call__().
@@ -221,15 +271,23 @@ class SACTorchModel(TorchModelV2, nn.Module):
Returns:
TensorType: Distribution inputs for sampling actions.
"""
return self.action_model(model_out)
# Model outs may come as original Tuple observations, concat them
# here if this is the case.
if isinstance(self.action_model.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = torch.cat(model_out, dim=-1)
elif isinstance(model_out, dict):
model_out = torch.cat(list(model_out.values()), dim=-1)
out, _ = self.action_model({"obs": model_out}, [], None)
return out
def policy_variables(self):
"""Return the list of variables for the policy net."""
return list(self.action_model.parameters())
return self.action_model.variables()
def q_variables(self):
"""Return the list of variables for Q / twin Q nets."""
return list(self.q_net.parameters()) + \
(list(self.twin_q_net.parameters()) if self.twin_q_net else [])
return self.q_net.variables() + (self.twin_q_net.variables()
if self.twin_q_net else [])
+127 -96
View File
@@ -1,5 +1,5 @@
from gym import Env
from gym.spaces import Box
from gym.spaces import Box, Discrete, Tuple
import numpy as np
import re
import unittest
@@ -9,6 +9,10 @@ import ray.rllib.agents.sac as sac
from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
loss_torch
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \
TorchBatchNormModel
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.tf.tf_action_dist import Dirichlet
from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
@@ -52,7 +56,7 @@ class SimpleEnv(Env):
class TestSAC(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
ray.init(local_mode=True)
@classmethod
def tearDownClass(cls) -> None:
@@ -61,22 +65,46 @@ class TestSAC(unittest.TestCase):
def test_sac_compilation(self):
"""Tests whether an SACTrainer can be built with all frameworks."""
config = sac.DEFAULT_CONFIG.copy()
config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
config["num_workers"] = 0 # Run locally.
config["twin_q"] = True
config["soft_horizon"] = True
config["clip_actions"] = False
config["normalize_actions"] = True
config["learning_starts"] = 0
config["prioritized_replay"] = True
config["rollout_fragment_length"] = 10
config["train_batch_size"] = 10
num_iterations = 1
for _ in framework_iterator(config):
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
ModelCatalog.register_custom_model("batch_norm_torch",
TorchBatchNormModel)
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
simple_space = Box(-1.0, 1.0, shape=(3, ))
for fw in framework_iterator(config):
# Test for different env types (discrete w/ and w/o image, + cont).
for env in [
"Pendulum-v0", "MsPacmanNoFrameskip-v4", "CartPole-v0"
RandomEnv,
"MsPacmanNoFrameskip-v4",
"CartPole-v0",
]:
print("Env={}".format(env))
config["use_state_preprocessor"] = \
env == "MsPacmanNoFrameskip-v4"
if env == RandomEnv:
config["env_config"] = {
"observation_space": Tuple(
[simple_space,
Discrete(2), image_space]),
"action_space": Box(-1.0, 1.0, shape=(1, )),
}
else:
config["env_config"] = {}
# Test making the Q-model a custom one for CartPole, otherwise,
# use the default model.
config["Q_model"]["custom_model"] = "batch_norm{}".format(
"_torch"
if fw == "torch" else "") if env == "CartPole-v0" else None
trainer = sac.SACTrainer(config=config, env=env)
for i in range(num_iterations):
results = trainer.train()
@@ -103,49 +131,56 @@ class TestSAC(unittest.TestCase):
config["env_config"] = {"simplex_actions": True}
map_ = {
# Normal net.
"default_policy/sequential/action_1/kernel": "action_model."
"action_0._model.0.weight",
"default_policy/sequential/action_1/bias": "action_model."
"action_0._model.0.bias",
"default_policy/sequential/action_out/kernel": "action_model."
"action_out._model.0.weight",
"default_policy/sequential/action_out/bias": "action_model."
"action_out._model.0.bias",
"default_policy/sequential_1/q_hidden_0/kernel": "q_net."
"q_hidden_0._model.0.weight",
"default_policy/sequential_1/q_hidden_0/bias": "q_net."
"q_hidden_0._model.0.bias",
"default_policy/sequential_1/q_out/kernel": "q_net."
"q_out._model.0.weight",
"default_policy/sequential_1/q_out/bias": "q_net."
"q_out._model.0.bias",
"default_policy/value_out/kernel": "_value_branch."
# Action net.
"default_policy/fc_1/kernel": "action_model._hidden_layers.0."
"_model.0.weight",
"default_policy/value_out/bias": "_value_branch."
"default_policy/fc_1/bias": "action_model._hidden_layers.0."
"_model.0.bias",
"default_policy/fc_out/kernel": "action_model."
"_logits._model.0.weight",
"default_policy/fc_out/bias": "action_model._logits._model.0.bias",
"default_policy/value_out/kernel": "action_model."
"_value_branch._model.0.weight",
"default_policy/value_out/bias": "action_model."
"_value_branch._model.0.bias",
# Q-net.
"default_policy/fc_1_1/kernel": "q_net."
"_hidden_layers.0._model.0.weight",
"default_policy/fc_1_1/bias": "q_net."
"_hidden_layers.0._model.0.bias",
"default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
"default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
"default_policy/value_out_1/kernel": "q_net."
"_value_branch._model.0.weight",
"default_policy/value_out_1/bias": "q_net."
"_value_branch._model.0.bias",
"default_policy/log_alpha": "log_alpha",
# Target net.
"default_policy/sequential_2/action_1/kernel": "action_model."
"action_0._model.0.weight",
"default_policy/sequential_2/action_1/bias": "action_model."
"action_0._model.0.bias",
"default_policy/sequential_2/action_out/kernel": "action_model."
"action_out._model.0.weight",
"default_policy/sequential_2/action_out/bias": "action_model."
"action_out._model.0.bias",
"default_policy/sequential_3/q_hidden_0/kernel": "q_net."
"q_hidden_0._model.0.weight",
"default_policy/sequential_3/q_hidden_0/bias": "q_net."
"q_hidden_0._model.0.bias",
"default_policy/sequential_3/q_out/kernel": "q_net."
"q_out._model.0.weight",
"default_policy/sequential_3/q_out/bias": "q_net."
"q_out._model.0.bias",
"default_policy/value_out_1/kernel": "_value_branch."
"_model.0.weight",
"default_policy/value_out_1/bias": "_value_branch."
"_model.0.bias",
# Target action-net.
"default_policy/fc_1_2/kernel": "action_model."
"_hidden_layers.0._model.0.weight",
"default_policy/fc_1_2/bias": "action_model."
"_hidden_layers.0._model.0.bias",
"default_policy/fc_out_2/kernel": "action_model."
"_logits._model.0.weight",
"default_policy/fc_out_2/bias": "action_model."
"_logits._model.0.bias",
"default_policy/value_out_2/kernel": "action_model."
"_value_branch._model.0.weight",
"default_policy/value_out_2/bias": "action_model."
"_value_branch._model.0.bias",
# Target Q-net
"default_policy/fc_1_3/kernel": "q_net."
"_hidden_layers.0._model.0.weight",
"default_policy/fc_1_3/bias": "q_net."
"_hidden_layers.0._model.0.bias",
"default_policy/fc_out_3/kernel": "q_net."
"_logits._model.0.weight",
"default_policy/fc_out_3/bias": "q_net."
"_logits._model.0.bias",
"default_policy/value_out_3/kernel": "q_net."
"_value_branch._model.0.weight",
"default_policy/value_out_3/bias": "q_net."
"_value_branch._model.0.bias",
"default_policy/log_alpha_1": "log_alpha",
}
@@ -225,10 +260,12 @@ class TestSAC(unittest.TestCase):
policy.td_error,
policy.optimizer().compute_gradients(
policy.critic_loss[0],
policy.model.q_variables()),
[v for v in policy.model.q_variables() if
"value_" not in v.name]),
policy.optimizer().compute_gradients(
policy.actor_loss,
policy.model.policy_variables()),
[v for v in policy.model.policy_variables() if
"value_" not in v.name]),
policy.optimizer().compute_gradients(
policy.alpha_loss, policy.model.log_alpha)],
feed_dict=policy._get_loss_inputs_dict(
@@ -261,8 +298,6 @@ class TestSAC(unittest.TestCase):
a.backward()
# `actor_loss` depends on Q-net vars (but these grads must
# be ignored and overridden in critic_loss.backward!).
assert not any(v.grad is None
for v in policy.model.q_variables())
assert not all(
torch.mean(v.grad) == 0
for v in policy.model.policy_variables())
@@ -273,45 +308,38 @@ class TestSAC(unittest.TestCase):
# Compare with tf ones.
torch_a_grads = [
v.grad for v in policy.model.policy_variables()
if v.grad is not None
]
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
if tf_g.shape != torch_g.shape:
check(tf_g, np.transpose(torch_g.detach().cpu()))
else:
check(tf_g, torch_g)
check(tf_a_grads[2],
np.transpose(torch_a_grads[0].detach().cpu()))
# Test critic gradients.
policy.critic_optims[0].zero_grad()
assert all(
torch.mean(v.grad) == 0.0
for v in policy.model.q_variables())
for v in policy.model.q_variables() if v.grad is not None)
assert all(
torch.min(v.grad) == 0.0
for v in policy.model.q_variables())
for v in policy.model.q_variables() if v.grad is not None)
assert policy.model.log_alpha.grad is None
c[0].backward()
assert not all(
torch.mean(v.grad) == 0
for v in policy.model.q_variables())
for v in policy.model.q_variables() if v.grad is not None)
assert not all(
torch.min(v.grad) == 0 for v in policy.model.q_variables())
torch.min(v.grad) == 0 for v in policy.model.q_variables()
if v.grad is not None)
assert policy.model.log_alpha.grad is None
# Compare with tf ones.
torch_c_grads = [v.grad for v in policy.model.q_variables()]
for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
if tf_g.shape != torch_g.shape:
check(tf_g, np.transpose(torch_g.detach().cpu()))
else:
check(tf_g, torch_g)
check(tf_c_grads[0],
np.transpose(torch_c_grads[2].detach().cpu()))
# Compare (unchanged(!) actor grads) with tf ones.
torch_a_grads = [
v.grad for v in policy.model.policy_variables()
]
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
if tf_g.shape != torch_g.shape:
check(tf_g, np.transpose(torch_g.detach().cpu()))
else:
check(tf_g, torch_g)
check(tf_a_grads[2],
np.transpose(torch_a_grads[0].detach().cpu()))
# Test alpha gradient.
policy.alpha_optim.zero_grad()
@@ -336,7 +364,7 @@ class TestSAC(unittest.TestCase):
prev_fw_loss = (c, a, e, t)
# Update weights from our batch (n times).
for update_iteration in range(10):
for update_iteration in range(5):
print("train iteration {}".format(update_iteration))
if fw == "tf":
in_ = self._get_batch_helper(obs_size, actions, batch_size)
@@ -350,10 +378,9 @@ class TestSAC(unittest.TestCase):
# Net must have changed.
if tf_updated_weights:
check(
updated_weights[
"default_policy/sequential/action_1/kernel"],
updated_weights["default_policy/fc_1/kernel"],
tf_updated_weights[-1][
"default_policy/sequential/action_1/kernel"],
"default_policy/fc_1/kernel"],
false=True)
tf_updated_weights.append(updated_weights)
@@ -367,7 +394,9 @@ class TestSAC(unittest.TestCase):
buf._fake_batch = in_
trainer.train()
# Compare updated model.
for tf_key in sorted(tf_weights.keys())[2:10]:
for tf_key in sorted(tf_weights.keys()):
if re.search("_[23]|alpha", tf_key):
continue
tf_var = tf_weights[tf_key]
torch_var = policy.model.state_dict()[map_[tf_key]]
if tf_var.shape != torch_var.shape:
@@ -381,7 +410,9 @@ class TestSAC(unittest.TestCase):
check(policy.model.log_alpha,
tf_weights["default_policy/log_alpha"])
# Compare target nets.
for tf_key in sorted(tf_weights.keys())[10:18]:
for tf_key in sorted(tf_weights.keys()):
if not re.search("_[23]", tf_key):
continue
tf_var = tf_weights[tf_key]
torch_var = policy.target_model.state_dict()[map_[
tf_key]]
@@ -437,9 +468,9 @@ class TestSAC(unittest.TestCase):
fc(
relu(
fc(model_out_t,
weights[ks[3]],
weights[ks[2]],
framework=fw)), weights[ks[5]], weights[ks[4]]), None)
weights[ks[1]],
weights[ks[0]],
framework=fw)), weights[ks[9]], weights[ks[8]]), None)
policy_t = action_dist_t.deterministic_sample()
log_pis_t = action_dist_t.logp(policy_t)
if sess:
@@ -452,9 +483,9 @@ class TestSAC(unittest.TestCase):
fc(
relu(
fc(model_out_tp1,
weights[ks[3]],
weights[ks[2]],
framework=fw)), weights[ks[5]], weights[ks[4]]), None)
weights[ks[1]],
weights[ks[0]],
framework=fw)), weights[ks[9]], weights[ks[8]]), None)
policy_tp1 = action_dist_tp1.deterministic_sample()
log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
if sess:
@@ -468,11 +499,11 @@ class TestSAC(unittest.TestCase):
relu(
fc(np.concatenate(
[model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
weights[ks[7]],
weights[ks[6]],
weights[ks[3]],
weights[ks[2]],
framework=fw)),
weights[ks[9]],
weights[ks[8]],
weights[ks[11]],
weights[ks[10]],
framework=fw)
# Q-values for current policy in given current state.
@@ -480,11 +511,11 @@ class TestSAC(unittest.TestCase):
q_t_det_policy = fc(
relu(
fc(np.concatenate([model_out_t, policy_t], -1),
weights[ks[7]],
weights[ks[6]],
weights[ks[3]],
weights[ks[2]],
framework=fw)),
weights[ks[9]],
weights[ks[8]],
weights[ks[11]],
weights[ks[10]],
framework=fw)
# Target q network evaluation.
@@ -493,11 +524,11 @@ class TestSAC(unittest.TestCase):
q_tp1 = fc(
relu(
fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
weights[ks[15]],
weights[ks[14]],
weights[ks[7]],
weights[ks[6]],
framework=fw)),
weights[ks[17]],
weights[ks[16]],
weights[ks[15]],
weights[ks[14]],
framework=fw)
else:
assert fw == "tfe"
@@ -538,9 +569,9 @@ class TestSAC(unittest.TestCase):
map_[k]: convert_to_torch_tensor(
np.transpose(v) if re.search("kernel", k) else np.array([v])
if re.search("log_alpha", k) else v)
for k, v in weights_dict.items()
if re.search("(sequential(/|_1)|value_out/|log_alpha)", k)
for i, (k, v) in enumerate(weights_dict.items()) if i < 13
}
return model_dict
def _translate_tfe_weights(self, weights_dict, map_):
+15 -8
View File
@@ -32,7 +32,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.sgd import do_minibatch_sgd
@@ -396,15 +396,22 @@ class RolloutWorker(ParallelIteratorWorker):
if clip_rewards is None:
clip_rewards = True
# framestacking via trajectory view API is enabled.
num_framestacks = model_config.get("num_framestacks", 0)
if not policy_config["_use_trajectory_view_api"]:
model_config["num_framestacks"] = num_framestacks = 0
elif num_framestacks == "auto":
model_config["num_framestacks"] = num_framestacks = 4
framestack_traj_view = num_framestacks > 1
# Deprecated way of framestacking is used.
framestack = model_config.get("framestack") is True
# framestacking via trajectory view API is enabled.
num_framestacks = model_config.get("num_framestacks", 0)
# No trajectory view API: No traj. view based framestacking.
if not policy_config["_use_trajectory_view_api"]:
model_config["num_framestacks"] = num_framestacks = 0
# Trajectory view API is on and num_framestacks=auto: Only
# stack traj. view based if old `framestack=[invalid value]`.
elif num_framestacks == "auto":
if framestack == DEPRECATED_VALUE:
model_config["num_framestacks"] = num_framestacks = 4
else:
model_config["num_framestacks"] = num_framestacks = 0
framestack_traj_view = num_framestacks > 1
def wrap(env):
env = wrap_deepmind(
@@ -1,218 +0,0 @@
from gym.spaces import Discrete, Tuple
from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import normc_initializer as \
torch_normc_initializer, SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_begin__
class CNNPlusFCConcatModel(TFModelV2):
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
assert isinstance(obs_space.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
concat_size = 0
for i, component in enumerate(obs_space.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
}
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="tf",
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
assert len(component.shape) == 1, \
"Only input Box 1D or 3D spaces allowed!"
concat_size += component.shape[-1]
self.logits_and_value_model = None
self._value_out = None
if num_outputs:
# Action-distribution head.
concat_layer = tf.keras.layers.Input((concat_size, ))
logits_layer = tf.keras.layers.Dense(
num_outputs,
activation=tf.keras.activations.linear,
name="logits")(concat_layer)
# Create the value branch model.
value_layer = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(concat_layer)
self.logits_and_value_model = tf.keras.models.Model(
concat_layer, [logits_layer, value_layer])
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
else:
outs.append(component)
# Concat all outputs and the non-image inputs.
out = tf.concat(outs, axis=1)
if not self.logits_and_value_model:
return out, []
# Value branch.
logits, values = self.logits_and_value_model(out)
self._value_out = tf.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
# __sphinx_doc_end__
class TorchCNNPlusFCConcatModel(TorchModelV2, nn.Module):
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
assert isinstance(obs_space.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
nn.Module.__init__(self)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
self.cnn_type = self.model_config["custom_model_config"].get(
"conv_type", "atari")
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
concat_size = 0
for i, component in enumerate(obs_space.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
}
if self.cnn_type == "atari":
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="torch",
name="cnn_{}".format(i))
else:
cnn = TorchImpalaVisionNet(
component,
action_space,
num_outputs=None,
model_config=config,
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.add_module("cnn_{}".format(i), cnn)
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
assert len(component.shape) == 1, \
"Only input Box 1D or 3D spaces allowed!"
concat_size += component.shape[-1]
self.logits_layer = None
self.value_layer = None
self._value_out = None
if num_outputs:
# Action-distribution head.
self.logits_layer = SlimFC(
in_size=concat_size,
out_size=num_outputs,
activation_fn=None,
)
# Create the value branch model.
self.value_layer = SlimFC(
in_size=concat_size,
out_size=1,
activation_fn=None,
initializer=torch_normc_initializer(0.01))
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
else:
outs.append(component)
# Concat all outputs and the non-image inputs.
out = torch.cat(outs, dim=1)
if self.logits_layer is None:
return out, []
# Value branch.
logits, values = self.logits_layer(out), self.value_layer(out)
self._value_out = torch.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
+38 -8
View File
@@ -19,7 +19,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
@@ -56,6 +56,18 @@ MODEL_DEFAULTS: ModelConfigDict = {
# "linear" (or None).
"conv_activation": "relu",
# Some default models support a final FC stack of n Dense layers with given
# activation:
# - Complex observation spaces: Image components are fed through
# VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
# everything is concated and pushed through this final FC stack.
# - VisionNets (CNNs), e.g. after the CNN stack, there may be
# additional Dense layers.
# - FullyConnectedNetworks will have this additional FCStack as well
# (that's why it's empty by default).
"post_fcnet_hiddens": [],
"post_fcnet_activation": "relu",
# For DiagGaussian action distributions, make the second half of the model
# outputs floating bias variables instead of state-dependent. This only
# has an effect is using the default fully connected net.
@@ -688,17 +700,22 @@ class ModelCatalog:
framework: str = "tf") -> Type[ModelV2]:
VisionNet = None
ComplexNet = None
if framework in ["tf2", "tf", "tfe"]:
from ray.rllib.models.tf.fcnet import \
FullyConnectedNetwork as FCNet
from ray.rllib.models.tf.visionnet import \
VisionNetwork as VisionNet
from ray.rllib.models.tf.complex_input_net import \
ComplexInputNetwork as ComplexNet
elif framework == "torch":
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
FCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
VisionNet)
from ray.rllib.models.torch.complex_input_net import \
ComplexInputNetwork as ComplexNet
elif framework == "jax":
from ray.rllib.models.jax.fcnet import (FullyConnectedNetwork as
FCNet)
@@ -710,16 +727,29 @@ class ModelCatalog:
# Discrete/1D obs-spaces or 2D obs space but traj. view framestacking
# disabled.
num_framestacks = model_config.get("num_framestacks", "auto")
# Tuple space, where at least one sub-space is image.
# -> Complex input model.
space_to_check = input_space if not hasattr(
input_space, "original_space") else input_space.original_space
if isinstance(input_space,
Tuple) or (isinstance(space_to_check, Tuple) and any(
isinstance(s, Box) and len(s.shape) >= 2
for s in space_to_check.spaces)):
return ComplexNet
# Single, flattenable/one-hot-abe space -> Simple FCNet.
if isinstance(input_space, (Discrete, MultiDiscrete)) or \
len(input_space.shape) == 1 or (
len(input_space.shape) == 2 and (
num_framestacks == "auto" or num_framestacks <= 1)):
return FCNet
# Default Conv2D net.
else:
if framework == "jax":
raise NotImplementedError("No Conv2D default net for JAX yet!")
return VisionNet
elif framework == "jax":
raise NotImplementedError("No non-FC default net for JAX yet!")
# Last resort: Conv2D stack for single image spaces.
return VisionNet
@staticmethod
def _get_multi_action_distribution(dist_class, action_space, config,
@@ -768,8 +798,8 @@ class ModelCatalog:
"framework=jax so far!")
if config.get("framestack") != DEPRECATED_VALUE:
deprecation_warning(
old="framestack", new="num_framestacks (int)", error=False)
# deprecation_warning(
# old="framestack", new="num_framestacks (int)", error=False)
# If old behavior is desired, disable traj. view-style
# framestacking.
config["num_framestacks"] = 0
+17 -23
View File
@@ -203,9 +203,13 @@ class ModelV2:
restored = input_dict.copy()
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
else:
try:
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"],
self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
except AttributeError:
restored["obs_flat"] = input_dict["obs"]
with self.context():
res = self.forward(restored, state or [], seq_lens)
@@ -216,15 +220,6 @@ class ModelV2:
"got {}".format(res))
outputs, state = res
try:
shape = outputs.shape
except AttributeError:
raise ValueError("Output is not a tensor: {}".format(outputs))
else:
if len(shape) != 2 or int(shape[1]) != self.num_outputs:
raise ValueError(
"Expected output shape of [None, {}], got {}".format(
self.num_outputs, shape))
if not isinstance(state, list):
raise ValueError("State output is not a list: {}".format(state))
@@ -418,15 +413,15 @@ def restore_original_dimensions(obs: TensorType,
observation space.
"""
if hasattr(obs_space, "original_space"):
if tensorlib == "tf":
tensorlib = tf
elif tensorlib == "torch":
assert torch is not None
tensorlib = torch
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
else:
if tensorlib == "tf":
tensorlib = tf
elif tensorlib == "torch":
assert torch is not None
tensorlib = torch
original_space = getattr(obs_space, "original_space", obs_space)
if original_space is obs_space:
return obs
return _unpack_obs(obs, original_space, tensorlib=tensorlib)
# Cache of preprocessors, for if the user is calling unpack obs often.
@@ -490,7 +485,8 @@ def _unpack_obs(obs: TensorType, space: gym.Space,
tensorlib.reshape(obs_slice, batch_dims + list(p.shape)),
v,
tensorlib=tensorlib)
elif isinstance(space, Repeated):
# Repeated space.
else:
assert isinstance(prep, RepeatedValuesPreprocessor), prep
child_size = prep.child_preprocessor.size
# The list lengths are stored in the first slot of the flat obs.
@@ -503,8 +499,6 @@ def _unpack_obs(obs: TensorType, space: gym.Space,
with_repeat_dim, space.child_space, tensorlib=tensorlib)
return RepeatedValues(
u, lengths=lengths, max_len=prep._obs_space.max_len)
else:
assert False, space
return u
else:
return obs
+156
View File
@@ -0,0 +1,156 @@
from gym.spaces import Box, Discrete, Tuple
import numpy as np
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import one_hot
tf1, tf, tfv = try_import_tf()
# __sphinx_doc_begin__
class ComplexInputNetwork(TFModelV2):
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
The data flow is as follows:
`obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT`
`CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out`
`out` -> (optional) FC-stack -> `out2`
`out2` -> action (logits) and vaulue heads.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
self.original_space = obs_space.original_space if \
hasattr(obs_space, "original_space") else obs_space
assert isinstance(self.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
super().__init__(self.original_space, action_space, num_outputs,
model_config, name)
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
self.one_hot = {}
self.flatten = {}
concat_size = 0
for i, component in enumerate(self.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
"post_fcnet_hiddens": [],
}
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="tf",
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
self.one_hot[i] = True
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
self.flatten[i] = int(np.product(component.shape))
concat_size += self.flatten[i]
# Optional post-concat FC-stack.
post_fc_stack_config = {
"fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
"fcnet_activation": model_config.get("post_fcnet_activation",
"relu")
}
self.post_fc_stack = ModelCatalog.get_model_v2(
Box(float("-inf"),
float("inf"),
shape=(concat_size, ),
dtype=np.float32),
self.action_space,
None,
post_fc_stack_config,
framework="tf",
name="post_fc_stack")
# Actions and value heads.
self.logits_and_value_model = None
self._value_out = None
if num_outputs:
# Action-distribution head.
concat_layer = tf.keras.layers.Input(
(self.post_fc_stack.num_outputs, ))
logits_layer = tf.keras.layers.Dense(
num_outputs,
activation=tf.keras.activations.linear,
name="logits")(concat_layer)
# Create the value branch model.
value_layer = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(concat_layer)
self.logits_and_value_model = tf.keras.models.Model(
concat_layer, [logits_layer, value_layer])
else:
self.num_outputs = self.post_fc_stack.num_outputs
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
orig_obs = input_dict[SampleBatch.OBS]
else:
orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS],
self.obs_space, "tf")
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(orig_obs):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
outs.append(cnn_out)
elif i in self.one_hot:
if component.dtype in [tf.int32, tf.int64, tf.uint8]:
outs.append(
one_hot(component, self.original_space.spaces[i]))
else:
outs.append(component)
else:
outs.append(tf.reshape(component, [-1, self.flatten[i]]))
# Concat all outputs and the non-image inputs.
out = tf.concat(outs, axis=1)
# Push through (optional) FC-stack (this may be an empty stack).
out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None)
# No logits/value branches.
if not self.logits_and_value_model:
return out, []
# Logits- and value branches.
logits, values = self.logits_and_value_model(out)
self._value_out = tf.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
# __sphinx_doc_end__
+6 -2
View File
@@ -19,8 +19,12 @@ class FullyConnectedNetwork(TFModelV2):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
activation = get_activation_fn(model_config.get("fcnet_activation"))
hiddens = model_config.get("fcnet_hiddens", [])
hiddens = model_config.get("fcnet_hiddens", []) + \
model_config.get("post_fcnet_hiddens", [])
activation = model_config.get("fcnet_activation")
if not model_config.get("fcnet_hiddens", []):
activation = model_config.get("post_fcnet_activation")
activation = get_activation_fn(activation)
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
free_log_std = model_config.get("free_log_std")
+4 -3
View File
@@ -107,7 +107,8 @@ class TFModelV2(ModelV2):
if isinstance(struct, tf.keras.models.Model):
ret = {}
for var in struct.variables:
key = current_key + "." + re.sub("/", ".", var.name)
name = re.sub("/", ".", var.name)
key = current_key + "." + name
ret[key] = var
return ret
# Other TFModelV2: Include its vars into ours.
@@ -118,7 +119,7 @@ class TFModelV2(ModelV2):
}
# tf.Variable
elif isinstance(struct, tf.Variable):
return {current_key + "." + struct.name: struct}
return {current_key: struct}
# List/Tuple.
elif isinstance(struct, (tuple, list)):
ret = {}
@@ -133,7 +134,7 @@ class TFModelV2(ModelV2):
current_key += "_"
ret = {}
for key, value in struct.items():
sub_vars = TFModelV2._find_sub_modules(current_key + key,
sub_vars = TFModelV2._find_sub_modules(current_key + str(key),
value)
ret.update(sub_vars)
return ret
+70 -15
View File
@@ -13,7 +13,17 @@ tf1, tf, tfv = try_import_tf()
class VisionNetwork(TFModelV2):
"""Generic vision network implemented in ModelV2 API."""
"""Generic vision network implemented in ModelV2 API.
An additional post-conv fully connected stack can be added and configured
via the config keys:
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
`post_fcnet_activation`: Activation function to use for this FC stack.
Examples:
"""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
@@ -29,6 +39,12 @@ class VisionNetwork(TFModelV2):
filters = self.model_config["conv_filters"]
assert len(filters) > 0,\
"Must provide at least 1 entry in `conv_filters`!"
# Post FC net config.
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
post_fcnet_activation = get_activation_fn(
model_config.get("post_fcnet_activation"), framework="tf")
no_final_linear = self.model_config.get("no_final_linear")
vf_share_layers = self.model_config.get("vf_share_layers")
self.traj_view_framestacking = False
@@ -62,17 +78,29 @@ class VisionNetwork(TFModelV2):
out_size, kernel, stride = filters[-1]
# No final linear: Last layer is a Conv2D and uses num_outputs.
# No final linear: Last layer has activation function and exits with
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
# on `post_fcnet_...` settings).
if no_final_linear and num_outputs:
last_layer = tf.keras.layers.Conv2D(
num_outputs,
out_size if post_fcnet_hiddens else num_outputs,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_out")(last_layer)
conv_out = last_layer
# Add (optional) post-fc-stack after last Conv2D layer.
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
if post_fcnet_hiddens else
[])
for i, out_size in enumerate(layer_sizes):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
@@ -88,29 +116,56 @@ class VisionNetwork(TFModelV2):
# num_outputs defined. Use that to create an exact
# `num_output`-sized (1,1)-Conv2D.
if num_outputs:
conv_out = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
if post_fcnet_hiddens:
last_cnn = last_layer = tf.keras.layers.Conv2D(
post_fcnet_hiddens[0], [1, 1],
activation=post_fcnet_activation,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
[num_outputs]):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i + 1),
activation=post_fcnet_activation
if i < len(post_fcnet_hiddens) - 1 else None,
kernel_initializer=normc_initializer(1.0))(
last_layer)
else:
last_cnn = last_layer = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
if conv_out.shape[1] != 1 or conv_out.shape[2] != 1:
if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, 1, "
"1, {} (`num_outputs`)] shape (but in {})! Please "
"adjust your Conv2D stack such that the dims 1 and 2 "
"are both 1.".format(self.model_config["conv_filters"],
self.num_outputs,
list(conv_out.shape)))
list(last_cnn.shape)))
# num_outputs not known -> Flatten, then set self.num_outputs
# to the resulting number of nodes.
else:
self.last_layer_is_flattened = True
conv_out = tf.keras.layers.Flatten(
last_layer = tf.keras.layers.Flatten(
data_format="channels_last")(last_layer)
self.num_outputs = conv_out.shape[1]
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
self.num_outputs = last_layer.shape[1]
logits_out = last_layer
# Build the value layers
if vf_share_layers:
@@ -151,7 +206,7 @@ class VisionNetwork(TFModelV2):
value_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
# Optional: framestacking obs/new_obs for Atari.
if self.traj_view_framestacking:
+163
View File
@@ -0,0 +1,163 @@
from gym.spaces import Box, Discrete, Tuple
import numpy as np
# TODO (sven): add IMPALA-style option.
# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
from ray.rllib.models.torch.misc import normc_initializer as \
torch_normc_initializer, SlimFC
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import one_hot
torch, nn = try_import_torch()
class ComplexInputNetwork(TorchModelV2, nn.Module):
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
The data flow is as follows:
`obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT`
`CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out`
`out` -> (optional) FC-stack -> `out2`
`out2` -> action (logits) and vaulue heads.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
self.original_space = obs_space.original_space if \
hasattr(obs_space, "original_space") else obs_space
assert isinstance(self.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
nn.Module.__init__(self)
TorchModelV2.__init__(self, self.original_space, action_space,
num_outputs, model_config, name)
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
# self.cnn_type = self.model_config["custom_model_config"].get(
# "conv_type", "atari")
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
self.one_hot = {}
self.flatten = {}
concat_size = 0
for i, component in enumerate(self.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
"post_fcnet_hiddens": [],
}
# if self.cnn_type == "atari":
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="torch",
name="cnn_{}".format(i))
# TODO (sven): add IMPALA-style option.
# else:
# cnn = TorchImpalaVisionNet(
# component,
# action_space,
# num_outputs=None,
# model_config=config,
# name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.add_module("cnn_{}".format(i), cnn)
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
self.one_hot[i] = True
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
self.flatten[i] = int(np.product(component.shape))
concat_size += self.flatten[i]
# Optional post-concat FC-stack.
post_fc_stack_config = {
"fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
"fcnet_activation": model_config.get("post_fcnet_activation",
"relu")
}
self.post_fc_stack = ModelCatalog.get_model_v2(
Box(float("-inf"),
float("inf"),
shape=(concat_size, ),
dtype=np.float32),
self.action_space,
None,
post_fc_stack_config,
framework="torch",
name="post_fc_stack")
# Actions and value heads.
self.logits_layer = None
self.value_layer = None
self._value_out = None
if num_outputs:
# Action-distribution head.
self.logits_layer = SlimFC(
in_size=self.post_fc_stack.num_outputs,
out_size=num_outputs,
activation_fn=None,
)
# Create the value branch model.
self.value_layer = SlimFC(
in_size=self.post_fc_stack.num_outputs,
out_size=1,
activation_fn=None,
initializer=torch_normc_initializer(0.01))
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
elif i in self.one_hot:
if component.dtype in [torch.int32, torch.int64, torch.uint8]:
outs.append(
one_hot(component, self.original_space.spaces[i]))
else:
outs.append(component)
else:
outs.append(torch.reshape(component, [-1, self.flatten[i]]))
# Concat all outputs and the non-image inputs.
out = torch.cat(outs, dim=1)
# Push through (optional) FC-stack (this may be an empty stack).
out, _ = self.post_fc_stack({"obs": out}, [], None)
# No logits/value branches.
if self.logits_layer is None:
return out, []
# Logits- and value branches.
logits, values = self.logits_layer(out), self.value_layer(out)
self._value_out = torch.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
+4 -1
View File
@@ -24,8 +24,11 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
model_config, name)
nn.Module.__init__(self)
hiddens = model_config.get("fcnet_hiddens", []) + \
model_config.get("post_fcnet_hiddens", [])
activation = model_config.get("fcnet_activation")
hiddens = model_config.get("fcnet_hiddens", [])
if not model_config.get("fcnet_hiddens", []):
activation = model_config.get("post_fcnet_activation")
no_final_linear = model_config.get("no_final_linear")
self.vf_share_layers = model_config.get("vf_share_layers")
self.free_log_std = model_config.get("free_log_std")
+64 -20
View File
@@ -5,7 +5,7 @@ import gym
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import normc_initializer, same_padding, \
SlimConv2d, SlimFC
from ray.rllib.models.utils import get_filter_config
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
@@ -33,6 +33,12 @@ class VisionNetwork(TorchModelV2, nn.Module):
filters = self.model_config["conv_filters"]
assert len(filters) > 0,\
"Must provide at least 1 entry in `conv_filters`!"
# Post FC net config.
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
post_fcnet_activation = get_activation_fn(
model_config.get("post_fcnet_activation"), framework="torch")
no_final_linear = self.model_config.get("no_final_linear")
vf_share_layers = self.model_config.get("vf_share_layers")
@@ -68,17 +74,33 @@ class VisionNetwork(TorchModelV2, nn.Module):
out_channels, kernel, stride = filters[-1]
# No final linear: Last layer is a Conv2D and uses num_outputs.
# No final linear: Last layer has activation function and exits with
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
# on `post_fcnet_...` settings).
if no_final_linear and num_outputs:
out_channels = out_channels if post_fcnet_hiddens else num_outputs
layers.append(
SlimConv2d(
in_channels,
num_outputs,
out_channels,
kernel,
stride,
None, # padding=valid
activation_fn=activation))
out_channels = num_outputs
# Add (optional) post-fc-stack after last Conv2D layer.
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
if post_fcnet_hiddens else
[])
for i, out_size in enumerate(layer_sizes):
layers.append(
SlimFC(
in_size=out_channels,
out_size=out_size,
activation_fn=post_fcnet_activation,
initializer=normc_initializer(1.0)))
out_channels = out_size
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
@@ -99,12 +121,31 @@ class VisionNetwork(TorchModelV2, nn.Module):
np.ceil((in_size[1] - kernel[1]) / stride)
]
padding, _ = same_padding(in_size, [1, 1], [1, 1])
self._logits = SlimConv2d(
out_channels,
num_outputs, [1, 1],
1,
padding,
activation_fn=None)
if post_fcnet_hiddens:
layers.append(nn.Flatten())
in_size = out_channels
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens +
[num_outputs]):
layers.append(
SlimFC(
in_size=in_size,
out_size=out_size,
activation_fn=post_fcnet_activation
if i < len(post_fcnet_hiddens) - 1 else None,
initializer=normc_initializer(1.0)))
in_size = out_size
# Last layer is logits layer.
self._logits = layers.pop()
else:
self._logits = SlimConv2d(
out_channels,
num_outputs, [1, 1],
1,
padding,
activation_fn=None)
# num_outputs not known -> Flatten, then set self.num_outputs
# to the resulting number of nodes.
else:
@@ -196,16 +237,19 @@ class VisionNetwork(TorchModelV2, nn.Module):
if not self.last_layer_is_flattened:
if self._logits:
conv_out = self._logits(conv_out)
if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, {} "
"(`num_outputs`), 1, 1] shape (but in {})! Please adjust "
"your Conv2D stack such that the last 2 dims are both "
"1.".format(self.model_config["conv_filters"],
self.num_outputs, list(conv_out.shape)))
logits = conv_out.squeeze(3)
logits = logits.squeeze(2)
if len(conv_out.shape) == 4:
if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, {} "
"(`num_outputs`), 1, 1] shape (but in {})! Please "
"adjust your Conv2D stack such that the last 2 dims "
"are both 1.".format(self.model_config["conv_filters"],
self.num_outputs,
list(conv_out.shape)))
logits = conv_out.squeeze(3)
logits = logits.squeeze(2)
else:
logits = conv_out
return logits, state
else:
return conv_out, state
@@ -177,8 +177,8 @@ class TestComputeLogLikelihood(unittest.TestCase):
config,
prev_a,
continuous=True,
layer_key=("sequential/action", (2, 4),
("action_model.action_0.", "action_model.action_out.")),
layer_key=("fc", (0, 2), ("action_model._hidden_layers.0.",
"action_model._logits.")),
logp_func=logp_func)
def test_sac_discr(self):
@@ -188,12 +188,7 @@ class TestComputeLogLikelihood(unittest.TestCase):
config["policy_model"]["fcnet_activation"] = "linear"
prev_a = np.array(0)
do_test_log_likelihood(
sac.SACTrainer,
config,
prev_a,
layer_key=("sequential/action", (0, 2),
("action_model.action_0.", "action_model.action_out.")))
do_test_log_likelihood(sac.SACTrainer, config, prev_a)
if __name__ == "__main__":
+5 -1
View File
@@ -37,6 +37,10 @@ parser.add_argument(
"--yaml-dir",
type=str,
help="The directory in which to find all yamls to test.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Run ray in local mode for easier debugging.")
# Obsoleted arg, use --framework=torch instead.
parser.add_argument(
@@ -92,7 +96,7 @@ if __name__ == "__main__":
passed = False
for i in range(3):
try:
ray.init(num_cpus=5)
ray.init(num_cpus=5, local_mode=args.local_mode)
trials = run_experiments(experiments, resume=False, verbose=2)
finally:
ray.shutdown()
@@ -333,7 +333,7 @@ class NestedSpacesTest(unittest.TestCase):
def test_invalid_model2(self):
ModelCatalog.register_custom_model("invalid2", InvalidModel2)
self.assertRaisesRegexp(
ValueError, "Expected output shape of",
ValueError, "State output is not a list",
lambda: PGTrainer(
env="CartPole-v0", config={
"model": {
+1 -3
View File
@@ -15,7 +15,7 @@ from ray.rllib.utils.test_utils import framework_iterator
ACTION_SPACES_TO_TEST = {
"discrete": Discrete(5),
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
# "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
"tuple": Tuple(
[Discrete(2),
@@ -63,8 +63,6 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False):
p_done=1.0,
check_action_bounds=check_bounds)))
stat = "ok"
if alg == "SAC":
config["use_state_preprocessor"] = o_name in ["atari", "image"]
try:
a = get_agent_class(alg)(config=config, env=RandomEnv)
-2
View File
@@ -14,8 +14,6 @@ atari-sac-tf-and-torch:
framework:
grid_search: [tf, torch]
gamma: 0.99
# state-preprocessor=Our default Atari Conv2D-net.
use_state_preprocessor: true
Q_model:
hidden_activation: relu
hidden_layer_sizes: [512]
@@ -11,8 +11,6 @@ mspacman-sac-tf:
# Works for both torch and tf.
framework: tf
gamma: 0.99
# state-preprocessor=Our default Atari Conv2D-net.
use_state_preprocessor: true
Q_model:
fcnet_hiddens: [512]
fcnet_activation: relu
+1 -4
View File
@@ -301,13 +301,10 @@ def check_compute_single_action(trainer,
assert worker_set
if isinstance(worker_set, list):
obs_space = trainer.get_policy().observation_space
try:
obs_space = obs_space.original_space
except AttributeError:
pass
else:
obs_space = worker_set.local_worker().for_policy(
lambda p: p.observation_space)
obs_space = getattr(obs_space, "original_space", obs_space)
else:
method_to_test = pol.compute_single_action
obs_space = pol.observation_space
+1 -1
View File
@@ -22,6 +22,6 @@ def with_lock(func: Callable):
except AttributeError:
raise AttributeError(
"Object {} must have a `self._lock` property (assigned to a "
"threading.Lock() object in its constructor)!".format(self))
"threading.RLock() object in its constructor)!".format(self))
return wrapper