mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 05:01:29 +08:00
[carla] [rllib] Add support for carla nav planner and scenarios from paper (#1382)
* wip * Sat Dec 30 15:07:28 PST 2017 * log video * video doesn't work well * scenario integration * Sat Dec 30 17:30:22 PST 2017 * Sat Dec 30 17:31:05 PST 2017 * Sat Dec 30 17:31:32 PST 2017 * Sat Dec 30 17:32:16 PST 2017 * Sat Dec 30 17:34:11 PST 2017 * Sat Dec 30 17:34:50 PST 2017 * Sat Dec 30 17:35:34 PST 2017 * Sat Dec 30 17:38:49 PST 2017 * Sat Dec 30 17:40:39 PST 2017 * Sat Dec 30 17:43:00 PST 2017 * Sat Dec 30 17:43:04 PST 2017 * Sat Dec 30 17:45:56 PST 2017 * Sat Dec 30 17:46:26 PST 2017 * Sat Dec 30 17:47:02 PST 2017 * Sat Dec 30 17:51:53 PST 2017 * Sat Dec 30 17:52:54 PST 2017 * Sat Dec 30 17:56:43 PST 2017 * Sat Dec 30 18:27:07 PST 2017 * Sat Dec 30 18:27:52 PST 2017 * fix train * Sat Dec 30 18:41:51 PST 2017 * Sat Dec 30 18:54:11 PST 2017 * Sat Dec 30 18:56:22 PST 2017 * Sat Dec 30 19:05:04 PST 2017 * Sat Dec 30 19:05:23 PST 2017 * Sat Dec 30 19:11:53 PST 2017 * Sat Dec 30 19:14:31 PST 2017 * Sat Dec 30 19:16:20 PST 2017 * Sat Dec 30 19:18:05 PST 2017 * Sat Dec 30 19:18:45 PST 2017 * Sat Dec 30 19:22:44 PST 2017 * Sat Dec 30 19:24:41 PST 2017 * Sat Dec 30 19:26:57 PST 2017 * Sat Dec 30 19:40:37 PST 2017 * wip models * reward bonus * test prep * Sun Dec 31 18:45:25 PST 2017 * Sun Dec 31 18:58:28 PST 2017 * Sun Dec 31 18:59:34 PST 2017 * Sun Dec 31 19:03:33 PST 2017 * Sun Dec 31 19:05:05 PST 2017 * Sun Dec 31 19:09:25 PST 2017 * fix train * kill * add tuple preprocessor * Sun Dec 31 20:38:33 PST 2017 * Sun Dec 31 22:51:24 PST 2017 * Sun Dec 31 23:14:13 PST 2017 * Sun Dec 31 23:16:04 PST 2017 * Mon Jan 1 00:08:35 PST 2018 * Mon Jan 1 00:10:48 PST 2018 * Mon Jan 1 01:08:31 PST 2018 * Mon Jan 1 14:45:44 PST 2018 * Mon Jan 1 14:54:56 PST 2018 * Mon Jan 1 17:29:29 PST 2018 * switch to euclidean dists * Mon Jan 1 17:39:27 PST 2018 * Mon Jan 1 17:41:47 PST 2018 * Mon Jan 1 17:44:18 PST 2018 * Mon Jan 1 17:47:09 PST 2018 * Mon Jan 1 20:31:02 PST 2018 * Mon Jan 1 20:39:33 PST 2018 * Mon Jan 1 20:40:55 PST 2018 * Mon Jan 1 20:55:06 PST 2018 * Mon Jan 1 21:05:52 PST 2018 * fix env path * merge richards fix * fix hash * Mon Jan 1 22:04:00 PST 2018 * Mon Jan 1 22:25:29 PST 2018 * Mon Jan 1 22:30:42 PST 2018 * simplified reward function * add framestack * add env configs * simplify speed reward * Tue Jan 2 17:36:15 PST 2018 * Tue Jan 2 17:49:16 PST 2018 * Tue Jan 2 18:10:38 PST 2018 * add lane keeping simple mode * Tue Jan 2 20:25:26 PST 2018 * Tue Jan 2 20:30:30 PST 2018 * Tue Jan 2 20:33:26 PST 2018 * Tue Jan 2 20:41:42 PST 2018 * ppo lane keep * simplify discrete actions * Tue Jan 2 21:41:05 PST 2018 * Tue Jan 2 21:49:03 PST 2018 * Tue Jan 2 22:12:23 PST 2018 * Tue Jan 2 22:14:42 PST 2018 * Tue Jan 2 22:20:59 PST 2018 * Tue Jan 2 22:23:43 PST 2018 * Tue Jan 2 22:26:27 PST 2018 * Tue Jan 2 22:27:20 PST 2018 * Tue Jan 2 22:44:00 PST 2018 * Tue Jan 2 22:57:58 PST 2018 * Tue Jan 2 23:08:51 PST 2018 * Tue Jan 2 23:11:32 PST 2018 * update dqn reward * Thu Jan 4 12:29:40 PST 2018 * Thu Jan 4 12:30:26 PST 2018 * Update train_dqn.py * fix
This commit is contained in:
committed by
Philipp Moritz
parent
088f01496c
commit
c60ccbad46
@@ -10,7 +10,8 @@ import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator
|
||||
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \
|
||||
GPURemoteA3CEvaluator
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
@@ -39,6 +40,8 @@ DEFAULT_CONFIG = {
|
||||
"vf_loss_coeff": 0.5,
|
||||
# Entropy coefficient
|
||||
"entropy_coeff": -0.01,
|
||||
# Whether to place workers on GPUs
|
||||
"use_gpu_for_workers": False,
|
||||
# Model and preprocessor options
|
||||
"model": {
|
||||
# (Image statespace) - Converts image to Channels = 1
|
||||
@@ -54,21 +57,27 @@ DEFAULT_CONFIG = {
|
||||
"optimizer": {
|
||||
# Number of gradients applied for each `train` step
|
||||
"grads_per_step": 100,
|
||||
}
|
||||
},
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
}
|
||||
|
||||
|
||||
class A3CAgent(Agent):
|
||||
_agent_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_subkeys = ["model", "optimizer"]
|
||||
_allow_unknown_subkeys = ["model", "optimizer", "env_config"]
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = A3CEvaluator(
|
||||
self.registry, self.env_creator, self.config, self.logdir,
|
||||
start_sampler=False)
|
||||
if self.config["use_gpu_for_workers"]:
|
||||
remote_cls = GPURemoteA3CEvaluator
|
||||
else:
|
||||
remote_cls = RemoteA3CEvaluator
|
||||
self.remote_evaluators = [
|
||||
RemoteA3CEvaluator.remote(
|
||||
remote_cls.remote(
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncOptimizer(
|
||||
|
||||
@@ -29,7 +29,7 @@ class A3CEvaluator(Evaluator):
|
||||
def __init__(
|
||||
self, registry, env_creator, config, logdir, start_sampler=True):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env_creator(), config["model"])
|
||||
registry, env_creator(config["env_config"]), config["model"])
|
||||
self.env = env
|
||||
policy_cls = get_policy_cls(config)
|
||||
# TODO(rliaw): should change this to be just env.observation_space
|
||||
@@ -116,3 +116,4 @@ class A3CEvaluator(Evaluator):
|
||||
|
||||
|
||||
RemoteA3CEvaluator = ray.remote(A3CEvaluator)
|
||||
GPURemoteA3CEvaluator = ray.remote(num_gpus=1)(A3CEvaluator)
|
||||
|
||||
@@ -78,7 +78,8 @@ class TFPolicy(Policy):
|
||||
|
||||
# TODO(rliaw): Can consider exposing these parameters
|
||||
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2,
|
||||
gpu_options=tf.GPUOptions(allow_growth=True)))
|
||||
self.variables = ray.experimental.TensorFlowVariables(self.loss,
|
||||
self.sess)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -46,8 +46,6 @@ def _deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
if not new_keys_allowed:
|
||||
raise Exception(
|
||||
"Unknown config parameter `{}` ".format(k))
|
||||
else:
|
||||
logger.warn("`{}` not in default configuration...".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if k in whitelist:
|
||||
_deep_update(original[k], value, True, [])
|
||||
@@ -98,9 +96,9 @@ class Agent(Trainable):
|
||||
self.env_creator = registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda: gym.make(env)
|
||||
self.env_creator = lambda env_config: gym.make(env)
|
||||
else:
|
||||
self.env_creator = lambda: None
|
||||
self.env_creator = lambda env_config: None
|
||||
self.config = self._default_config.copy()
|
||||
self.registry = registry
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ DEFAULT_CONFIG = dict(
|
||||
model={},
|
||||
# Discount factor for the MDP
|
||||
gamma=0.99,
|
||||
# Arguments to pass to the env creator
|
||||
env_config={},
|
||||
|
||||
# === Exploration ===
|
||||
# Max num timesteps for annealing schedules. Exploration is annealed from
|
||||
@@ -107,7 +109,8 @@ DEFAULT_CONFIG = dict(
|
||||
|
||||
class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_allow_unknown_subkeys = ["model", "optimizer", "tf_session_args"]
|
||||
_allow_unknown_subkeys = [
|
||||
"model", "optimizer", "tf_session_args", "env_config"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
|
||||
@@ -18,7 +18,7 @@ class DQNEvaluator(TFMultiGPUSupport):
|
||||
TODO(rliaw): Support observation/reward filters?"""
|
||||
|
||||
def __init__(self, registry, env_creator, config, logdir):
|
||||
env = env_creator()
|
||||
env = env_creator(config["env_config"])
|
||||
env = wrap_dqn(registry, env, config["model"])
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
@@ -37,7 +37,8 @@ DEFAULT_CONFIG = dict(
|
||||
return_proc_mode="centered_rank",
|
||||
num_workers=10,
|
||||
stepsize=0.01,
|
||||
observation_filter="MeanStdFilter")
|
||||
observation_filter="MeanStdFilter",
|
||||
env_config={})
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -70,7 +71,7 @@ class Worker(object):
|
||||
self.policy_params = policy_params
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = env_creator()
|
||||
self.env = env_creator(config["env_config"])
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(registry, self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
@@ -135,13 +136,14 @@ class Worker(object):
|
||||
class ESAgent(Agent):
|
||||
_agent_name = "ES"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_subkeys = ["env_config"]
|
||||
|
||||
def _init(self):
|
||||
policy_params = {
|
||||
"action_noise_std": 0.01
|
||||
}
|
||||
|
||||
env = self.env_creator()
|
||||
env = self.env_creator(self.config["env_config"])
|
||||
preprocessor = ModelCatalog.get_preprocessor(self.registry, env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
|
||||
@@ -9,9 +9,7 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||
|
||||
from ray.rllib.models.action_dist import (
|
||||
Categorical, Deterministic, DiagGaussian)
|
||||
from ray.rllib.models.preprocessors import (
|
||||
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor,
|
||||
OneHotPreprocessor)
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
|
||||
@@ -48,9 +46,6 @@ class ModelCatalog(object):
|
||||
>>> action = dist.sample()
|
||||
"""
|
||||
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
@staticmethod
|
||||
def get_action_dist(action_space, dist_type=None):
|
||||
"""Returns action distribution class and size for the given action space.
|
||||
@@ -147,40 +142,19 @@ class ModelCatalog(object):
|
||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||
"""
|
||||
|
||||
# For older gym versions that don't set shape for Discrete
|
||||
if not hasattr(env.observation_space, "shape") and \
|
||||
isinstance(env.observation_space, gym.spaces.Discrete):
|
||||
env.observation_space.shape = ()
|
||||
|
||||
obs_shape = env.observation_space.shape
|
||||
|
||||
for k in options.keys():
|
||||
if k not in MODEL_CONFIGS:
|
||||
raise Exception(
|
||||
"Unknown config key `{}`, all keys: {}".format(
|
||||
k, MODEL_CONFIGS))
|
||||
|
||||
print("Observation shape is {}".format(obs_shape))
|
||||
|
||||
if "custom_preprocessor" in options:
|
||||
preprocessor = options["custom_preprocessor"]
|
||||
print("Using custom preprocessor {}".format(preprocessor))
|
||||
return registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
env.observation_space, options)
|
||||
|
||||
if obs_shape == ():
|
||||
print("Using one-hot preprocessor for discrete envs.")
|
||||
preprocessor = OneHotPreprocessor
|
||||
elif obs_shape == ModelCatalog.ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
preprocessor = AtariPixelPreprocessor
|
||||
elif obs_shape == ModelCatalog.ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
preprocessor = AtariRamPreprocessor
|
||||
else:
|
||||
print("Not using any observation preprocessor.")
|
||||
preprocessor = NoPreprocessor
|
||||
|
||||
preprocessor = get_preprocessor(env.observation_space)
|
||||
return preprocessor(env.observation_space, options)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -3,6 +3,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import cv2
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
|
||||
class Preprocessor(object):
|
||||
@@ -13,6 +17,7 @@ class Preprocessor(object):
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, options):
|
||||
legacy_patch_shapes(obs_space)
|
||||
self._obs_space = obs_space
|
||||
self._options = options
|
||||
self._init()
|
||||
@@ -40,7 +45,6 @@ class AtariPixelPreprocessor(Preprocessor):
|
||||
if self._channel_major:
|
||||
self.shape = self.shape[-1:] + self.shape[:-1]
|
||||
|
||||
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
|
||||
def transform(self, observation):
|
||||
"""Downsamples images from (210, 160, 3) by the configured factor."""
|
||||
scaled = observation[25:-25, :, :]
|
||||
@@ -64,7 +68,6 @@ class AtariPixelPreprocessor(Preprocessor):
|
||||
return scaled
|
||||
|
||||
|
||||
# TODO(rliaw): Also should include the deepmind preprocessor
|
||||
class AtariRamPreprocessor(Preprocessor):
|
||||
def _init(self):
|
||||
self.shape = (128,)
|
||||
@@ -90,3 +93,75 @@ class NoPreprocessor(Preprocessor):
|
||||
|
||||
def transform(self, observation):
|
||||
return observation
|
||||
|
||||
|
||||
class TupleFlatteningPreprocessor(Preprocessor):
|
||||
"""Preprocesses each tuple element, then flattens it all into a vector.
|
||||
|
||||
If desired, the vector output can be unpacked via tf.reshape() within a
|
||||
custom model to handle each component separately.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
assert isinstance(self._obs_space, gym.spaces.Tuple)
|
||||
size = 0
|
||||
self.preprocessors = []
|
||||
for i in range(len(self._obs_space.spaces)):
|
||||
space = self._obs_space.spaces[i]
|
||||
print("Creating sub-preprocessor for", space)
|
||||
preprocessor = get_preprocessor(space)(space, self._options)
|
||||
self.preprocessors.append(preprocessor)
|
||||
size += np.product(preprocessor.shape)
|
||||
self.shape = (size,)
|
||||
|
||||
def transform(self, observation):
|
||||
assert len(observation) == len(self.preprocessors), observation
|
||||
return np.concatenate([
|
||||
np.reshape(p.transform(o), [np.product(p.shape)])
|
||||
for (o, p) in zip(observation, self.preprocessors)])
|
||||
|
||||
|
||||
def get_preprocessor(space):
|
||||
"""Returns an appropriate preprocessor class for the given space."""
|
||||
|
||||
legacy_patch_shapes(space)
|
||||
obs_shape = space.shape
|
||||
print("Observation shape is {}".format(obs_shape))
|
||||
|
||||
if obs_shape == ():
|
||||
print("Using one-hot preprocessor for discrete envs.")
|
||||
preprocessor = OneHotPreprocessor
|
||||
elif obs_shape == ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
preprocessor = AtariPixelPreprocessor
|
||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
preprocessor = AtariRamPreprocessor
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
print("Using a TupleFlatteningPreprocessor")
|
||||
preprocessor = TupleFlatteningPreprocessor
|
||||
else:
|
||||
print("Not using any observation preprocessor.")
|
||||
preprocessor = NoPreprocessor
|
||||
|
||||
return preprocessor
|
||||
|
||||
|
||||
def legacy_patch_shapes(space):
|
||||
"""Assigns shapes to spaces that don't have shapes.
|
||||
|
||||
This is only needed for older gym versions that don't set shapes properly
|
||||
for Tuple and Discrete spaces.
|
||||
"""
|
||||
|
||||
if not hasattr(space, "shape"):
|
||||
if isinstance(space, gym.spaces.Discrete):
|
||||
space.shape = ()
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
shapes = []
|
||||
for s in space.spaces:
|
||||
shape = legacy_patch_shapes(s)
|
||||
shapes.append(shape)
|
||||
space.shape = tuple(shapes)
|
||||
|
||||
return space.shape
|
||||
|
||||
@@ -78,12 +78,14 @@ DEFAULT_CONFIG = {
|
||||
"tf_debug_inf_or_nan": False,
|
||||
# If True, we write tensorflow logs and checkpoints
|
||||
"write_logs": True,
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
}
|
||||
|
||||
|
||||
class PPOAgent(Agent):
|
||||
_agent_name = "PPO"
|
||||
_allow_unknown_subkeys = ["model", "tf_session_args"]
|
||||
_allow_unknown_subkeys = ["model", "tf_session_args", "env_config"]
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
|
||||
@@ -43,7 +43,7 @@ class PPOEvaluator(Evaluator):
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env_creator(), config["model"])
|
||||
registry, env_creator(config["env_config"]), config["model"])
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@ import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import unittest
|
||||
from gym.spaces import Box, Discrete, Tuple
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import get_registry
|
||||
@@ -34,11 +35,25 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
def testGymPreprocessors(self):
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), gym.make("CartPole-v0"))
|
||||
assert type(p1) == NoPreprocessor
|
||||
self.assertEqual(type(p1), NoPreprocessor)
|
||||
|
||||
p2 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), gym.make("FrozenLake-v0"))
|
||||
assert type(p2) == OneHotPreprocessor
|
||||
self.assertEqual(type(p2), OneHotPreprocessor)
|
||||
|
||||
def testTuplePreprocessor(self):
|
||||
ray.init()
|
||||
|
||||
class TupleEnv(object):
|
||||
def __init__(self):
|
||||
self.observation_space = Tuple(
|
||||
[Discrete(5), Box(0, 1, shape=(3,))])
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), TupleEnv())
|
||||
self.assertEqual(p1.shape, (8,))
|
||||
self.assertEqual(
|
||||
list(p1.transform((0, [1, 2, 3]))),
|
||||
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]])
|
||||
|
||||
def testCustomPreprocessor(self):
|
||||
ray.init()
|
||||
@@ -47,12 +62,12 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
env = gym.make("CartPole-v0")
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), env, {"custom_preprocessor": "foo"})
|
||||
assert type(p1) == CustomPreprocessor
|
||||
self.assertEqual(str(type(p1)), str(CustomPreprocessor))
|
||||
p2 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), env, {"custom_preprocessor": "bar"})
|
||||
assert type(p2) == CustomPreprocessor2
|
||||
self.assertEqual(str(type(p2)), str(CustomPreprocessor2))
|
||||
p3 = ModelCatalog.get_preprocessor(get_registry(), env)
|
||||
assert type(p3) == NoPreprocessor
|
||||
self.assertEqual(type(p3), NoPreprocessor)
|
||||
|
||||
def testDefaultModels(self):
|
||||
ray.init()
|
||||
@@ -60,19 +75,19 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
with tf.variable_scope("test1"):
|
||||
p1 = ModelCatalog.get_model(
|
||||
get_registry(), np.zeros((10, 3), dtype=np.float32), 5)
|
||||
assert type(p1) == FullyConnectedNetwork
|
||||
self.assertEqual(type(p1), FullyConnectedNetwork)
|
||||
|
||||
with tf.variable_scope("test2"):
|
||||
p2 = ModelCatalog.get_model(
|
||||
get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5)
|
||||
assert type(p2) == VisionNetwork
|
||||
self.assertEqual(type(p2), VisionNetwork)
|
||||
|
||||
def testCustomModel(self):
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_model("foo", CustomModel)
|
||||
p1 = ModelCatalog.get_model(
|
||||
get_registry(), 1, 5, {"custom_model": "foo"})
|
||||
assert type(p1) == CustomModel
|
||||
self.assertEqual(str(type(p1)), str(CustomModel))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -193,8 +193,12 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
|
||||
terminal condition, and other fields as dictated by `policy`.
|
||||
"""
|
||||
last_observation = obs_filter(env.reset())
|
||||
horizon = horizon if horizon else env.spec.tags.get(
|
||||
"wrapper_config.TimeLimit.max_episode_steps")
|
||||
try:
|
||||
horizon = horizon if horizon else env.spec.tags.get(
|
||||
"wrapper_config.TimeLimit.max_episode_steps")
|
||||
except Exception:
|
||||
print("Warning, no horizon specified, assuming infinite")
|
||||
horizon = 999999
|
||||
assert horizon > 0
|
||||
if hasattr(policy, "get_initial_features"):
|
||||
last_features = policy.get_initial_features()
|
||||
|
||||
@@ -306,6 +306,8 @@ class Trial(object):
|
||||
|
||||
def logger_creator(config):
|
||||
# Set the working dir in the remote process, for user file writes
|
||||
if not os.path.exists(remote_logdir):
|
||||
os.makedirs(remote_logdir)
|
||||
os.chdir(remote_logdir)
|
||||
return NoopLogger(config, remote_logdir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user