[RLlib] Issue #13342: Add validate_spaces to MB-MPO. (#14038)

This commit is contained in:
Sven Mika
2021-02-11 11:36:53 +01:00
committed by GitHub
parent f6cfc44dbd
commit a2f7998026
3 changed files with 34 additions and 1 deletions
+31
View File
@@ -1,4 +1,5 @@
import gym
from gym.spaces import Box, Discrete
import logging
from typing import Tuple, Type
@@ -13,6 +14,7 @@ from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.typing import TrainerConfigDict
@@ -22,6 +24,35 @@ torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
"""Validates the observation- and action spaces used for the Policy.
Args:
policy (Policy): The policy, whose spaces are being validated.
observation_space (gym.spaces.Space): The observation space to
validate.
action_space (gym.spaces.Space): The action space to validate.
config (TrainerConfigDict): The Policy's config dict.
Raises:
UnsupportedSpaceException: If one of the spaces is not supported.
"""
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, (Box, Discrete)):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "
"MB-MPO. Must be [Box|Discrete].".format(action_space, policy))
# If Box, make sure it's a 1D vector space.
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space ({}) of {} has multiple dimensions "
"{}. ".format(action_space, policy, action_space.shape) +
"Consider reshaping this into a single dimension Box space "
"or using the multi-agent API.")
def make_model_and_action_dist(
policy: Policy,
obs_space: gym.spaces.Space,
+2
View File
@@ -136,6 +136,8 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
obs_space.low[0],
obs_space.high[0],
shape=(obs_space.shape[0] + action_space.shape[0], ))
else:
raise NotImplementedError
super(DynamicsEnsembleCustomModel, self).__init__(
input_space, action_space, num_outputs, model_config, name)
+1 -1
View File
@@ -652,7 +652,7 @@ def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
Raises:
UnsupportedSpaceException: If one of the spaces is not supported.
"""
# Only support single Box or single Discreete spaces.
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, (Box, Discrete, Simplex)):
raise UnsupportedSpaceException(
"Action space ({}) of {} is not supported for "