[RLlib] Issue 13330: No TF installed causes crash in ModelCatalog.get_action_shape() (#13332)

This commit is contained in:
Sven Mika
2021-01-11 13:19:46 +01:00
committed by GitHub
parent 93006c2ba5
commit 5d50d37f45
+12 -7
View File
@@ -20,12 +20,13 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
@@ -251,22 +252,25 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]):
def get_action_shape(action_space: gym.Space,
framework: str = "tf") -> (np.dtype, List[int]):
"""Returns action tensor dtype and shape for the action space.
Args:
action_space (Space): Action space of the target gym env.
framework (str): The framework identifier. One of "tf" or "torch".
Returns:
(dtype, shape): Dtype and shape of the actions tensor.
"""
dl_lib = torch if framework == "torch" else tf
if isinstance(action_space, gym.spaces.Discrete):
return (action_space.dtype, (None, ))
return action_space.dtype, (None, )
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
return (tf.float32, (None, ) + action_space.shape)
return dl_lib.float32, (None, ) + action_space.shape
elif isinstance(action_space, gym.spaces.MultiDiscrete):
return (tf.as_dtype(action_space.dtype),
(None, ) + action_space.shape)
return action_space.dtype, (None, ) + action_space.shape
elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
flat_action_space = flatten_space(action_space)
size = 0
@@ -278,7 +282,8 @@ class ModelCatalog:
all_discrete = False
size += np.product(flat_action_space[i].shape)
size = int(size)
return (tf.int64 if all_discrete else tf.float32, (None, size))
return dl_lib.int64 if all_discrete else dl_lib.float32, \
(None, size)
else:
raise NotImplementedError(
"Action space {} not supported".format(action_space))