diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 5cf40408e..2fe4458f8 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -450,7 +450,7 @@ class ModelCatalog(object): else: obs_rank = len(obs_space.shape) - if obs_rank > 1: + if obs_rank > 2: return PyTorchVisionNet(obs_space, action_space, num_outputs, model_config, name) @@ -506,7 +506,7 @@ class ModelCatalog(object): obs_rank = len(input_dict["obs"].shape) - 1 - if obs_rank > 1: + if obs_rank > 2: return VisionNetwork(input_dict, obs_space, action_space, num_outputs, options) @@ -521,7 +521,7 @@ class ModelCatalog(object): if options.get("use_lstm"): return None # TODO: default LSTM v2 not implemented - if obs_rank > 1: + if obs_rank > 2: return VisionNetV2 return FCNetV2 diff --git a/rllib/models/model.py b/rllib/models/model.py index f60a5a82f..309caa460 100644 --- a/rllib/models/model.py +++ b/rllib/models/model.py @@ -190,6 +190,18 @@ class Model(object): self._num_outputs, shape)) +@DeveloperAPI +def flatten(obs, framework): + """Flatten the given tensor.""" + if framework == "tf": + return tf.layers.flatten(obs) + elif framework == "torch": + import torch + return torch.flatten(obs, start_dim=1) + else: + raise NotImplementedError("flatten", framework) + + @DeveloperAPI def restore_original_dimensions(obs, obs_space, tensorlib=tf): """Unpacks Dict and Tuple space observations into their original form. diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 9954e01e9..568963da0 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -3,7 +3,7 @@ from __future__ import division from __future__ import print_function from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.model import restore_original_dimensions +from ray.rllib.models.model import restore_original_dimensions, flatten from ray.rllib.utils.annotations import PublicAPI @@ -146,7 +146,10 @@ class ModelV2(object): restored = input_dict.copy() restored["obs"] = restore_original_dimensions( input_dict["obs"], self.obs_space, self.framework) - restored["obs_flat"] = input_dict["obs"] + if len(input_dict["obs"].shape) > 2: + restored["obs_flat"] = flatten(input_dict["obs"], self.framework) + else: + restored["obs_flat"] = input_dict["obs"] with self.context(): res = self.forward(restored, state or [], seq_lens) if ((not isinstance(res, list) and not isinstance(res, tuple)) diff --git a/rllib/models/tf/fcnet_v1.py b/rllib/models/tf/fcnet_v1.py index 7663ba64e..d4f645ab0 100644 --- a/rllib/models/tf/fcnet_v1.py +++ b/rllib/models/tf/fcnet_v1.py @@ -25,6 +25,9 @@ class FullyConnectedNetwork(Model): hiddens = options.get("fcnet_hiddens") activation = get_activation_fn(options.get("fcnet_activation")) + if len(inputs.shape) > 2: + inputs = tf.layers.flatten(inputs) + with tf.name_scope("fc_net"): i = 1 last_layer = inputs diff --git a/rllib/models/tf/fcnet_v2.py b/rllib/models/tf/fcnet_v2.py index 1201fa858..2932659c5 100644 --- a/rllib/models/tf/fcnet_v2.py +++ b/rllib/models/tf/fcnet_v2.py @@ -2,6 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn from ray.rllib.utils import try_import_tf @@ -22,8 +24,9 @@ class FullyConnectedNetwork(TFModelV2): no_final_linear = model_config.get("no_final_linear") vf_share_layers = model_config.get("vf_share_layers") + # we are using obs_flat, so take the flattened shape as input inputs = tf.keras.layers.Input( - shape=obs_space.shape, name="observations") + shape=(np.product(obs_space.shape), ), name="observations") last_layer = inputs i = 1 diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index b007679fb..edd0beb7d 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -31,6 +31,7 @@ ACTION_SPACES_TO_TEST = { OBSERVATION_SPACES_TO_TEST = { "discrete": Discrete(5), "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), + "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32), "image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32), "atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32), "tuple": Tuple([Discrete(10), @@ -106,6 +107,7 @@ def check_support(alg, config, stats, check_bounds=False, name=None): def check_support_multiagent(alg, config): register_env("multi_mountaincar", lambda _: MultiMountainCar(2)) register_env("multi_cartpole", lambda _: MultiCartpole(2)) + config["log_level"] = "ERROR" if "DDPG" in alg: a = get_agent_class(alg)(config=config, env="multi_mountaincar") else: