From 2a0225dd25a2cb52dec60f9e5e756e3d755bc862 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 5 Nov 2019 11:36:29 -0800 Subject: [PATCH] [rllib] RLlib chooses wrong neural network model for Atari in 0.7.5 (#6087) --- rllib/models/catalog.py | 4 ++-- rllib/tests/test_io.py | 3 ++- rllib/tests/test_supported_spaces.py | 8 ++++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 2fe4458f8..f79a016d4 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -504,7 +504,7 @@ class ModelCatalog(object): state_in=state_in, seq_lens=seq_lens) - obs_rank = len(input_dict["obs"].shape) - 1 + obs_rank = len(input_dict["obs"].shape) - 1 # drops batch dim if obs_rank > 2: return VisionNetwork(input_dict, obs_space, action_space, @@ -516,7 +516,7 @@ class ModelCatalog(object): @staticmethod def _get_v2_model(obs_space, options): options = options or MODEL_DEFAULTS - obs_rank = len(obs_space.shape) - 1 + obs_rank = len(obs_space.shape) if options.get("use_lstm"): return None # TODO: default LSTM v2 not implemented diff --git a/rllib/tests/test_io.py b/rllib/tests/test_io.py index c98e4553d..359be1404 100644 --- a/rllib/tests/test_io.py +++ b/rllib/tests/test_io.py @@ -240,7 +240,8 @@ class JsonIOTest(unittest.TestCase): self.assertEqual(len(os.listdir(self.test_dir)), 0) for _ in range(100): writer.write(SAMPLES) - self.assertEqual(len(os.listdir(self.test_dir)), 12) + num_files = len(os.listdir(self.test_dir)) + assert num_files in [12, 13], num_files def testReadWrite(self): ioctx = IOContext(self.test_dir, {}, 0, None) diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index edd0beb7d..f4784e439 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -9,6 +9,8 @@ import sys import ray from ray.rllib.agents.registry import get_agent_class +from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2 +from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2 from ray.rllib.tests.test_multi_agent_env import (MultiCartpole, MultiMountainCar) from ray.rllib.utils.error import UnsupportedSpaceException @@ -83,6 +85,12 @@ def check_support(alg, config, stats, check_bounds=False, name=None): stat = "skip" # speed up tests by avoiding full grid else: a = get_agent_class(alg)(config=config, env="stub_env") + if alg not in ["DDPG", "ES", "ARS"]: + if o_name in ["atari", "image"]: + assert isinstance(a.get_policy().model, + VisionNetV2) + elif o_name in ["vector", "vector2"]: + assert isinstance(a.get_policy().model, FCNetV2) a.train() covered_a.add(a_name) covered_o.add(o_name)