mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[RLlib] Issue 13330: No TF installed causes crash in ModelCatalog.get_action_shape() (#13332)
This commit is contained in:
+12
-7
@@ -20,12 +20,13 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_space
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -251,22 +252,25 @@ class ModelCatalog:
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]):
|
||||
def get_action_shape(action_space: gym.Space,
|
||||
framework: str = "tf") -> (np.dtype, List[int]):
|
||||
"""Returns action tensor dtype and shape for the action space.
|
||||
|
||||
Args:
|
||||
action_space (Space): Action space of the target gym env.
|
||||
framework (str): The framework identifier. One of "tf" or "torch".
|
||||
|
||||
Returns:
|
||||
(dtype, shape): Dtype and shape of the actions tensor.
|
||||
"""
|
||||
dl_lib = torch if framework == "torch" else tf
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
return (action_space.dtype, (None, ))
|
||||
return action_space.dtype, (None, )
|
||||
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
|
||||
return (tf.float32, (None, ) + action_space.shape)
|
||||
return dl_lib.float32, (None, ) + action_space.shape
|
||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||
return (tf.as_dtype(action_space.dtype),
|
||||
(None, ) + action_space.shape)
|
||||
return action_space.dtype, (None, ) + action_space.shape
|
||||
elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
|
||||
flat_action_space = flatten_space(action_space)
|
||||
size = 0
|
||||
@@ -278,7 +282,8 @@ class ModelCatalog:
|
||||
all_discrete = False
|
||||
size += np.product(flat_action_space[i].shape)
|
||||
size = int(size)
|
||||
return (tf.int64 if all_discrete else tf.float32, (None, size))
|
||||
return dl_lib.int64 if all_discrete else dl_lib.float32, \
|
||||
(None, size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Action space {} not supported".format(action_space))
|
||||
|
||||
Reference in New Issue
Block a user