[rllib] RLlib chooses wrong neural network model for Atari in 0.7.5 (#6087)

This commit is contained in:
Eric Liang
2019-11-05 11:36:29 -08:00
committed by GitHub
parent 8f6d73a93a
commit 2a0225dd25
3 changed files with 12 additions and 3 deletions
+2 -2
View File
@@ -504,7 +504,7 @@ class ModelCatalog(object):
state_in=state_in,
seq_lens=seq_lens)
obs_rank = len(input_dict["obs"].shape) - 1
obs_rank = len(input_dict["obs"].shape) - 1 # drops batch dim
if obs_rank > 2:
return VisionNetwork(input_dict, obs_space, action_space,
@@ -516,7 +516,7 @@ class ModelCatalog(object):
@staticmethod
def _get_v2_model(obs_space, options):
options = options or MODEL_DEFAULTS
obs_rank = len(obs_space.shape) - 1
obs_rank = len(obs_space.shape)
if options.get("use_lstm"):
return None # TODO: default LSTM v2 not implemented
+2 -1
View File
@@ -240,7 +240,8 @@ class JsonIOTest(unittest.TestCase):
self.assertEqual(len(os.listdir(self.test_dir)), 0)
for _ in range(100):
writer.write(SAMPLES)
self.assertEqual(len(os.listdir(self.test_dir)), 12)
num_files = len(os.listdir(self.test_dir))
assert num_files in [12, 13], num_files
def testReadWrite(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
+8
View File
@@ -9,6 +9,8 @@ import sys
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
from ray.rllib.tests.test_multi_agent_env import (MultiCartpole,
MultiMountainCar)
from ray.rllib.utils.error import UnsupportedSpaceException
@@ -83,6 +85,12 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
stat = "skip" # speed up tests by avoiding full grid
else:
a = get_agent_class(alg)(config=config, env="stub_env")
if alg not in ["DDPG", "ES", "ARS"]:
if o_name in ["atari", "image"]:
assert isinstance(a.get_policy().model,
VisionNetV2)
elif o_name in ["vector", "vector2"]:
assert isinstance(a.get_policy().model, FCNetV2)
a.train()
covered_a.add(a_name)
covered_o.add(o_name)