mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 16:46:37 +08:00
@@ -1,4 +1,5 @@
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
@@ -13,6 +14,7 @@ from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
@@ -22,6 +24,35 @@ torch, nn = try_import_torch()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Validates the observation- and action spaces used for the Policy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The policy, whose spaces are being validated.
|
||||
observation_space (gym.spaces.Space): The observation space to
|
||||
validate.
|
||||
action_space (gym.spaces.Space): The action space to validate.
|
||||
config (TrainerConfigDict): The Policy's config dict.
|
||||
|
||||
Raises:
|
||||
UnsupportedSpaceException: If one of the spaces is not supported.
|
||||
"""
|
||||
# Only support single Box or single Discrete spaces.
|
||||
if not isinstance(action_space, (Box, Discrete)):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} is not supported for "
|
||||
"MB-MPO. Must be [Box|Discrete].".format(action_space, policy))
|
||||
# If Box, make sure it's a 1D vector space.
|
||||
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} has multiple dimensions "
|
||||
"{}. ".format(action_space, policy, action_space.shape) +
|
||||
"Consider reshaping this into a single dimension Box space "
|
||||
"or using the multi-agent API.")
|
||||
|
||||
|
||||
def make_model_and_action_dist(
|
||||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
|
||||
@@ -136,6 +136,8 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
||||
obs_space.low[0],
|
||||
obs_space.high[0],
|
||||
shape=(obs_space.shape[0] + action_space.shape[0], ))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
super(DynamicsEnsembleCustomModel, self).__init__(
|
||||
input_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
|
||||
@@ -652,7 +652,7 @@ def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
|
||||
Raises:
|
||||
UnsupportedSpaceException: If one of the spaces is not supported.
|
||||
"""
|
||||
# Only support single Box or single Discreete spaces.
|
||||
# Only support single Box or single Discrete spaces.
|
||||
if not isinstance(action_space, (Box, Discrete, Simplex)):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space ({}) of {} is not supported for "
|
||||
|
||||
Reference in New Issue
Block a user