mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 14:19:24 +08:00
[rllib] RLlib chooses wrong neural network model for Atari in 0.7.5 (#6087)
This commit is contained in:
@@ -504,7 +504,7 @@ class ModelCatalog(object):
|
||||
state_in=state_in,
|
||||
seq_lens=seq_lens)
|
||||
|
||||
obs_rank = len(input_dict["obs"].shape) - 1
|
||||
obs_rank = len(input_dict["obs"].shape) - 1 # drops batch dim
|
||||
|
||||
if obs_rank > 2:
|
||||
return VisionNetwork(input_dict, obs_space, action_space,
|
||||
@@ -516,7 +516,7 @@ class ModelCatalog(object):
|
||||
@staticmethod
|
||||
def _get_v2_model(obs_space, options):
|
||||
options = options or MODEL_DEFAULTS
|
||||
obs_rank = len(obs_space.shape) - 1
|
||||
obs_rank = len(obs_space.shape)
|
||||
|
||||
if options.get("use_lstm"):
|
||||
return None # TODO: default LSTM v2 not implemented
|
||||
|
||||
@@ -240,7 +240,8 @@ class JsonIOTest(unittest.TestCase):
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
for _ in range(100):
|
||||
writer.write(SAMPLES)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 12)
|
||||
num_files = len(os.listdir(self.test_dir))
|
||||
assert num_files in [12, 13], num_files
|
||||
|
||||
def testReadWrite(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
|
||||
@@ -9,6 +9,8 @@ import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
|
||||
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
|
||||
from ray.rllib.tests.test_multi_agent_env import (MultiCartpole,
|
||||
MultiMountainCar)
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
@@ -83,6 +85,12 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
|
||||
stat = "skip" # speed up tests by avoiding full grid
|
||||
else:
|
||||
a = get_agent_class(alg)(config=config, env="stub_env")
|
||||
if alg not in ["DDPG", "ES", "ARS"]:
|
||||
if o_name in ["atari", "image"]:
|
||||
assert isinstance(a.get_policy().model,
|
||||
VisionNetV2)
|
||||
elif o_name in ["vector", "vector2"]:
|
||||
assert isinstance(a.get_policy().model, FCNetV2)
|
||||
a.train()
|
||||
covered_a.add(a_name)
|
||||
covered_o.add(o_name)
|
||||
|
||||
Reference in New Issue
Block a user