diff --git a/rllib/agents/mbmpo/mbmpo_torch_policy.py b/rllib/agents/mbmpo/mbmpo_torch_policy.py index 06e65042e..5dc03435c 100644 --- a/rllib/agents/mbmpo/mbmpo_torch_policy.py +++ b/rllib/agents/mbmpo/mbmpo_torch_policy.py @@ -1,4 +1,5 @@ import gym +from gym.spaces import Box, Discrete import logging from typing import Tuple, Type @@ -13,6 +14,7 @@ from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import apply_grad_clipping from ray.rllib.utils.typing import TrainerConfigDict @@ -22,6 +24,35 @@ torch, nn = try_import_torch() logger = logging.getLogger(__name__) +def validate_spaces(policy: Policy, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict) -> None: + """Validates the observation- and action spaces used for the Policy. + + Args: + policy (Policy): The policy, whose spaces are being validated. + observation_space (gym.spaces.Space): The observation space to + validate. + action_space (gym.spaces.Space): The action space to validate. + config (TrainerConfigDict): The Policy's config dict. + + Raises: + UnsupportedSpaceException: If one of the spaces is not supported. + """ + # Only support single Box or single Discrete spaces. + if not isinstance(action_space, (Box, Discrete)): + raise UnsupportedSpaceException( + "Action space ({}) of {} is not supported for " + "MB-MPO. Must be [Box|Discrete].".format(action_space, policy)) + # If Box, make sure it's a 1D vector space. + elif isinstance(action_space, Box) and len(action_space.shape) > 1: + raise UnsupportedSpaceException( + "Action space ({}) of {} has multiple dimensions " + "{}. ".format(action_space, policy, action_space.shape) + + "Consider reshaping this into a single dimension Box space " + "or using the multi-agent API.") + + def make_model_and_action_dist( policy: Policy, obs_space: gym.spaces.Space, diff --git a/rllib/agents/mbmpo/model_ensemble.py b/rllib/agents/mbmpo/model_ensemble.py index 2bb9513da..f7cb35b6f 100644 --- a/rllib/agents/mbmpo/model_ensemble.py +++ b/rllib/agents/mbmpo/model_ensemble.py @@ -136,6 +136,8 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): obs_space.low[0], obs_space.high[0], shape=(obs_space.shape[0] + action_space.shape[0], )) + else: + raise NotImplementedError super(DynamicsEnsembleCustomModel, self).__init__( input_space, action_space, num_outputs, model_config, name) diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 83fa076ed..e4cc080af 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -652,7 +652,7 @@ def validate_spaces(policy: Policy, observation_space: gym.spaces.Space, Raises: UnsupportedSpaceException: If one of the spaces is not supported. """ - # Only support single Box or single Discreete spaces. + # Only support single Box or single Discrete spaces. if not isinstance(action_space, (Box, Discrete, Simplex)): raise UnsupportedSpaceException( "Action space ({}) of {} is not supported for "