From 5d50d37f45ae30913dd28fafbee11cb5025f497f Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 11 Jan 2021 13:19:46 +0100 Subject: [PATCH] [RLlib] Issue 13330: No TF installed causes crash in `ModelCatalog.get_action_shape()` (#13332) --- rllib/models/catalog.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index eb388548d..9638ed44b 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -20,12 +20,13 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import flatten_space from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() logger = logging.getLogger(__name__) @@ -251,22 +252,25 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]): + def get_action_shape(action_space: gym.Space, + framework: str = "tf") -> (np.dtype, List[int]): """Returns action tensor dtype and shape for the action space. Args: action_space (Space): Action space of the target gym env. + framework (str): The framework identifier. One of "tf" or "torch". + Returns: (dtype, shape): Dtype and shape of the actions tensor. """ + dl_lib = torch if framework == "torch" else tf if isinstance(action_space, gym.spaces.Discrete): - return (action_space.dtype, (None, )) + return action_space.dtype, (None, ) elif isinstance(action_space, (gym.spaces.Box, Simplex)): - return (tf.float32, (None, ) + action_space.shape) + return dl_lib.float32, (None, ) + action_space.shape elif isinstance(action_space, gym.spaces.MultiDiscrete): - return (tf.as_dtype(action_space.dtype), - (None, ) + action_space.shape) + return action_space.dtype, (None, ) + action_space.shape elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)): flat_action_space = flatten_space(action_space) size = 0 @@ -278,7 +282,8 @@ class ModelCatalog: all_discrete = False size += np.product(flat_action_space[i].shape) size = int(size) - return (tf.int64 if all_discrete else tf.float32, (None, size)) + return dl_lib.int64 if all_discrete else dl_lib.float32, \ + (None, size) else: raise NotImplementedError( "Action space {} not supported".format(action_space))