mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[RLlib] Allow SAC to use custom models as Q- or policy nets and deprecate "state-preprocessor" for image spaces. (#13522)
This commit is contained in:
@@ -453,7 +453,7 @@ with the remaining non-image (flat) inputs (the 1D Box and discrete/one-hot comp
|
||||
|
||||
Take a look at this model example that does exactly that:
|
||||
|
||||
.. literalinclude:: ../../rllib/examples/models/cnn_plus_fc_concat_model.py
|
||||
.. literalinclude:: ../../rllib/models/tf/complex_input_net.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
+31
-10
@@ -16,6 +16,7 @@ from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,16 +40,37 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Use a e.g. conv2D state preprocessing network before concatenating the
|
||||
# resulting (feature) vector with the action input for the input to
|
||||
# the Q-networks.
|
||||
"use_state_preprocessor": False,
|
||||
# Model options for the Q network(s).
|
||||
"use_state_preprocessor": DEPRECATED_VALUE,
|
||||
# Model options for the Q network(s). These will override MODEL_DEFAULTS.
|
||||
# The `Q_model` dict is treated just as the top-level `model` dict in
|
||||
# setting up the Q-network(s) (2 if twin_q=True).
|
||||
# That means, you can do for different observation spaces:
|
||||
# obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet
|
||||
# obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
|
||||
# -> post_fcnet
|
||||
# obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
|
||||
# -> vision-net -> concat w/ Box(1D) and action -> post_fcnet
|
||||
# You can also have SAC use your custom_model as Q-model(s), by simply
|
||||
# specifying the `custom_model` sub-key in below dict (just like you would
|
||||
# do in the top-level `model` dict.
|
||||
"Q_model": {
|
||||
"fcnet_activation": "relu",
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": None,
|
||||
"custom_model": None, # Use this to define custom Q-model(s).
|
||||
"custom_model_config": {},
|
||||
},
|
||||
# Model options for the policy function.
|
||||
# Model options for the policy function (see `Q_model` above for details).
|
||||
# The difference to `Q_model` above is that no action concat'ing is
|
||||
# performed before the post_fcnet stack.
|
||||
"policy_model": {
|
||||
"fcnet_activation": "relu",
|
||||
"fcnet_hiddens": [256, 256],
|
||||
"fcnet_activation": "relu",
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": None,
|
||||
"custom_model": None, # Use this to define a custom policy model.
|
||||
"custom_model_config": {},
|
||||
},
|
||||
# Unsquash actions to the upper and lower bounds of env's action space.
|
||||
# Ignored for discrete action spaces.
|
||||
@@ -145,11 +167,10 @@ def validate_config(config: TrainerConfigDict) -> None:
|
||||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
if config["model"].get("custom_model"):
|
||||
logger.warning(
|
||||
"Setting use_state_preprocessor=True since a custom model "
|
||||
"was specified.")
|
||||
config["use_state_preprocessor"] = True
|
||||
if config["use_state_preprocessor"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['use_state_preprocessor']", error=False)
|
||||
config["use_state_preprocessor"] = DEPRECATED_VALUE
|
||||
|
||||
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
|
||||
raise ValueError("`grad_clip` value must be > 0.0!")
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
@@ -14,14 +17,21 @@ tf1, tf, tfv = try_import_tf()
|
||||
class SACTFModel(TFModelV2):
|
||||
"""Extension of the standard TFModelV2 for SAC.
|
||||
|
||||
Instances of this Model get created via wrapping this class around another
|
||||
default- or custom model (inside
|
||||
rllib/agents/sac/sac_tf_policy.py::build_sac_model). Doing so simply adds
|
||||
this class' methods (`get_q_values`, etc..) to the wrapped model, such that
|
||||
the wrapped model can be used by the SAC algorithm.
|
||||
To customize, do one of the following:
|
||||
- sub-class SACTFModel and override one or more of its methods.
|
||||
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
|
||||
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
|
||||
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
|
||||
to specify your own custom Q-model(s) and policy-models, which will be
|
||||
created within this SACTFModel (see `build_policy_model` and
|
||||
`build_q_model`.
|
||||
|
||||
Note: It is not recommended to override the `forward` method for SAC. This
|
||||
would lead to shared weights (between policy and Q-nets), which will then
|
||||
not be optimized by either of the critic- or actor-optimizers!
|
||||
|
||||
Data flow:
|
||||
`obs` -> forward() -> `model_out`
|
||||
`obs` -> forward() (should stay a noop method!) -> `model_out`
|
||||
`model_out` -> get_policy_output() -> pi(actions|obs)
|
||||
`model_out`, `actions` -> get_q_values() -> Q(s, a)
|
||||
`model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a)
|
||||
@@ -33,20 +43,18 @@ class SACTFModel(TFModelV2):
|
||||
num_outputs: Optional[int],
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
actor_hidden_activation: str = "relu",
|
||||
actor_hiddens: Tuple[int] = (256, 256),
|
||||
critic_hidden_activation: str = "relu",
|
||||
critic_hiddens: Tuple[int] = (256, 256),
|
||||
policy_model_config: ModelConfigDict = None,
|
||||
q_model_config: ModelConfigDict = None,
|
||||
twin_q: bool = False,
|
||||
initial_alpha: float = 1.0,
|
||||
target_entropy: Optional[float] = None):
|
||||
"""Initialize a SACTFModel instance.
|
||||
|
||||
Args:
|
||||
actor_hidden_activation (str): Activation for the actor network.
|
||||
actor_hiddens (list): Hidden layers sizes for the actor network.
|
||||
critic_hidden_activation (str): Activation for the critic network.
|
||||
critic_hiddens (list): Hidden layers sizes for the critic network.
|
||||
policy_model_config (ModelConfigDict): The config dict for the
|
||||
policy network.
|
||||
q_model_config (ModelConfigDict): The config dict for the
|
||||
Q-network(s) (2 if twin_q=True).
|
||||
twin_q (bool): Build twin Q networks (Q-net and target) for more
|
||||
stable Q-learning.
|
||||
initial_alpha (float): The initial value for the to-be-optimized
|
||||
@@ -77,54 +85,15 @@ class SACTFModel(TFModelV2):
|
||||
action_outs = self.action_dim
|
||||
q_outs = 1
|
||||
|
||||
self.model_out = tf.keras.layers.Input(
|
||||
shape=(self.num_outputs, ), name="model_out")
|
||||
self.action_model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(
|
||||
units=hidden,
|
||||
activation=getattr(tf.nn, actor_hidden_activation, None),
|
||||
name="action_{}".format(i + 1))
|
||||
for i, hidden in enumerate(actor_hiddens)
|
||||
] + [
|
||||
tf.keras.layers.Dense(
|
||||
units=action_outs, activation=None, name="action_out")
|
||||
])
|
||||
self.shift_and_log_scale_diag = self.action_model(self.model_out)
|
||||
|
||||
self.actions_input = None
|
||||
if not self.discrete:
|
||||
self.actions_input = tf.keras.layers.Input(
|
||||
shape=(self.action_dim, ), name="actions")
|
||||
|
||||
def build_q_net(name, observations, actions):
|
||||
# For continuous actions: Feed obs and actions (concatenated)
|
||||
# through the NN. For discrete actions, only obs.
|
||||
q_net = tf.keras.Sequential(([
|
||||
tf.keras.layers.Concatenate(axis=1),
|
||||
] if not self.discrete else []) + [
|
||||
tf.keras.layers.Dense(
|
||||
units=units,
|
||||
activation=getattr(tf.nn, critic_hidden_activation, None),
|
||||
name="{}_hidden_{}".format(name, i))
|
||||
for i, units in enumerate(critic_hiddens)
|
||||
] + [
|
||||
tf.keras.layers.Dense(
|
||||
units=q_outs, activation=None, name="{}_out".format(name))
|
||||
])
|
||||
|
||||
# TODO(hartikainen): Remove the unnecessary Model calls here
|
||||
if self.discrete:
|
||||
q_net = tf.keras.Model(observations, q_net(observations))
|
||||
else:
|
||||
q_net = tf.keras.Model([observations, actions],
|
||||
q_net([observations, actions]))
|
||||
return q_net
|
||||
|
||||
self.q_net = build_q_net("q", self.model_out, self.actions_input)
|
||||
self.action_model = self.build_policy_model(
|
||||
self.obs_space, action_outs, policy_model_config, "policy_model")
|
||||
|
||||
self.q_net = self.build_q_model(self.obs_space, self.action_space,
|
||||
q_outs, q_model_config, "q")
|
||||
if twin_q:
|
||||
self.twin_q_net = build_q_net("twin_q", self.model_out,
|
||||
self.actions_input)
|
||||
self.twin_q_net = self.build_q_model(self.obs_space,
|
||||
self.action_space, q_outs,
|
||||
q_model_config, "twin_q")
|
||||
else:
|
||||
self.twin_q_net = None
|
||||
|
||||
@@ -143,6 +112,80 @@ class SACTFModel(TFModelV2):
|
||||
target_entropy = -np.prod(action_space.shape)
|
||||
self.target_entropy = target_entropy
|
||||
|
||||
@override(TFModelV2)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
"""The common (Q-net and policy-net) forward pass.
|
||||
|
||||
NOTE: It is not(!) recommended to override this method as it would
|
||||
introduce a shared pre-network, which would be updated by both
|
||||
actor- and critic optimizers.
|
||||
"""
|
||||
return input_dict["obs"], state
|
||||
|
||||
def build_policy_model(self, obs_space, num_outputs, policy_model_config,
|
||||
name):
|
||||
"""Builds the policy model used by this SAC.
|
||||
|
||||
Override this method in a sub-class of SACTFModel to implement your
|
||||
own policy net. Alternatively, simply set `custom_model` within the
|
||||
top level SAC `policy_model` config key to make this default
|
||||
implementation of `build_policy_model` use your custom policy network.
|
||||
|
||||
Returns:
|
||||
TFModelV2: The TFModelV2 policy sub-model.
|
||||
"""
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
self.action_space,
|
||||
num_outputs,
|
||||
policy_model_config,
|
||||
framework="tf",
|
||||
name=name)
|
||||
return model
|
||||
|
||||
def build_q_model(self, obs_space, action_space, num_outputs,
|
||||
q_model_config, name):
|
||||
"""Builds one of the (twin) Q-nets used by this SAC.
|
||||
|
||||
Override this method in a sub-class of SACTFModel to implement your
|
||||
own Q-nets. Alternatively, simply set `custom_model` within the
|
||||
top level SAC `Q_model` config key to make this default implementation
|
||||
of `build_q_model` use your custom Q-nets.
|
||||
|
||||
Returns:
|
||||
TFModelV2: The TFModelV2 Q-net sub-model.
|
||||
"""
|
||||
self.concat_obs_and_actions = False
|
||||
if self.discrete:
|
||||
input_space = obs_space
|
||||
else:
|
||||
orig_space = getattr(obs_space, "original_space", obs_space)
|
||||
if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
|
||||
input_space = Box(
|
||||
float("-inf"),
|
||||
float("inf"),
|
||||
shape=(orig_space.shape[0] + action_space.shape[0], ))
|
||||
self.concat_obs_and_actions = True
|
||||
else:
|
||||
if isinstance(orig_space, gym.spaces.Tuple):
|
||||
spaces = orig_space.spaces
|
||||
elif isinstance(orig_space, gym.spaces.Dict):
|
||||
spaces = list(orig_space.spaces.values())
|
||||
else:
|
||||
spaces = [obs_space]
|
||||
input_space = gym.spaces.Tuple(spaces + [action_space])
|
||||
|
||||
model = ModelCatalog.get_model_v2(
|
||||
input_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
q_model_config,
|
||||
framework="tf",
|
||||
name=name)
|
||||
return model
|
||||
|
||||
def get_q_values(self,
|
||||
model_out: TensorType,
|
||||
actions: Optional[TensorType] = None) -> TensorType:
|
||||
@@ -161,12 +204,7 @@ class SACTFModel(TFModelV2):
|
||||
Returns:
|
||||
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
|
||||
"""
|
||||
# Continuous case -> concat actions to model_out.
|
||||
if actions is not None:
|
||||
return self.q_net([model_out, actions])
|
||||
# Discrete case -> return q-vals for all actions.
|
||||
else:
|
||||
return self.q_net(model_out)
|
||||
return self._get_q_value(model_out, actions, self.q_net)
|
||||
|
||||
def get_twin_q_values(self,
|
||||
model_out: TensorType,
|
||||
@@ -185,12 +223,32 @@ class SACTFModel(TFModelV2):
|
||||
Returns:
|
||||
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
|
||||
"""
|
||||
return self._get_q_value(model_out, actions, self.twin_q_net)
|
||||
|
||||
def _get_q_value(self, model_out, actions, net):
|
||||
# Model outs may come as original Tuple/Dict observations, concat them
|
||||
# here if this is the case.
|
||||
if isinstance(net.obs_space, Box):
|
||||
if isinstance(model_out, (list, tuple)):
|
||||
model_out = tf.concat(model_out, axis=-1)
|
||||
elif isinstance(model_out, dict):
|
||||
model_out = list(model_out.values())
|
||||
|
||||
# Continuous case -> concat actions to model_out.
|
||||
if actions is not None:
|
||||
return self.twin_q_net([model_out, actions])
|
||||
if self.concat_obs_and_actions:
|
||||
input_dict = {"obs": tf.concat([model_out, actions], axis=-1)}
|
||||
else:
|
||||
input_dict = {"obs": force_list(model_out) + [actions]}
|
||||
# Discrete case -> return q-vals for all actions.
|
||||
else:
|
||||
return self.twin_q_net(model_out)
|
||||
input_dict = {"obs": model_out}
|
||||
# Switch on training mode (when getting Q-values, we are usually in
|
||||
# training).
|
||||
input_dict["is_training"] = True
|
||||
|
||||
out, _ = net(input_dict, [], None)
|
||||
return out
|
||||
|
||||
def get_policy_output(self, model_out: TensorType) -> TensorType:
|
||||
"""Returns policy outputs, given the output of self.__call__().
|
||||
@@ -207,15 +265,23 @@ class SACTFModel(TFModelV2):
|
||||
Returns:
|
||||
TensorType: Distribution inputs for sampling actions.
|
||||
"""
|
||||
return self.action_model(model_out)
|
||||
# Model outs may come as original Tuple observations, concat them
|
||||
# here if this is the case.
|
||||
if isinstance(self.action_model.obs_space, Box):
|
||||
if isinstance(model_out, (list, tuple)):
|
||||
model_out = tf.concat(model_out, axis=-1)
|
||||
elif isinstance(model_out, dict):
|
||||
model_out = tf.concat(list(model_out.values()), axis=-1)
|
||||
out, _ = self.action_model({"obs": model_out}, [], None)
|
||||
return out
|
||||
|
||||
def policy_variables(self):
|
||||
"""Return the list of variables for the policy net."""
|
||||
|
||||
return list(self.action_model.variables)
|
||||
return self.action_model.variables()
|
||||
|
||||
def q_variables(self):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
||||
return self.q_net.variables + (self.twin_q_net.variables
|
||||
if self.twin_q_net else [])
|
||||
return self.q_net.variables() + (self.twin_q_net.variables()
|
||||
if self.twin_q_net else [])
|
||||
|
||||
@@ -6,6 +6,7 @@ import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
from functools import partial
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import ray
|
||||
@@ -17,7 +18,7 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
|
||||
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
|
||||
DiagGaussian, Dirichlet, SquashedGaussian, TFActionDistribution
|
||||
@@ -55,40 +56,35 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# With separate state-preprocessor (before obs+action concat).
|
||||
if config["use_state_preprocessor"]:
|
||||
num_outputs = 256 # Flatten last Conv2D to this many nodes.
|
||||
# No separate state-preprocessor: concat obs+actions right away.
|
||||
else:
|
||||
num_outputs = 0
|
||||
# No state preprocessor: fcnet_hiddens should be empty.
|
||||
if config["model"]["fcnet_hiddens"]:
|
||||
logger.warning(
|
||||
"When not using a state-preprocessor with SAC, `fcnet_hiddens`"
|
||||
" will be set to an empty list! Any hidden layer sizes are "
|
||||
"defined via `policy_model.fcnet_hiddens` and "
|
||||
"`Q_model.fcnet_hiddens`.")
|
||||
config["model"]["fcnet_hiddens"] = []
|
||||
num_outputs = int(np.product(obs_space.shape))
|
||||
|
||||
# Force-ignore any additionally provided hidden layer sizes.
|
||||
# Everything should be configured using SAC's "Q_model" and "policy_model"
|
||||
# settings.
|
||||
policy_model_config = MODEL_DEFAULTS.copy()
|
||||
policy_model_config.update(config["policy_model"])
|
||||
q_model_config = MODEL_DEFAULTS.copy()
|
||||
q_model_config.update(config["Q_model"])
|
||||
|
||||
default_model_cls = SACTorchModel if config["framework"] == "torch" \
|
||||
else SACTFModel
|
||||
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework=config["framework"],
|
||||
model_interface=SACTorchModel
|
||||
if config["framework"] == "torch" else SACTFModel,
|
||||
default_model=default_model_cls,
|
||||
name="sac_model",
|
||||
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
|
||||
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
|
||||
critic_hidden_activation=config["Q_model"]["fcnet_activation"],
|
||||
critic_hiddens=config["Q_model"]["fcnet_hiddens"],
|
||||
policy_model_config=policy_model_config,
|
||||
q_model_config=q_model_config,
|
||||
twin_q=config["twin_q"],
|
||||
initial_alpha=config["initial_alpha"],
|
||||
target_entropy=config["target_entropy"])
|
||||
|
||||
assert isinstance(model, default_model_cls)
|
||||
|
||||
# Create an exact copy of the model and store it in `policy.target_model`.
|
||||
# This will be used for tau-synched Q-target models that run behind the
|
||||
# actual Q-networks and are used for target q-value calculations in the
|
||||
@@ -99,17 +95,16 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework=config["framework"],
|
||||
model_interface=SACTorchModel
|
||||
if config["framework"] == "torch" else SACTFModel,
|
||||
default_model=default_model_cls,
|
||||
name="target_sac_model",
|
||||
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
|
||||
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
|
||||
critic_hidden_activation=config["Q_model"]["fcnet_activation"],
|
||||
critic_hiddens=config["Q_model"]["fcnet_hiddens"],
|
||||
policy_model_config=policy_model_config,
|
||||
q_model_config=q_model_config,
|
||||
twin_q=config["twin_q"],
|
||||
initial_alpha=config["initial_alpha"],
|
||||
target_entropy=config["target_entropy"])
|
||||
|
||||
assert isinstance(policy.target_model, default_model_cls)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -198,14 +193,14 @@ def get_distribution_inputs_and_class(
|
||||
dist inputs, dist class, and a list of internal state outputs
|
||||
(in the RNN case).
|
||||
"""
|
||||
# Get base-model output (w/o the SAC specific parts of the network).
|
||||
model_out, state_out = model({
|
||||
# Get base-model (forward) output (this should be a noop call).
|
||||
forward_out, state_out = model({
|
||||
"obs": obs_batch,
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}, [], None)
|
||||
# Use the base output to get the policy outputs from the SAC model's
|
||||
# policy components.
|
||||
distribution_inputs = model.get_policy_output(model_out)
|
||||
distribution_inputs = model.get_policy_output(forward_out)
|
||||
# Get a distribution class to be used with the just calculated dist-inputs.
|
||||
action_dist_class = _get_dist_class(policy.config, policy.action_space)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
@@ -16,14 +17,21 @@ torch, nn = try_import_torch()
|
||||
class SACTorchModel(TorchModelV2, nn.Module):
|
||||
"""Extension of the standard TorchModelV2 for SAC.
|
||||
|
||||
Instances of this Model get created via wrapping this class around another
|
||||
default- or custom model (inside
|
||||
rllib/agents/sac/sac_torch_policy.py::build_sac_model). Doing so simply
|
||||
adds this class' methods (`get_q_values`, etc..) to the wrapped model, such
|
||||
that the wrapped model can be used by the SAC algorithm.
|
||||
To customize, do one of the following:
|
||||
- sub-class SACTorchModel and override one or more of its methods.
|
||||
- Use SAC's `Q_model` and `policy_model` keys to tweak the default model
|
||||
behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
|
||||
- Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
|
||||
to specify your own custom Q-model(s) and policy-models, which will be
|
||||
created within this SACTFModel (see `build_policy_model` and
|
||||
`build_q_model`.
|
||||
|
||||
Note: It is not recommended to override the `forward` method for SAC. This
|
||||
would lead to shared weights (between policy and Q-nets), which will then
|
||||
not be optimized by either of the critic- or actor-optimizers!
|
||||
|
||||
Data flow:
|
||||
`obs` -> forward() -> `model_out`
|
||||
`obs` -> forward() (should stay a noop method!) -> `model_out`
|
||||
`model_out` -> get_policy_output() -> pi(actions|obs)
|
||||
`model_out`, `actions` -> get_q_values() -> Q(s, a)
|
||||
`model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a)
|
||||
@@ -35,20 +43,18 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
||||
num_outputs: Optional[int],
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
actor_hidden_activation: str = "relu",
|
||||
actor_hiddens: Tuple[int] = (256, 256),
|
||||
critic_hidden_activation: str = "relu",
|
||||
critic_hiddens: Tuple[int] = (256, 256),
|
||||
policy_model_config: ModelConfigDict = None,
|
||||
q_model_config: ModelConfigDict = None,
|
||||
twin_q: bool = False,
|
||||
initial_alpha: float = 1.0,
|
||||
target_entropy: Optional[float] = None):
|
||||
"""Initializes a SACTorchModel instance.
|
||||
7
|
||||
Args:
|
||||
actor_hidden_activation (str): Activation for the actor network.
|
||||
actor_hiddens (list): Hidden layers sizes for the actor network.
|
||||
critic_hidden_activation (str): Activation for the critic network.
|
||||
critic_hiddens (list): Hidden layers sizes for the critic network.
|
||||
policy_model_config (ModelConfigDict): The config dict for the
|
||||
policy network.
|
||||
q_model_config (ModelConfigDict): The config dict for the
|
||||
Q-network(s) (2 if twin_q=True).
|
||||
twin_q (bool): Build twin Q networks (Q-net and target) for more
|
||||
stable Q-learning.
|
||||
initial_alpha (float): The initial value for the to-be-optimized
|
||||
@@ -69,74 +75,29 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
||||
self.action_dim = action_space.n
|
||||
self.discrete = True
|
||||
action_outs = q_outs = self.action_dim
|
||||
action_ins = None # No action inputs for the discrete case.
|
||||
elif isinstance(action_space, Box):
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
self.discrete = False
|
||||
action_outs = 2 * self.action_dim
|
||||
action_ins = self.action_dim
|
||||
q_outs = 1
|
||||
else:
|
||||
assert isinstance(action_space, Simplex)
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
self.discrete = False
|
||||
action_outs = self.action_dim
|
||||
action_ins = self.action_dim
|
||||
q_outs = 1
|
||||
|
||||
# Build the policy network.
|
||||
self.action_model = nn.Sequential()
|
||||
ins = self.num_outputs
|
||||
self.obs_ins = ins
|
||||
activation = get_activation_fn(
|
||||
actor_hidden_activation, framework="torch")
|
||||
for i, n in enumerate(actor_hiddens):
|
||||
self.action_model.add_module(
|
||||
"action_{}".format(i),
|
||||
SlimFC(
|
||||
ins,
|
||||
n,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=activation))
|
||||
ins = n
|
||||
self.action_model.add_module(
|
||||
"action_out",
|
||||
SlimFC(
|
||||
ins,
|
||||
action_outs,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=None))
|
||||
self.action_model = self.build_policy_model(
|
||||
self.obs_space, action_outs, policy_model_config, "policy_model")
|
||||
|
||||
# Build the Q-net(s), including target Q-net(s).
|
||||
def build_q_net(name_):
|
||||
activation = get_activation_fn(
|
||||
critic_hidden_activation, framework="torch")
|
||||
# For continuous actions: Feed obs and actions (concatenated)
|
||||
# through the NN. For discrete actions, only obs.
|
||||
q_net = nn.Sequential()
|
||||
ins = self.obs_ins + (0 if self.discrete else action_ins)
|
||||
for i, n in enumerate(critic_hiddens):
|
||||
q_net.add_module(
|
||||
"{}_hidden_{}".format(name_, i),
|
||||
SlimFC(
|
||||
ins,
|
||||
n,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=activation))
|
||||
ins = n
|
||||
|
||||
q_net.add_module(
|
||||
"{}_out".format(name_),
|
||||
SlimFC(
|
||||
ins,
|
||||
q_outs,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=None))
|
||||
return q_net
|
||||
|
||||
self.q_net = build_q_net("q")
|
||||
# Build the Q-network(s).
|
||||
self.q_net = self.build_q_model(self.obs_space, self.action_space,
|
||||
q_outs, q_model_config, "q")
|
||||
if twin_q:
|
||||
self.twin_q_net = build_q_net("twin_q")
|
||||
self.twin_q_net = self.build_q_model(self.obs_space,
|
||||
self.action_space, q_outs,
|
||||
q_model_config, "twin_q")
|
||||
else:
|
||||
self.twin_q_net = None
|
||||
|
||||
@@ -157,6 +118,80 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
||||
self.target_entropy = torch.tensor(
|
||||
data=[target_entropy], dtype=torch.float32, requires_grad=False)
|
||||
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
"""The common (Q-net and policy-net) forward pass.
|
||||
|
||||
NOTE: It is not(!) recommended to override this method as it would
|
||||
introduce a shared pre-network, which would be updated by both
|
||||
actor- and critic optimizers.
|
||||
"""
|
||||
return input_dict["obs"], state
|
||||
|
||||
def build_policy_model(self, obs_space, num_outputs, policy_model_config,
|
||||
name):
|
||||
"""Builds the policy model used by this SAC.
|
||||
|
||||
Override this method in a sub-class of SACTFModel to implement your
|
||||
own policy net. Alternatively, simply set `custom_model` within the
|
||||
top level SAC `policy_model` config key to make this default
|
||||
implementation of `build_policy_model` use your custom policy network.
|
||||
|
||||
Returns:
|
||||
TorchModelV2: The TorchModelV2 policy sub-model.
|
||||
"""
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
self.action_space,
|
||||
num_outputs,
|
||||
policy_model_config,
|
||||
framework="torch",
|
||||
name=name)
|
||||
return model
|
||||
|
||||
def build_q_model(self, obs_space, action_space, num_outputs,
|
||||
q_model_config, name):
|
||||
"""Builds one of the (twin) Q-nets used by this SAC.
|
||||
|
||||
Override this method in a sub-class of SACTFModel to implement your
|
||||
own Q-nets. Alternatively, simply set `custom_model` within the
|
||||
top level SAC `Q_model` config key to make this default implementation
|
||||
of `build_q_model` use your custom Q-nets.
|
||||
|
||||
Returns:
|
||||
TorchModelV2: The TorchModelV2 Q-net sub-model.
|
||||
"""
|
||||
self.concat_obs_and_actions = False
|
||||
if self.discrete:
|
||||
input_space = obs_space
|
||||
else:
|
||||
orig_space = getattr(obs_space, "original_space", obs_space)
|
||||
if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
|
||||
input_space = Box(
|
||||
float("-inf"),
|
||||
float("inf"),
|
||||
shape=(orig_space.shape[0] + action_space.shape[0], ))
|
||||
self.concat_obs_and_actions = True
|
||||
else:
|
||||
if isinstance(orig_space, gym.spaces.Tuple):
|
||||
spaces = orig_space.spaces
|
||||
elif isinstance(orig_space, gym.spaces.Dict):
|
||||
spaces = list(orig_space.spaces.values())
|
||||
else:
|
||||
spaces = [obs_space]
|
||||
input_space = gym.spaces.Tuple(spaces + [action_space])
|
||||
|
||||
model = ModelCatalog.get_model_v2(
|
||||
input_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
q_model_config,
|
||||
framework="torch",
|
||||
name=name)
|
||||
return model
|
||||
|
||||
def get_q_values(self,
|
||||
model_out: TensorType,
|
||||
actions: Optional[TensorType] = None) -> TensorType:
|
||||
@@ -175,12 +210,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
||||
Returns:
|
||||
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
|
||||
"""
|
||||
# Continuous case -> concat actions to model_out.
|
||||
if actions is not None:
|
||||
return self.q_net(torch.cat([model_out, actions], -1))
|
||||
# Discrete case -> return q-vals for all actions.
|
||||
else:
|
||||
return self.q_net(model_out)
|
||||
return self._get_q_value(model_out, actions, self.q_net)
|
||||
|
||||
def get_twin_q_values(self,
|
||||
model_out: TensorType,
|
||||
@@ -199,12 +229,32 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
||||
Returns:
|
||||
TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
|
||||
"""
|
||||
return self._get_q_value(model_out, actions, self.twin_q_net)
|
||||
|
||||
def _get_q_value(self, model_out, actions, net):
|
||||
# Model outs may come as original Tuple observations, concat them
|
||||
# here if this is the case.
|
||||
if isinstance(net.obs_space, Box):
|
||||
if isinstance(model_out, (list, tuple)):
|
||||
model_out = torch.cat(model_out, dim=-1)
|
||||
elif isinstance(model_out, dict):
|
||||
model_out = list(model_out.values())
|
||||
|
||||
# Continuous case -> concat actions to model_out.
|
||||
if actions is not None:
|
||||
return self.twin_q_net(torch.cat([model_out, actions], -1))
|
||||
if self.concat_obs_and_actions:
|
||||
input_dict = {"obs": torch.cat([model_out, actions], dim=-1)}
|
||||
else:
|
||||
input_dict = {"obs": force_list(model_out) + [actions]}
|
||||
# Discrete case -> return q-vals for all actions.
|
||||
else:
|
||||
return self.twin_q_net(model_out)
|
||||
input_dict = {"obs": model_out}
|
||||
# Switch on training mode (when getting Q-values, we are usually in
|
||||
# training).
|
||||
input_dict["is_training"] = True
|
||||
|
||||
out, _ = net(input_dict, [], None)
|
||||
return out
|
||||
|
||||
def get_policy_output(self, model_out: TensorType) -> TensorType:
|
||||
"""Returns policy outputs, given the output of self.__call__().
|
||||
@@ -221,15 +271,23 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
||||
Returns:
|
||||
TensorType: Distribution inputs for sampling actions.
|
||||
"""
|
||||
return self.action_model(model_out)
|
||||
# Model outs may come as original Tuple observations, concat them
|
||||
# here if this is the case.
|
||||
if isinstance(self.action_model.obs_space, Box):
|
||||
if isinstance(model_out, (list, tuple)):
|
||||
model_out = torch.cat(model_out, dim=-1)
|
||||
elif isinstance(model_out, dict):
|
||||
model_out = torch.cat(list(model_out.values()), dim=-1)
|
||||
out, _ = self.action_model({"obs": model_out}, [], None)
|
||||
return out
|
||||
|
||||
def policy_variables(self):
|
||||
"""Return the list of variables for the policy net."""
|
||||
|
||||
return list(self.action_model.parameters())
|
||||
return self.action_model.variables()
|
||||
|
||||
def q_variables(self):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
||||
return list(self.q_net.parameters()) + \
|
||||
(list(self.twin_q_net.parameters()) if self.twin_q_net else [])
|
||||
return self.q_net.variables() + (self.twin_q_net.variables()
|
||||
if self.twin_q_net else [])
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from gym import Env
|
||||
from gym.spaces import Box
|
||||
from gym.spaces import Box, Discrete, Tuple
|
||||
import numpy as np
|
||||
import re
|
||||
import unittest
|
||||
@@ -9,6 +9,10 @@ import ray.rllib.agents.sac as sac
|
||||
from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
|
||||
from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
|
||||
loss_torch
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \
|
||||
TorchBatchNormModel
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.tf.tf_action_dist import Dirichlet
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
@@ -52,7 +56,7 @@ class SimpleEnv(Env):
|
||||
class TestSAC(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
ray.init(local_mode=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
@@ -61,22 +65,46 @@ class TestSAC(unittest.TestCase):
|
||||
def test_sac_compilation(self):
|
||||
"""Tests whether an SACTrainer can be built with all frameworks."""
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["twin_q"] = True
|
||||
config["soft_horizon"] = True
|
||||
config["clip_actions"] = False
|
||||
config["normalize_actions"] = True
|
||||
config["learning_starts"] = 0
|
||||
config["prioritized_replay"] = True
|
||||
config["rollout_fragment_length"] = 10
|
||||
config["train_batch_size"] = 10
|
||||
num_iterations = 1
|
||||
for _ in framework_iterator(config):
|
||||
|
||||
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
|
||||
ModelCatalog.register_custom_model("batch_norm_torch",
|
||||
TorchBatchNormModel)
|
||||
|
||||
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
|
||||
simple_space = Box(-1.0, 1.0, shape=(3, ))
|
||||
|
||||
for fw in framework_iterator(config):
|
||||
# Test for different env types (discrete w/ and w/o image, + cont).
|
||||
for env in [
|
||||
"Pendulum-v0", "MsPacmanNoFrameskip-v4", "CartPole-v0"
|
||||
RandomEnv,
|
||||
"MsPacmanNoFrameskip-v4",
|
||||
"CartPole-v0",
|
||||
]:
|
||||
print("Env={}".format(env))
|
||||
config["use_state_preprocessor"] = \
|
||||
env == "MsPacmanNoFrameskip-v4"
|
||||
if env == RandomEnv:
|
||||
config["env_config"] = {
|
||||
"observation_space": Tuple(
|
||||
[simple_space,
|
||||
Discrete(2), image_space]),
|
||||
"action_space": Box(-1.0, 1.0, shape=(1, )),
|
||||
}
|
||||
else:
|
||||
config["env_config"] = {}
|
||||
# Test making the Q-model a custom one for CartPole, otherwise,
|
||||
# use the default model.
|
||||
config["Q_model"]["custom_model"] = "batch_norm{}".format(
|
||||
"_torch"
|
||||
if fw == "torch" else "") if env == "CartPole-v0" else None
|
||||
trainer = sac.SACTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
@@ -103,49 +131,56 @@ class TestSAC(unittest.TestCase):
|
||||
config["env_config"] = {"simplex_actions": True}
|
||||
|
||||
map_ = {
|
||||
# Normal net.
|
||||
"default_policy/sequential/action_1/kernel": "action_model."
|
||||
"action_0._model.0.weight",
|
||||
"default_policy/sequential/action_1/bias": "action_model."
|
||||
"action_0._model.0.bias",
|
||||
"default_policy/sequential/action_out/kernel": "action_model."
|
||||
"action_out._model.0.weight",
|
||||
"default_policy/sequential/action_out/bias": "action_model."
|
||||
"action_out._model.0.bias",
|
||||
"default_policy/sequential_1/q_hidden_0/kernel": "q_net."
|
||||
"q_hidden_0._model.0.weight",
|
||||
"default_policy/sequential_1/q_hidden_0/bias": "q_net."
|
||||
"q_hidden_0._model.0.bias",
|
||||
"default_policy/sequential_1/q_out/kernel": "q_net."
|
||||
"q_out._model.0.weight",
|
||||
"default_policy/sequential_1/q_out/bias": "q_net."
|
||||
"q_out._model.0.bias",
|
||||
"default_policy/value_out/kernel": "_value_branch."
|
||||
# Action net.
|
||||
"default_policy/fc_1/kernel": "action_model._hidden_layers.0."
|
||||
"_model.0.weight",
|
||||
"default_policy/value_out/bias": "_value_branch."
|
||||
"default_policy/fc_1/bias": "action_model._hidden_layers.0."
|
||||
"_model.0.bias",
|
||||
"default_policy/fc_out/kernel": "action_model."
|
||||
"_logits._model.0.weight",
|
||||
"default_policy/fc_out/bias": "action_model._logits._model.0.bias",
|
||||
"default_policy/value_out/kernel": "action_model."
|
||||
"_value_branch._model.0.weight",
|
||||
"default_policy/value_out/bias": "action_model."
|
||||
"_value_branch._model.0.bias",
|
||||
# Q-net.
|
||||
"default_policy/fc_1_1/kernel": "q_net."
|
||||
"_hidden_layers.0._model.0.weight",
|
||||
"default_policy/fc_1_1/bias": "q_net."
|
||||
"_hidden_layers.0._model.0.bias",
|
||||
"default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
|
||||
"default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
|
||||
"default_policy/value_out_1/kernel": "q_net."
|
||||
"_value_branch._model.0.weight",
|
||||
"default_policy/value_out_1/bias": "q_net."
|
||||
"_value_branch._model.0.bias",
|
||||
"default_policy/log_alpha": "log_alpha",
|
||||
# Target net.
|
||||
"default_policy/sequential_2/action_1/kernel": "action_model."
|
||||
"action_0._model.0.weight",
|
||||
"default_policy/sequential_2/action_1/bias": "action_model."
|
||||
"action_0._model.0.bias",
|
||||
"default_policy/sequential_2/action_out/kernel": "action_model."
|
||||
"action_out._model.0.weight",
|
||||
"default_policy/sequential_2/action_out/bias": "action_model."
|
||||
"action_out._model.0.bias",
|
||||
"default_policy/sequential_3/q_hidden_0/kernel": "q_net."
|
||||
"q_hidden_0._model.0.weight",
|
||||
"default_policy/sequential_3/q_hidden_0/bias": "q_net."
|
||||
"q_hidden_0._model.0.bias",
|
||||
"default_policy/sequential_3/q_out/kernel": "q_net."
|
||||
"q_out._model.0.weight",
|
||||
"default_policy/sequential_3/q_out/bias": "q_net."
|
||||
"q_out._model.0.bias",
|
||||
"default_policy/value_out_1/kernel": "_value_branch."
|
||||
"_model.0.weight",
|
||||
"default_policy/value_out_1/bias": "_value_branch."
|
||||
"_model.0.bias",
|
||||
# Target action-net.
|
||||
"default_policy/fc_1_2/kernel": "action_model."
|
||||
"_hidden_layers.0._model.0.weight",
|
||||
"default_policy/fc_1_2/bias": "action_model."
|
||||
"_hidden_layers.0._model.0.bias",
|
||||
"default_policy/fc_out_2/kernel": "action_model."
|
||||
"_logits._model.0.weight",
|
||||
"default_policy/fc_out_2/bias": "action_model."
|
||||
"_logits._model.0.bias",
|
||||
"default_policy/value_out_2/kernel": "action_model."
|
||||
"_value_branch._model.0.weight",
|
||||
"default_policy/value_out_2/bias": "action_model."
|
||||
"_value_branch._model.0.bias",
|
||||
# Target Q-net
|
||||
"default_policy/fc_1_3/kernel": "q_net."
|
||||
"_hidden_layers.0._model.0.weight",
|
||||
"default_policy/fc_1_3/bias": "q_net."
|
||||
"_hidden_layers.0._model.0.bias",
|
||||
"default_policy/fc_out_3/kernel": "q_net."
|
||||
"_logits._model.0.weight",
|
||||
"default_policy/fc_out_3/bias": "q_net."
|
||||
"_logits._model.0.bias",
|
||||
"default_policy/value_out_3/kernel": "q_net."
|
||||
"_value_branch._model.0.weight",
|
||||
"default_policy/value_out_3/bias": "q_net."
|
||||
"_value_branch._model.0.bias",
|
||||
"default_policy/log_alpha_1": "log_alpha",
|
||||
}
|
||||
|
||||
@@ -225,10 +260,12 @@ class TestSAC(unittest.TestCase):
|
||||
policy.td_error,
|
||||
policy.optimizer().compute_gradients(
|
||||
policy.critic_loss[0],
|
||||
policy.model.q_variables()),
|
||||
[v for v in policy.model.q_variables() if
|
||||
"value_" not in v.name]),
|
||||
policy.optimizer().compute_gradients(
|
||||
policy.actor_loss,
|
||||
policy.model.policy_variables()),
|
||||
[v for v in policy.model.policy_variables() if
|
||||
"value_" not in v.name]),
|
||||
policy.optimizer().compute_gradients(
|
||||
policy.alpha_loss, policy.model.log_alpha)],
|
||||
feed_dict=policy._get_loss_inputs_dict(
|
||||
@@ -261,8 +298,6 @@ class TestSAC(unittest.TestCase):
|
||||
a.backward()
|
||||
# `actor_loss` depends on Q-net vars (but these grads must
|
||||
# be ignored and overridden in critic_loss.backward!).
|
||||
assert not any(v.grad is None
|
||||
for v in policy.model.q_variables())
|
||||
assert not all(
|
||||
torch.mean(v.grad) == 0
|
||||
for v in policy.model.policy_variables())
|
||||
@@ -273,45 +308,38 @@ class TestSAC(unittest.TestCase):
|
||||
# Compare with tf ones.
|
||||
torch_a_grads = [
|
||||
v.grad for v in policy.model.policy_variables()
|
||||
if v.grad is not None
|
||||
]
|
||||
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
|
||||
if tf_g.shape != torch_g.shape:
|
||||
check(tf_g, np.transpose(torch_g.detach().cpu()))
|
||||
else:
|
||||
check(tf_g, torch_g)
|
||||
check(tf_a_grads[2],
|
||||
np.transpose(torch_a_grads[0].detach().cpu()))
|
||||
|
||||
# Test critic gradients.
|
||||
policy.critic_optims[0].zero_grad()
|
||||
assert all(
|
||||
torch.mean(v.grad) == 0.0
|
||||
for v in policy.model.q_variables())
|
||||
for v in policy.model.q_variables() if v.grad is not None)
|
||||
assert all(
|
||||
torch.min(v.grad) == 0.0
|
||||
for v in policy.model.q_variables())
|
||||
for v in policy.model.q_variables() if v.grad is not None)
|
||||
assert policy.model.log_alpha.grad is None
|
||||
c[0].backward()
|
||||
assert not all(
|
||||
torch.mean(v.grad) == 0
|
||||
for v in policy.model.q_variables())
|
||||
for v in policy.model.q_variables() if v.grad is not None)
|
||||
assert not all(
|
||||
torch.min(v.grad) == 0 for v in policy.model.q_variables())
|
||||
torch.min(v.grad) == 0 for v in policy.model.q_variables()
|
||||
if v.grad is not None)
|
||||
assert policy.model.log_alpha.grad is None
|
||||
# Compare with tf ones.
|
||||
torch_c_grads = [v.grad for v in policy.model.q_variables()]
|
||||
for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
|
||||
if tf_g.shape != torch_g.shape:
|
||||
check(tf_g, np.transpose(torch_g.detach().cpu()))
|
||||
else:
|
||||
check(tf_g, torch_g)
|
||||
check(tf_c_grads[0],
|
||||
np.transpose(torch_c_grads[2].detach().cpu()))
|
||||
# Compare (unchanged(!) actor grads) with tf ones.
|
||||
torch_a_grads = [
|
||||
v.grad for v in policy.model.policy_variables()
|
||||
]
|
||||
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
|
||||
if tf_g.shape != torch_g.shape:
|
||||
check(tf_g, np.transpose(torch_g.detach().cpu()))
|
||||
else:
|
||||
check(tf_g, torch_g)
|
||||
check(tf_a_grads[2],
|
||||
np.transpose(torch_a_grads[0].detach().cpu()))
|
||||
|
||||
# Test alpha gradient.
|
||||
policy.alpha_optim.zero_grad()
|
||||
@@ -336,7 +364,7 @@ class TestSAC(unittest.TestCase):
|
||||
prev_fw_loss = (c, a, e, t)
|
||||
|
||||
# Update weights from our batch (n times).
|
||||
for update_iteration in range(10):
|
||||
for update_iteration in range(5):
|
||||
print("train iteration {}".format(update_iteration))
|
||||
if fw == "tf":
|
||||
in_ = self._get_batch_helper(obs_size, actions, batch_size)
|
||||
@@ -350,10 +378,9 @@ class TestSAC(unittest.TestCase):
|
||||
# Net must have changed.
|
||||
if tf_updated_weights:
|
||||
check(
|
||||
updated_weights[
|
||||
"default_policy/sequential/action_1/kernel"],
|
||||
updated_weights["default_policy/fc_1/kernel"],
|
||||
tf_updated_weights[-1][
|
||||
"default_policy/sequential/action_1/kernel"],
|
||||
"default_policy/fc_1/kernel"],
|
||||
false=True)
|
||||
tf_updated_weights.append(updated_weights)
|
||||
|
||||
@@ -367,7 +394,9 @@ class TestSAC(unittest.TestCase):
|
||||
buf._fake_batch = in_
|
||||
trainer.train()
|
||||
# Compare updated model.
|
||||
for tf_key in sorted(tf_weights.keys())[2:10]:
|
||||
for tf_key in sorted(tf_weights.keys()):
|
||||
if re.search("_[23]|alpha", tf_key):
|
||||
continue
|
||||
tf_var = tf_weights[tf_key]
|
||||
torch_var = policy.model.state_dict()[map_[tf_key]]
|
||||
if tf_var.shape != torch_var.shape:
|
||||
@@ -381,7 +410,9 @@ class TestSAC(unittest.TestCase):
|
||||
check(policy.model.log_alpha,
|
||||
tf_weights["default_policy/log_alpha"])
|
||||
# Compare target nets.
|
||||
for tf_key in sorted(tf_weights.keys())[10:18]:
|
||||
for tf_key in sorted(tf_weights.keys()):
|
||||
if not re.search("_[23]", tf_key):
|
||||
continue
|
||||
tf_var = tf_weights[tf_key]
|
||||
torch_var = policy.target_model.state_dict()[map_[
|
||||
tf_key]]
|
||||
@@ -437,9 +468,9 @@ class TestSAC(unittest.TestCase):
|
||||
fc(
|
||||
relu(
|
||||
fc(model_out_t,
|
||||
weights[ks[3]],
|
||||
weights[ks[2]],
|
||||
framework=fw)), weights[ks[5]], weights[ks[4]]), None)
|
||||
weights[ks[1]],
|
||||
weights[ks[0]],
|
||||
framework=fw)), weights[ks[9]], weights[ks[8]]), None)
|
||||
policy_t = action_dist_t.deterministic_sample()
|
||||
log_pis_t = action_dist_t.logp(policy_t)
|
||||
if sess:
|
||||
@@ -452,9 +483,9 @@ class TestSAC(unittest.TestCase):
|
||||
fc(
|
||||
relu(
|
||||
fc(model_out_tp1,
|
||||
weights[ks[3]],
|
||||
weights[ks[2]],
|
||||
framework=fw)), weights[ks[5]], weights[ks[4]]), None)
|
||||
weights[ks[1]],
|
||||
weights[ks[0]],
|
||||
framework=fw)), weights[ks[9]], weights[ks[8]]), None)
|
||||
policy_tp1 = action_dist_tp1.deterministic_sample()
|
||||
log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
|
||||
if sess:
|
||||
@@ -468,11 +499,11 @@ class TestSAC(unittest.TestCase):
|
||||
relu(
|
||||
fc(np.concatenate(
|
||||
[model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
|
||||
weights[ks[7]],
|
||||
weights[ks[6]],
|
||||
weights[ks[3]],
|
||||
weights[ks[2]],
|
||||
framework=fw)),
|
||||
weights[ks[9]],
|
||||
weights[ks[8]],
|
||||
weights[ks[11]],
|
||||
weights[ks[10]],
|
||||
framework=fw)
|
||||
|
||||
# Q-values for current policy in given current state.
|
||||
@@ -480,11 +511,11 @@ class TestSAC(unittest.TestCase):
|
||||
q_t_det_policy = fc(
|
||||
relu(
|
||||
fc(np.concatenate([model_out_t, policy_t], -1),
|
||||
weights[ks[7]],
|
||||
weights[ks[6]],
|
||||
weights[ks[3]],
|
||||
weights[ks[2]],
|
||||
framework=fw)),
|
||||
weights[ks[9]],
|
||||
weights[ks[8]],
|
||||
weights[ks[11]],
|
||||
weights[ks[10]],
|
||||
framework=fw)
|
||||
|
||||
# Target q network evaluation.
|
||||
@@ -493,11 +524,11 @@ class TestSAC(unittest.TestCase):
|
||||
q_tp1 = fc(
|
||||
relu(
|
||||
fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
|
||||
weights[ks[15]],
|
||||
weights[ks[14]],
|
||||
weights[ks[7]],
|
||||
weights[ks[6]],
|
||||
framework=fw)),
|
||||
weights[ks[17]],
|
||||
weights[ks[16]],
|
||||
weights[ks[15]],
|
||||
weights[ks[14]],
|
||||
framework=fw)
|
||||
else:
|
||||
assert fw == "tfe"
|
||||
@@ -538,9 +569,9 @@ class TestSAC(unittest.TestCase):
|
||||
map_[k]: convert_to_torch_tensor(
|
||||
np.transpose(v) if re.search("kernel", k) else np.array([v])
|
||||
if re.search("log_alpha", k) else v)
|
||||
for k, v in weights_dict.items()
|
||||
if re.search("(sequential(/|_1)|value_out/|log_alpha)", k)
|
||||
for i, (k, v) in enumerate(weights_dict.items()) if i < 13
|
||||
}
|
||||
|
||||
return model_dict
|
||||
|
||||
def _translate_tfe_weights(self, weights_dict, map_):
|
||||
|
||||
@@ -32,7 +32,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.filter import get_filter, Filter
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
@@ -396,15 +396,22 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||
if clip_rewards is None:
|
||||
clip_rewards = True
|
||||
|
||||
# framestacking via trajectory view API is enabled.
|
||||
num_framestacks = model_config.get("num_framestacks", 0)
|
||||
if not policy_config["_use_trajectory_view_api"]:
|
||||
model_config["num_framestacks"] = num_framestacks = 0
|
||||
elif num_framestacks == "auto":
|
||||
model_config["num_framestacks"] = num_framestacks = 4
|
||||
framestack_traj_view = num_framestacks > 1
|
||||
# Deprecated way of framestacking is used.
|
||||
framestack = model_config.get("framestack") is True
|
||||
# framestacking via trajectory view API is enabled.
|
||||
num_framestacks = model_config.get("num_framestacks", 0)
|
||||
|
||||
# No trajectory view API: No traj. view based framestacking.
|
||||
if not policy_config["_use_trajectory_view_api"]:
|
||||
model_config["num_framestacks"] = num_framestacks = 0
|
||||
# Trajectory view API is on and num_framestacks=auto: Only
|
||||
# stack traj. view based if old `framestack=[invalid value]`.
|
||||
elif num_framestacks == "auto":
|
||||
if framestack == DEPRECATED_VALUE:
|
||||
model_config["num_framestacks"] = num_framestacks = 4
|
||||
else:
|
||||
model_config["num_framestacks"] = num_framestacks = 0
|
||||
framestack_traj_view = num_framestacks > 1
|
||||
|
||||
def wrap(env):
|
||||
env = wrap_deepmind(
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
from gym.spaces import Discrete, Tuple
|
||||
|
||||
from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.misc import normc_initializer as \
|
||||
torch_normc_initializer, SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
class CNNPlusFCConcatModel(TFModelV2):
|
||||
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
|
||||
|
||||
Note: This model should be used for complex (Dict or Tuple) observation
|
||||
spaces that have one or more image components.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# TODO: (sven) Support Dicts as well.
|
||||
assert isinstance(obs_space.original_space, (Tuple)), \
|
||||
"`obs_space.original_space` must be Tuple!"
|
||||
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
# Build the CNN(s) given obs_space's image components.
|
||||
self.cnns = {}
|
||||
concat_size = 0
|
||||
for i, component in enumerate(obs_space.original_space):
|
||||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
}
|
||||
cnn = ModelCatalog.get_model_v2(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
framework="tf",
|
||||
name="cnn_{}".format(i))
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
concat_size += component.n
|
||||
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
|
||||
# Everything else (1D Box).
|
||||
else:
|
||||
assert len(component.shape) == 1, \
|
||||
"Only input Box 1D or 3D spaces allowed!"
|
||||
concat_size += component.shape[-1]
|
||||
|
||||
self.logits_and_value_model = None
|
||||
self._value_out = None
|
||||
if num_outputs:
|
||||
# Action-distribution head.
|
||||
concat_layer = tf.keras.layers.Input((concat_size, ))
|
||||
logits_layer = tf.keras.layers.Dense(
|
||||
num_outputs,
|
||||
activation=tf.keras.activations.linear,
|
||||
name="logits")(concat_layer)
|
||||
|
||||
# Create the value branch model.
|
||||
value_layer = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(concat_layer)
|
||||
self.logits_and_value_model = tf.keras.models.Model(
|
||||
concat_layer, [logits_layer, value_layer])
|
||||
else:
|
||||
self.num_outputs = concat_size
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Push image observations through our CNNs.
|
||||
outs = []
|
||||
for i, component in enumerate(input_dict["obs"]):
|
||||
if i in self.cnns:
|
||||
cnn_out, _ = self.cnns[i]({"obs": component})
|
||||
outs.append(cnn_out)
|
||||
else:
|
||||
outs.append(component)
|
||||
# Concat all outputs and the non-image inputs.
|
||||
out = tf.concat(outs, axis=1)
|
||||
if not self.logits_and_value_model:
|
||||
return out, []
|
||||
|
||||
# Value branch.
|
||||
logits, values = self.logits_and_value_model(out)
|
||||
self._value_out = tf.reshape(values, [-1])
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return self._value_out
|
||||
|
||||
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class TorchCNNPlusFCConcatModel(TorchModelV2, nn.Module):
|
||||
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
|
||||
|
||||
Note: This model should be used for complex (Dict or Tuple) observation
|
||||
spaces that have one or more image components.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# TODO: (sven) Support Dicts as well.
|
||||
assert isinstance(obs_space.original_space, (Tuple)), \
|
||||
"`obs_space.original_space` must be Tuple!"
|
||||
|
||||
nn.Module.__init__(self)
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
|
||||
self.cnn_type = self.model_config["custom_model_config"].get(
|
||||
"conv_type", "atari")
|
||||
|
||||
# Build the CNN(s) given obs_space's image components.
|
||||
self.cnns = {}
|
||||
concat_size = 0
|
||||
for i, component in enumerate(obs_space.original_space):
|
||||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
}
|
||||
if self.cnn_type == "atari":
|
||||
cnn = ModelCatalog.get_model_v2(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
framework="torch",
|
||||
name="cnn_{}".format(i))
|
||||
else:
|
||||
cnn = TorchImpalaVisionNet(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
name="cnn_{}".format(i))
|
||||
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
self.add_module("cnn_{}".format(i), cnn)
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
concat_size += component.n
|
||||
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
|
||||
# Everything else (1D Box).
|
||||
else:
|
||||
assert len(component.shape) == 1, \
|
||||
"Only input Box 1D or 3D spaces allowed!"
|
||||
concat_size += component.shape[-1]
|
||||
|
||||
self.logits_layer = None
|
||||
self.value_layer = None
|
||||
self._value_out = None
|
||||
|
||||
if num_outputs:
|
||||
# Action-distribution head.
|
||||
self.logits_layer = SlimFC(
|
||||
in_size=concat_size,
|
||||
out_size=num_outputs,
|
||||
activation_fn=None,
|
||||
)
|
||||
# Create the value branch model.
|
||||
self.value_layer = SlimFC(
|
||||
in_size=concat_size,
|
||||
out_size=1,
|
||||
activation_fn=None,
|
||||
initializer=torch_normc_initializer(0.01))
|
||||
else:
|
||||
self.num_outputs = concat_size
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Push image observations through our CNNs.
|
||||
outs = []
|
||||
for i, component in enumerate(input_dict["obs"]):
|
||||
if i in self.cnns:
|
||||
cnn_out, _ = self.cnns[i]({"obs": component})
|
||||
outs.append(cnn_out)
|
||||
else:
|
||||
outs.append(component)
|
||||
# Concat all outputs and the non-image inputs.
|
||||
out = torch.cat(outs, dim=1)
|
||||
if self.logits_layer is None:
|
||||
return out, []
|
||||
|
||||
# Value branch.
|
||||
logits, values = self.logits_layer(out), self.value_layer(out)
|
||||
self._value_out = torch.reshape(values, [-1])
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return self._value_out
|
||||
+38
-8
@@ -19,7 +19,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchDeterministic, TorchDiagGaussian, \
|
||||
TorchMultiActionDistribution, TorchMultiCategorical
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
@@ -56,6 +56,18 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
||||
# "linear" (or None).
|
||||
"conv_activation": "relu",
|
||||
|
||||
# Some default models support a final FC stack of n Dense layers with given
|
||||
# activation:
|
||||
# - Complex observation spaces: Image components are fed through
|
||||
# VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
|
||||
# everything is concated and pushed through this final FC stack.
|
||||
# - VisionNets (CNNs), e.g. after the CNN stack, there may be
|
||||
# additional Dense layers.
|
||||
# - FullyConnectedNetworks will have this additional FCStack as well
|
||||
# (that's why it's empty by default).
|
||||
"post_fcnet_hiddens": [],
|
||||
"post_fcnet_activation": "relu",
|
||||
|
||||
# For DiagGaussian action distributions, make the second half of the model
|
||||
# outputs floating bias variables instead of state-dependent. This only
|
||||
# has an effect is using the default fully connected net.
|
||||
@@ -688,17 +700,22 @@ class ModelCatalog:
|
||||
framework: str = "tf") -> Type[ModelV2]:
|
||||
|
||||
VisionNet = None
|
||||
ComplexNet = None
|
||||
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
from ray.rllib.models.tf.fcnet import \
|
||||
FullyConnectedNetwork as FCNet
|
||||
from ray.rllib.models.tf.visionnet import \
|
||||
VisionNetwork as VisionNet
|
||||
from ray.rllib.models.tf.complex_input_net import \
|
||||
ComplexInputNetwork as ComplexNet
|
||||
elif framework == "torch":
|
||||
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
||||
FCNet)
|
||||
from ray.rllib.models.torch.visionnet import (VisionNetwork as
|
||||
VisionNet)
|
||||
from ray.rllib.models.torch.complex_input_net import \
|
||||
ComplexInputNetwork as ComplexNet
|
||||
elif framework == "jax":
|
||||
from ray.rllib.models.jax.fcnet import (FullyConnectedNetwork as
|
||||
FCNet)
|
||||
@@ -710,16 +727,29 @@ class ModelCatalog:
|
||||
# Discrete/1D obs-spaces or 2D obs space but traj. view framestacking
|
||||
# disabled.
|
||||
num_framestacks = model_config.get("num_framestacks", "auto")
|
||||
|
||||
# Tuple space, where at least one sub-space is image.
|
||||
# -> Complex input model.
|
||||
space_to_check = input_space if not hasattr(
|
||||
input_space, "original_space") else input_space.original_space
|
||||
if isinstance(input_space,
|
||||
Tuple) or (isinstance(space_to_check, Tuple) and any(
|
||||
isinstance(s, Box) and len(s.shape) >= 2
|
||||
for s in space_to_check.spaces)):
|
||||
return ComplexNet
|
||||
|
||||
# Single, flattenable/one-hot-abe space -> Simple FCNet.
|
||||
if isinstance(input_space, (Discrete, MultiDiscrete)) or \
|
||||
len(input_space.shape) == 1 or (
|
||||
len(input_space.shape) == 2 and (
|
||||
num_framestacks == "auto" or num_framestacks <= 1)):
|
||||
return FCNet
|
||||
# Default Conv2D net.
|
||||
else:
|
||||
if framework == "jax":
|
||||
raise NotImplementedError("No Conv2D default net for JAX yet!")
|
||||
return VisionNet
|
||||
|
||||
elif framework == "jax":
|
||||
raise NotImplementedError("No non-FC default net for JAX yet!")
|
||||
|
||||
# Last resort: Conv2D stack for single image spaces.
|
||||
return VisionNet
|
||||
|
||||
@staticmethod
|
||||
def _get_multi_action_distribution(dist_class, action_space, config,
|
||||
@@ -768,8 +798,8 @@ class ModelCatalog:
|
||||
"framework=jax so far!")
|
||||
|
||||
if config.get("framestack") != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="framestack", new="num_framestacks (int)", error=False)
|
||||
# deprecation_warning(
|
||||
# old="framestack", new="num_framestacks (int)", error=False)
|
||||
# If old behavior is desired, disable traj. view-style
|
||||
# framestacking.
|
||||
config["num_framestacks"] = 0
|
||||
|
||||
+17
-23
@@ -203,9 +203,13 @@ class ModelV2:
|
||||
restored = input_dict.copy()
|
||||
restored["obs"] = restore_original_dimensions(
|
||||
input_dict["obs"], self.obs_space, self.framework)
|
||||
if len(input_dict["obs"].shape) > 2:
|
||||
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
|
||||
else:
|
||||
try:
|
||||
if len(input_dict["obs"].shape) > 2:
|
||||
restored["obs_flat"] = flatten(input_dict["obs"],
|
||||
self.framework)
|
||||
else:
|
||||
restored["obs_flat"] = input_dict["obs"]
|
||||
except AttributeError:
|
||||
restored["obs_flat"] = input_dict["obs"]
|
||||
with self.context():
|
||||
res = self.forward(restored, state or [], seq_lens)
|
||||
@@ -216,15 +220,6 @@ class ModelV2:
|
||||
"got {}".format(res))
|
||||
outputs, state = res
|
||||
|
||||
try:
|
||||
shape = outputs.shape
|
||||
except AttributeError:
|
||||
raise ValueError("Output is not a tensor: {}".format(outputs))
|
||||
else:
|
||||
if len(shape) != 2 or int(shape[1]) != self.num_outputs:
|
||||
raise ValueError(
|
||||
"Expected output shape of [None, {}], got {}".format(
|
||||
self.num_outputs, shape))
|
||||
if not isinstance(state, list):
|
||||
raise ValueError("State output is not a list: {}".format(state))
|
||||
|
||||
@@ -418,15 +413,15 @@ def restore_original_dimensions(obs: TensorType,
|
||||
observation space.
|
||||
"""
|
||||
|
||||
if hasattr(obs_space, "original_space"):
|
||||
if tensorlib == "tf":
|
||||
tensorlib = tf
|
||||
elif tensorlib == "torch":
|
||||
assert torch is not None
|
||||
tensorlib = torch
|
||||
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
|
||||
else:
|
||||
if tensorlib == "tf":
|
||||
tensorlib = tf
|
||||
elif tensorlib == "torch":
|
||||
assert torch is not None
|
||||
tensorlib = torch
|
||||
original_space = getattr(obs_space, "original_space", obs_space)
|
||||
if original_space is obs_space:
|
||||
return obs
|
||||
return _unpack_obs(obs, original_space, tensorlib=tensorlib)
|
||||
|
||||
|
||||
# Cache of preprocessors, for if the user is calling unpack obs often.
|
||||
@@ -490,7 +485,8 @@ def _unpack_obs(obs: TensorType, space: gym.Space,
|
||||
tensorlib.reshape(obs_slice, batch_dims + list(p.shape)),
|
||||
v,
|
||||
tensorlib=tensorlib)
|
||||
elif isinstance(space, Repeated):
|
||||
# Repeated space.
|
||||
else:
|
||||
assert isinstance(prep, RepeatedValuesPreprocessor), prep
|
||||
child_size = prep.child_preprocessor.size
|
||||
# The list lengths are stored in the first slot of the flat obs.
|
||||
@@ -503,8 +499,6 @@ def _unpack_obs(obs: TensorType, space: gym.Space,
|
||||
with_repeat_dim, space.child_space, tensorlib=tensorlib)
|
||||
return RepeatedValues(
|
||||
u, lengths=lengths, max_len=prep._obs_space.max_len)
|
||||
else:
|
||||
assert False, space
|
||||
return u
|
||||
else:
|
||||
return obs
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
from gym.spaces import Box, Discrete, Tuple
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import one_hot
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
class ComplexInputNetwork(TFModelV2):
|
||||
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
|
||||
|
||||
Note: This model should be used for complex (Dict or Tuple) observation
|
||||
spaces that have one or more image components.
|
||||
|
||||
The data flow is as follows:
|
||||
|
||||
`obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT`
|
||||
`CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out`
|
||||
`out` -> (optional) FC-stack -> `out2`
|
||||
`out2` -> action (logits) and vaulue heads.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# TODO: (sven) Support Dicts as well.
|
||||
self.original_space = obs_space.original_space if \
|
||||
hasattr(obs_space, "original_space") else obs_space
|
||||
assert isinstance(self.original_space, (Tuple)), \
|
||||
"`obs_space.original_space` must be Tuple!"
|
||||
|
||||
super().__init__(self.original_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
# Build the CNN(s) given obs_space's image components.
|
||||
self.cnns = {}
|
||||
self.one_hot = {}
|
||||
self.flatten = {}
|
||||
concat_size = 0
|
||||
for i, component in enumerate(self.original_space):
|
||||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
"post_fcnet_hiddens": [],
|
||||
}
|
||||
cnn = ModelCatalog.get_model_v2(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
framework="tf",
|
||||
name="cnn_{}".format(i))
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
self.one_hot[i] = True
|
||||
concat_size += component.n
|
||||
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
|
||||
# Everything else (1D Box).
|
||||
else:
|
||||
self.flatten[i] = int(np.product(component.shape))
|
||||
concat_size += self.flatten[i]
|
||||
|
||||
# Optional post-concat FC-stack.
|
||||
post_fc_stack_config = {
|
||||
"fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
|
||||
"fcnet_activation": model_config.get("post_fcnet_activation",
|
||||
"relu")
|
||||
}
|
||||
self.post_fc_stack = ModelCatalog.get_model_v2(
|
||||
Box(float("-inf"),
|
||||
float("inf"),
|
||||
shape=(concat_size, ),
|
||||
dtype=np.float32),
|
||||
self.action_space,
|
||||
None,
|
||||
post_fc_stack_config,
|
||||
framework="tf",
|
||||
name="post_fc_stack")
|
||||
|
||||
# Actions and value heads.
|
||||
self.logits_and_value_model = None
|
||||
self._value_out = None
|
||||
if num_outputs:
|
||||
# Action-distribution head.
|
||||
concat_layer = tf.keras.layers.Input(
|
||||
(self.post_fc_stack.num_outputs, ))
|
||||
logits_layer = tf.keras.layers.Dense(
|
||||
num_outputs,
|
||||
activation=tf.keras.activations.linear,
|
||||
name="logits")(concat_layer)
|
||||
|
||||
# Create the value branch model.
|
||||
value_layer = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(concat_layer)
|
||||
self.logits_and_value_model = tf.keras.models.Model(
|
||||
concat_layer, [logits_layer, value_layer])
|
||||
else:
|
||||
self.num_outputs = self.post_fc_stack.num_outputs
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
|
||||
orig_obs = input_dict[SampleBatch.OBS]
|
||||
else:
|
||||
orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS],
|
||||
self.obs_space, "tf")
|
||||
# Push image observations through our CNNs.
|
||||
outs = []
|
||||
for i, component in enumerate(orig_obs):
|
||||
if i in self.cnns:
|
||||
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
|
||||
outs.append(cnn_out)
|
||||
elif i in self.one_hot:
|
||||
if component.dtype in [tf.int32, tf.int64, tf.uint8]:
|
||||
outs.append(
|
||||
one_hot(component, self.original_space.spaces[i]))
|
||||
else:
|
||||
outs.append(component)
|
||||
else:
|
||||
outs.append(tf.reshape(component, [-1, self.flatten[i]]))
|
||||
# Concat all outputs and the non-image inputs.
|
||||
out = tf.concat(outs, axis=1)
|
||||
# Push through (optional) FC-stack (this may be an empty stack).
|
||||
out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None)
|
||||
|
||||
# No logits/value branches.
|
||||
if not self.logits_and_value_model:
|
||||
return out, []
|
||||
|
||||
# Logits- and value branches.
|
||||
logits, values = self.logits_and_value_model(out)
|
||||
self._value_out = tf.reshape(values, [-1])
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return self._value_out
|
||||
|
||||
|
||||
# __sphinx_doc_end__
|
||||
@@ -19,8 +19,12 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
super(FullyConnectedNetwork, self).__init__(
|
||||
obs_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
activation = get_activation_fn(model_config.get("fcnet_activation"))
|
||||
hiddens = model_config.get("fcnet_hiddens", [])
|
||||
hiddens = model_config.get("fcnet_hiddens", []) + \
|
||||
model_config.get("post_fcnet_hiddens", [])
|
||||
activation = model_config.get("fcnet_activation")
|
||||
if not model_config.get("fcnet_hiddens", []):
|
||||
activation = model_config.get("post_fcnet_activation")
|
||||
activation = get_activation_fn(activation)
|
||||
no_final_linear = model_config.get("no_final_linear")
|
||||
vf_share_layers = model_config.get("vf_share_layers")
|
||||
free_log_std = model_config.get("free_log_std")
|
||||
|
||||
@@ -107,7 +107,8 @@ class TFModelV2(ModelV2):
|
||||
if isinstance(struct, tf.keras.models.Model):
|
||||
ret = {}
|
||||
for var in struct.variables:
|
||||
key = current_key + "." + re.sub("/", ".", var.name)
|
||||
name = re.sub("/", ".", var.name)
|
||||
key = current_key + "." + name
|
||||
ret[key] = var
|
||||
return ret
|
||||
# Other TFModelV2: Include its vars into ours.
|
||||
@@ -118,7 +119,7 @@ class TFModelV2(ModelV2):
|
||||
}
|
||||
# tf.Variable
|
||||
elif isinstance(struct, tf.Variable):
|
||||
return {current_key + "." + struct.name: struct}
|
||||
return {current_key: struct}
|
||||
# List/Tuple.
|
||||
elif isinstance(struct, (tuple, list)):
|
||||
ret = {}
|
||||
@@ -133,7 +134,7 @@ class TFModelV2(ModelV2):
|
||||
current_key += "_"
|
||||
ret = {}
|
||||
for key, value in struct.items():
|
||||
sub_vars = TFModelV2._find_sub_modules(current_key + key,
|
||||
sub_vars = TFModelV2._find_sub_modules(current_key + str(key),
|
||||
value)
|
||||
ret.update(sub_vars)
|
||||
return ret
|
||||
|
||||
@@ -13,7 +13,17 @@ tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class VisionNetwork(TFModelV2):
|
||||
"""Generic vision network implemented in ModelV2 API."""
|
||||
"""Generic vision network implemented in ModelV2 API.
|
||||
|
||||
An additional post-conv fully connected stack can be added and configured
|
||||
via the config keys:
|
||||
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
|
||||
`post_fcnet_activation`: Activation function to use for this FC stack.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, num_outputs: int,
|
||||
@@ -29,6 +39,12 @@ class VisionNetwork(TFModelV2):
|
||||
filters = self.model_config["conv_filters"]
|
||||
assert len(filters) > 0,\
|
||||
"Must provide at least 1 entry in `conv_filters`!"
|
||||
|
||||
# Post FC net config.
|
||||
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
|
||||
post_fcnet_activation = get_activation_fn(
|
||||
model_config.get("post_fcnet_activation"), framework="tf")
|
||||
|
||||
no_final_linear = self.model_config.get("no_final_linear")
|
||||
vf_share_layers = self.model_config.get("vf_share_layers")
|
||||
self.traj_view_framestacking = False
|
||||
@@ -62,17 +78,29 @@ class VisionNetwork(TFModelV2):
|
||||
|
||||
out_size, kernel, stride = filters[-1]
|
||||
|
||||
# No final linear: Last layer is a Conv2D and uses num_outputs.
|
||||
# No final linear: Last layer has activation function and exits with
|
||||
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
|
||||
# on `post_fcnet_...` settings).
|
||||
if no_final_linear and num_outputs:
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
num_outputs,
|
||||
out_size if post_fcnet_hiddens else num_outputs,
|
||||
kernel,
|
||||
strides=(stride, stride),
|
||||
activation=activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
conv_out = last_layer
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
|
||||
if post_fcnet_hiddens else
|
||||
[])
|
||||
for i, out_size in enumerate(layer_sizes):
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i),
|
||||
activation=post_fcnet_activation,
|
||||
kernel_initializer=normc_initializer(1.0))(last_layer)
|
||||
|
||||
# Finish network normally (w/o overriding last layer size with
|
||||
# `num_outputs`), then add another linear one of size `num_outputs`.
|
||||
else:
|
||||
@@ -88,29 +116,56 @@ class VisionNetwork(TFModelV2):
|
||||
# num_outputs defined. Use that to create an exact
|
||||
# `num_output`-sized (1,1)-Conv2D.
|
||||
if num_outputs:
|
||||
conv_out = tf.keras.layers.Conv2D(
|
||||
num_outputs, [1, 1],
|
||||
activation=None,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
if post_fcnet_hiddens:
|
||||
last_cnn = last_layer = tf.keras.layers.Conv2D(
|
||||
post_fcnet_hiddens[0], [1, 1],
|
||||
activation=post_fcnet_activation,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
|
||||
[num_outputs]):
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i + 1),
|
||||
activation=post_fcnet_activation
|
||||
if i < len(post_fcnet_hiddens) - 1 else None,
|
||||
kernel_initializer=normc_initializer(1.0))(
|
||||
last_layer)
|
||||
else:
|
||||
last_cnn = last_layer = tf.keras.layers.Conv2D(
|
||||
num_outputs, [1, 1],
|
||||
activation=None,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
|
||||
if conv_out.shape[1] != 1 or conv_out.shape[2] != 1:
|
||||
if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1:
|
||||
raise ValueError(
|
||||
"Given `conv_filters` ({}) do not result in a [B, 1, "
|
||||
"1, {} (`num_outputs`)] shape (but in {})! Please "
|
||||
"adjust your Conv2D stack such that the dims 1 and 2 "
|
||||
"are both 1.".format(self.model_config["conv_filters"],
|
||||
self.num_outputs,
|
||||
list(conv_out.shape)))
|
||||
list(last_cnn.shape)))
|
||||
|
||||
# num_outputs not known -> Flatten, then set self.num_outputs
|
||||
# to the resulting number of nodes.
|
||||
else:
|
||||
self.last_layer_is_flattened = True
|
||||
conv_out = tf.keras.layers.Flatten(
|
||||
last_layer = tf.keras.layers.Flatten(
|
||||
data_format="channels_last")(last_layer)
|
||||
self.num_outputs = conv_out.shape[1]
|
||||
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
for i, out_size in enumerate(post_fcnet_hiddens):
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i),
|
||||
activation=post_fcnet_activation,
|
||||
kernel_initializer=normc_initializer(1.0))(last_layer)
|
||||
self.num_outputs = last_layer.shape[1]
|
||||
logits_out = last_layer
|
||||
|
||||
# Build the value layers
|
||||
if vf_share_layers:
|
||||
@@ -151,7 +206,7 @@ class VisionNetwork(TFModelV2):
|
||||
value_out = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
|
||||
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
|
||||
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
|
||||
|
||||
# Optional: framestacking obs/new_obs for Atari.
|
||||
if self.traj_view_framestacking:
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
from gym.spaces import Box, Discrete, Tuple
|
||||
import numpy as np
|
||||
|
||||
# TODO (sven): add IMPALA-style option.
|
||||
# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
|
||||
from ray.rllib.models.torch.misc import normc_initializer as \
|
||||
torch_normc_initializer, SlimFC
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import one_hot
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class ComplexInputNetwork(TorchModelV2, nn.Module):
|
||||
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
|
||||
|
||||
Note: This model should be used for complex (Dict or Tuple) observation
|
||||
spaces that have one or more image components.
|
||||
|
||||
The data flow is as follows:
|
||||
|
||||
`obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT`
|
||||
`CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out`
|
||||
`out` -> (optional) FC-stack -> `out2`
|
||||
`out2` -> action (logits) and vaulue heads.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
# TODO: (sven) Support Dicts as well.
|
||||
self.original_space = obs_space.original_space if \
|
||||
hasattr(obs_space, "original_space") else obs_space
|
||||
assert isinstance(self.original_space, (Tuple)), \
|
||||
"`obs_space.original_space` must be Tuple!"
|
||||
|
||||
nn.Module.__init__(self)
|
||||
TorchModelV2.__init__(self, self.original_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
|
||||
# self.cnn_type = self.model_config["custom_model_config"].get(
|
||||
# "conv_type", "atari")
|
||||
|
||||
# Build the CNN(s) given obs_space's image components.
|
||||
self.cnns = {}
|
||||
self.one_hot = {}
|
||||
self.flatten = {}
|
||||
concat_size = 0
|
||||
for i, component in enumerate(self.original_space):
|
||||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
"post_fcnet_hiddens": [],
|
||||
}
|
||||
# if self.cnn_type == "atari":
|
||||
cnn = ModelCatalog.get_model_v2(
|
||||
component,
|
||||
action_space,
|
||||
num_outputs=None,
|
||||
model_config=config,
|
||||
framework="torch",
|
||||
name="cnn_{}".format(i))
|
||||
# TODO (sven): add IMPALA-style option.
|
||||
# else:
|
||||
# cnn = TorchImpalaVisionNet(
|
||||
# component,
|
||||
# action_space,
|
||||
# num_outputs=None,
|
||||
# model_config=config,
|
||||
# name="cnn_{}".format(i))
|
||||
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
self.add_module("cnn_{}".format(i), cnn)
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
self.one_hot[i] = True
|
||||
concat_size += component.n
|
||||
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
|
||||
# Everything else (1D Box).
|
||||
else:
|
||||
self.flatten[i] = int(np.product(component.shape))
|
||||
concat_size += self.flatten[i]
|
||||
|
||||
# Optional post-concat FC-stack.
|
||||
post_fc_stack_config = {
|
||||
"fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
|
||||
"fcnet_activation": model_config.get("post_fcnet_activation",
|
||||
"relu")
|
||||
}
|
||||
self.post_fc_stack = ModelCatalog.get_model_v2(
|
||||
Box(float("-inf"),
|
||||
float("inf"),
|
||||
shape=(concat_size, ),
|
||||
dtype=np.float32),
|
||||
self.action_space,
|
||||
None,
|
||||
post_fc_stack_config,
|
||||
framework="torch",
|
||||
name="post_fc_stack")
|
||||
|
||||
# Actions and value heads.
|
||||
self.logits_layer = None
|
||||
self.value_layer = None
|
||||
self._value_out = None
|
||||
|
||||
if num_outputs:
|
||||
# Action-distribution head.
|
||||
self.logits_layer = SlimFC(
|
||||
in_size=self.post_fc_stack.num_outputs,
|
||||
out_size=num_outputs,
|
||||
activation_fn=None,
|
||||
)
|
||||
# Create the value branch model.
|
||||
self.value_layer = SlimFC(
|
||||
in_size=self.post_fc_stack.num_outputs,
|
||||
out_size=1,
|
||||
activation_fn=None,
|
||||
initializer=torch_normc_initializer(0.01))
|
||||
else:
|
||||
self.num_outputs = concat_size
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Push image observations through our CNNs.
|
||||
outs = []
|
||||
for i, component in enumerate(input_dict["obs"]):
|
||||
if i in self.cnns:
|
||||
cnn_out, _ = self.cnns[i]({"obs": component})
|
||||
outs.append(cnn_out)
|
||||
elif i in self.one_hot:
|
||||
if component.dtype in [torch.int32, torch.int64, torch.uint8]:
|
||||
outs.append(
|
||||
one_hot(component, self.original_space.spaces[i]))
|
||||
else:
|
||||
outs.append(component)
|
||||
else:
|
||||
outs.append(torch.reshape(component, [-1, self.flatten[i]]))
|
||||
# Concat all outputs and the non-image inputs.
|
||||
out = torch.cat(outs, dim=1)
|
||||
# Push through (optional) FC-stack (this may be an empty stack).
|
||||
out, _ = self.post_fc_stack({"obs": out}, [], None)
|
||||
|
||||
# No logits/value branches.
|
||||
if self.logits_layer is None:
|
||||
return out, []
|
||||
|
||||
# Logits- and value branches.
|
||||
logits, values = self.logits_layer(out), self.value_layer(out)
|
||||
self._value_out = torch.reshape(values, [-1])
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return self._value_out
|
||||
@@ -24,8 +24,11 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
hiddens = model_config.get("fcnet_hiddens", []) + \
|
||||
model_config.get("post_fcnet_hiddens", [])
|
||||
activation = model_config.get("fcnet_activation")
|
||||
hiddens = model_config.get("fcnet_hiddens", [])
|
||||
if not model_config.get("fcnet_hiddens", []):
|
||||
activation = model_config.get("post_fcnet_activation")
|
||||
no_final_linear = model_config.get("no_final_linear")
|
||||
self.vf_share_layers = model_config.get("vf_share_layers")
|
||||
self.free_log_std = model_config.get("free_log_std")
|
||||
|
||||
@@ -5,7 +5,7 @@ import gym
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.misc import normc_initializer, same_padding, \
|
||||
SlimConv2d, SlimFC
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.models.utils import get_activation_fn, get_filter_config
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -33,6 +33,12 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
filters = self.model_config["conv_filters"]
|
||||
assert len(filters) > 0,\
|
||||
"Must provide at least 1 entry in `conv_filters`!"
|
||||
|
||||
# Post FC net config.
|
||||
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
|
||||
post_fcnet_activation = get_activation_fn(
|
||||
model_config.get("post_fcnet_activation"), framework="torch")
|
||||
|
||||
no_final_linear = self.model_config.get("no_final_linear")
|
||||
vf_share_layers = self.model_config.get("vf_share_layers")
|
||||
|
||||
@@ -68,17 +74,33 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
|
||||
out_channels, kernel, stride = filters[-1]
|
||||
|
||||
# No final linear: Last layer is a Conv2D and uses num_outputs.
|
||||
# No final linear: Last layer has activation function and exits with
|
||||
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
|
||||
# on `post_fcnet_...` settings).
|
||||
if no_final_linear and num_outputs:
|
||||
out_channels = out_channels if post_fcnet_hiddens else num_outputs
|
||||
layers.append(
|
||||
SlimConv2d(
|
||||
in_channels,
|
||||
num_outputs,
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
None, # padding=valid
|
||||
activation_fn=activation))
|
||||
out_channels = num_outputs
|
||||
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
|
||||
if post_fcnet_hiddens else
|
||||
[])
|
||||
for i, out_size in enumerate(layer_sizes):
|
||||
layers.append(
|
||||
SlimFC(
|
||||
in_size=out_channels,
|
||||
out_size=out_size,
|
||||
activation_fn=post_fcnet_activation,
|
||||
initializer=normc_initializer(1.0)))
|
||||
out_channels = out_size
|
||||
|
||||
# Finish network normally (w/o overriding last layer size with
|
||||
# `num_outputs`), then add another linear one of size `num_outputs`.
|
||||
else:
|
||||
@@ -99,12 +121,31 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
np.ceil((in_size[1] - kernel[1]) / stride)
|
||||
]
|
||||
padding, _ = same_padding(in_size, [1, 1], [1, 1])
|
||||
self._logits = SlimConv2d(
|
||||
out_channels,
|
||||
num_outputs, [1, 1],
|
||||
1,
|
||||
padding,
|
||||
activation_fn=None)
|
||||
if post_fcnet_hiddens:
|
||||
layers.append(nn.Flatten())
|
||||
in_size = out_channels
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
for i, out_size in enumerate(post_fcnet_hiddens +
|
||||
[num_outputs]):
|
||||
layers.append(
|
||||
SlimFC(
|
||||
in_size=in_size,
|
||||
out_size=out_size,
|
||||
activation_fn=post_fcnet_activation
|
||||
if i < len(post_fcnet_hiddens) - 1 else None,
|
||||
initializer=normc_initializer(1.0)))
|
||||
in_size = out_size
|
||||
# Last layer is logits layer.
|
||||
self._logits = layers.pop()
|
||||
|
||||
else:
|
||||
self._logits = SlimConv2d(
|
||||
out_channels,
|
||||
num_outputs, [1, 1],
|
||||
1,
|
||||
padding,
|
||||
activation_fn=None)
|
||||
|
||||
# num_outputs not known -> Flatten, then set self.num_outputs
|
||||
# to the resulting number of nodes.
|
||||
else:
|
||||
@@ -196,16 +237,19 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
if not self.last_layer_is_flattened:
|
||||
if self._logits:
|
||||
conv_out = self._logits(conv_out)
|
||||
if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
|
||||
raise ValueError(
|
||||
"Given `conv_filters` ({}) do not result in a [B, {} "
|
||||
"(`num_outputs`), 1, 1] shape (but in {})! Please adjust "
|
||||
"your Conv2D stack such that the last 2 dims are both "
|
||||
"1.".format(self.model_config["conv_filters"],
|
||||
self.num_outputs, list(conv_out.shape)))
|
||||
logits = conv_out.squeeze(3)
|
||||
logits = logits.squeeze(2)
|
||||
|
||||
if len(conv_out.shape) == 4:
|
||||
if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
|
||||
raise ValueError(
|
||||
"Given `conv_filters` ({}) do not result in a [B, {} "
|
||||
"(`num_outputs`), 1, 1] shape (but in {})! Please "
|
||||
"adjust your Conv2D stack such that the last 2 dims "
|
||||
"are both 1.".format(self.model_config["conv_filters"],
|
||||
self.num_outputs,
|
||||
list(conv_out.shape)))
|
||||
logits = conv_out.squeeze(3)
|
||||
logits = logits.squeeze(2)
|
||||
else:
|
||||
logits = conv_out
|
||||
return logits, state
|
||||
else:
|
||||
return conv_out, state
|
||||
|
||||
@@ -177,8 +177,8 @@ class TestComputeLogLikelihood(unittest.TestCase):
|
||||
config,
|
||||
prev_a,
|
||||
continuous=True,
|
||||
layer_key=("sequential/action", (2, 4),
|
||||
("action_model.action_0.", "action_model.action_out.")),
|
||||
layer_key=("fc", (0, 2), ("action_model._hidden_layers.0.",
|
||||
"action_model._logits.")),
|
||||
logp_func=logp_func)
|
||||
|
||||
def test_sac_discr(self):
|
||||
@@ -188,12 +188,7 @@ class TestComputeLogLikelihood(unittest.TestCase):
|
||||
config["policy_model"]["fcnet_activation"] = "linear"
|
||||
prev_a = np.array(0)
|
||||
|
||||
do_test_log_likelihood(
|
||||
sac.SACTrainer,
|
||||
config,
|
||||
prev_a,
|
||||
layer_key=("sequential/action", (0, 2),
|
||||
("action_model.action_0.", "action_model.action_out.")))
|
||||
do_test_log_likelihood(sac.SACTrainer, config, prev_a)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -37,6 +37,10 @@ parser.add_argument(
|
||||
"--yaml-dir",
|
||||
type=str,
|
||||
help="The directory in which to find all yamls to test.")
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
help="Run ray in local mode for easier debugging.")
|
||||
|
||||
# Obsoleted arg, use --framework=torch instead.
|
||||
parser.add_argument(
|
||||
@@ -92,7 +96,7 @@ if __name__ == "__main__":
|
||||
passed = False
|
||||
for i in range(3):
|
||||
try:
|
||||
ray.init(num_cpus=5)
|
||||
ray.init(num_cpus=5, local_mode=args.local_mode)
|
||||
trials = run_experiments(experiments, resume=False, verbose=2)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
|
||||
@@ -333,7 +333,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
def test_invalid_model2(self):
|
||||
ModelCatalog.register_custom_model("invalid2", InvalidModel2)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Expected output shape of",
|
||||
ValueError, "State output is not a list",
|
||||
lambda: PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
|
||||
@@ -15,7 +15,7 @@ from ray.rllib.utils.test_utils import framework_iterator
|
||||
ACTION_SPACES_TO_TEST = {
|
||||
"discrete": Discrete(5),
|
||||
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
||||
# "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
||||
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
||||
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
||||
"tuple": Tuple(
|
||||
[Discrete(2),
|
||||
@@ -63,8 +63,6 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False):
|
||||
p_done=1.0,
|
||||
check_action_bounds=check_bounds)))
|
||||
stat = "ok"
|
||||
if alg == "SAC":
|
||||
config["use_state_preprocessor"] = o_name in ["atari", "image"]
|
||||
|
||||
try:
|
||||
a = get_agent_class(alg)(config=config, env=RandomEnv)
|
||||
|
||||
@@ -14,8 +14,6 @@ atari-sac-tf-and-torch:
|
||||
framework:
|
||||
grid_search: [tf, torch]
|
||||
gamma: 0.99
|
||||
# state-preprocessor=Our default Atari Conv2D-net.
|
||||
use_state_preprocessor: true
|
||||
Q_model:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [512]
|
||||
|
||||
@@ -11,8 +11,6 @@ mspacman-sac-tf:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
gamma: 0.99
|
||||
# state-preprocessor=Our default Atari Conv2D-net.
|
||||
use_state_preprocessor: true
|
||||
Q_model:
|
||||
fcnet_hiddens: [512]
|
||||
fcnet_activation: relu
|
||||
|
||||
@@ -301,13 +301,10 @@ def check_compute_single_action(trainer,
|
||||
assert worker_set
|
||||
if isinstance(worker_set, list):
|
||||
obs_space = trainer.get_policy().observation_space
|
||||
try:
|
||||
obs_space = obs_space.original_space
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
obs_space = worker_set.local_worker().for_policy(
|
||||
lambda p: p.observation_space)
|
||||
obs_space = getattr(obs_space, "original_space", obs_space)
|
||||
else:
|
||||
method_to_test = pol.compute_single_action
|
||||
obs_space = pol.observation_space
|
||||
|
||||
@@ -22,6 +22,6 @@ def with_lock(func: Callable):
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"Object {} must have a `self._lock` property (assigned to a "
|
||||
"threading.Lock() object in its constructor)!".format(self))
|
||||
"threading.RLock() object in its constructor)!".format(self))
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user