[rllib] Allow envs to be auto-registered; add on_train_result callback with curriculum example (#3451)

* train step and docs

* debug

* doc

* doc

* fix examples

* fix code

* integration test

* fix

* ...

* space

* instance

* Update .travis.yml

* fix test
This commit is contained in:
Eric Liang
2018-12-03 23:15:43 -08:00
committed by GitHub
parent be6567e6fd
commit ce355d13d4
13 changed files with 207 additions and 258 deletions
+25 -8
View File
@@ -2,12 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import logging
import pickle
import tempfile
from datetime import datetime
import copy
import logging
import os
import pickle
import six
import tempfile
import tensorflow as tf
import ray
@@ -15,7 +16,7 @@ from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources
from ray.tune.logger import UnifiedLogger
@@ -40,6 +41,7 @@ COMMON_CONFIG = {
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"agent": ..., "result": ...}
},
# === Policy ===
@@ -277,7 +279,7 @@ class Agent(Trainable):
self.global_vars = {"timestep": 0}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
self._env_id = _register_if_needed(env or config.get("env"))
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
@@ -319,7 +321,13 @@ class Agent(Trainable):
logger.debug("synchronized filters: {}".format(
self.local_evaluator.filters))
return Trainable.train(self)
result = Trainable.train(self)
if self.config["callbacks"].get("on_train_result"):
self.config["callbacks"]["on_train_result"]({
"agent": self,
"result": result,
})
return result
def _setup(self, config):
env = self._env_id
@@ -447,6 +455,15 @@ class Agent(Trainable):
self.__setstate__(extra_data)
def _register_if_needed(env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
@@ -1,51 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
from models import register_carla_model
from scenarios import LANE_KEEP
env_name = "carla_env"
env_config = ENV_CONFIG.copy()
env_config.update({
"verbose": False,
"x_res": 80,
"y_res": 80,
"use_depth_camera": False,
"discrete_actions": False,
"server_map": "/Game/Maps/Town02",
"reward_function": "lane_keep",
"enable_planner": False,
"scenarios": [LANE_KEEP],
})
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-a3c": {
"run": "A3C",
"env": "carla_env",
"config": {
"env_config": env_config,
"model": {
"custom_model": "carla",
"custom_options": {
"image_shape": [80, 80, 6],
},
"conv_filters": [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1],
],
},
"gamma": 0.8,
"num_workers": 1,
},
},
})
@@ -1,53 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
from models import register_carla_model
from scenarios import LANE_KEEP
env_name = "carla_env"
env_config = ENV_CONFIG.copy()
env_config.update({
"verbose": False,
"x_res": 80,
"y_res": 80,
"use_depth_camera": False,
"discrete_actions": True,
"server_map": "/Game/Maps/Town02",
"reward_function": "lane_keep",
"enable_planner": False,
"scenarios": [LANE_KEEP],
})
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-dqn": {
"run": "DQN",
"env": "carla_env",
"config": {
"env_config": env_config,
"model": {
"custom_model": "carla",
"custom_options": {
"image_shape": [80, 80, 6],
},
"conv_filters": [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1],
],
},
"timesteps_per_iteration": 100,
"learning_starts": 1000,
"schedule_max_timesteps": 100000,
"gamma": 0.8,
},
},
})
@@ -1,63 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
from models import register_carla_model
from scenarios import LANE_KEEP
env_name = "carla_env"
env_config = ENV_CONFIG.copy()
env_config.update({
"verbose": False,
"x_res": 80,
"y_res": 80,
"use_depth_camera": False,
"discrete_actions": False,
"server_map": "/Game/Maps/Town02",
"reward_function": "lane_keep",
"enable_planner": False,
"scenarios": [LANE_KEEP],
})
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-ppo": {
"run": "PPO",
"env": "carla_env",
"config": {
"env_config": env_config,
"model": {
"custom_model": "carla",
"custom_options": {
"image_shape": [80, 80, 6],
},
"conv_filters": [
[16, [8, 8], 4],
[32, [4, 4], 2],
[512, [10, 10], 1],
],
},
"num_workers": 1,
"timesteps_per_batch": 2000,
"min_steps_per_task": 100,
"lambda": 0.95,
"clip_param": 0.2,
"num_sgd_iter": 20,
"sgd_stepsize": 0.0001,
"sgd_batchsize": 32,
"devices": ["/gpu:0"],
"tf_session_args": {
"gpu_options": {
"allow_growth": True
}
}
},
},
})
+2 -4
View File
@@ -3,13 +3,12 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.tune import grid_search, register_env, run_experiments
from ray.tune import grid_search, run_experiments
from env import CarlaEnv, ENV_CONFIG
from models import register_carla_model
from scenarios import TOWN2_STRAIGHT
env_name = "carla_env"
env_config = ENV_CONFIG.copy()
env_config.update({
"verbose": False,
@@ -23,7 +22,6 @@ env_config.update({
"scenarios": TOWN2_STRAIGHT,
})
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
redis_address = ray.services.get_node_ip_address() + ":6379"
@@ -31,7 +29,7 @@ ray.init(redis_address=redis_address)
run_experiments({
"carla-a3c": {
"run": "A3C",
"env": "carla_env",
"env": CarlaEnv,
"config": {
"env_config": env_config,
"use_gpu_for_workers": True,
+2 -4
View File
@@ -3,13 +3,12 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from ray.tune import run_experiments
from env import CarlaEnv, ENV_CONFIG
from models import register_carla_model
from scenarios import TOWN2_ONE_CURVE
env_name = "carla_env"
env_config = ENV_CONFIG.copy()
env_config.update({
"verbose": False,
@@ -21,7 +20,6 @@ env_config.update({
"scenarios": TOWN2_ONE_CURVE,
})
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
@@ -35,7 +33,7 @@ def shape_out(spec):
run_experiments({
"carla-dqn": {
"run": "DQN",
"env": "carla_env",
"env": CarlaEnv,
"config": {
"env_config": env_config,
"model": {
+2 -4
View File
@@ -3,13 +3,12 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from ray.tune import run_experiments
from env import CarlaEnv, ENV_CONFIG
from models import register_carla_model
from scenarios import TOWN2_STRAIGHT
env_name = "carla_env"
env_config = ENV_CONFIG.copy()
env_config.update({
"verbose": False,
@@ -20,14 +19,13 @@ env_config.update({
"server_map": "/Game/Maps/Town02",
"scenarios": TOWN2_STRAIGHT,
})
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init(redirect_output=True)
run_experiments({
"carla": {
"run": "PPO",
"env": "carla_env",
"env": CarlaEnv,
"config": {
"env_config": env_config,
"model": {
+3 -4
View File
@@ -11,7 +11,6 @@ from gym.envs.registration import EnvSpec
import ray
from ray.tune import run_experiments
from ray.tune.registry import register_env
class SimpleCorridor(gym.Env):
@@ -42,13 +41,13 @@ class SimpleCorridor(gym.Env):
if __name__ == "__main__":
env_creator_name = "corridor"
register_env(env_creator_name, lambda config: SimpleCorridor(config))
# Can also register the env creator function explicitly with:
# register_env("corridor", lambda config: SimpleCorridor(config))
ray.init()
run_experiments({
"demo": {
"run": "PPO",
"env": "corridor",
"env": SimpleCorridor, # or "corridor" if registered above
"config": {
"env_config": {
"corridor_length": 5,
@@ -35,6 +35,13 @@ def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
def on_train_result(info):
print("agent.train() result: {} -> {} episodes".format(
info["agent"], info["result"]["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=2000)
@@ -54,6 +61,7 @@ if __name__ == "__main__":
"on_episode_step": tune.function(on_episode_step),
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
"on_train_result": tune.function(on_train_result),
},
},
}
@@ -64,3 +72,4 @@ if __name__ == "__main__":
print(custom_metrics)
assert "mean_pole_angle" in custom_metrics
assert type(custom_metrics["mean_pole_angle"]) is float
assert "callback_ok" in trials[0].last_result
+4 -2
View File
@@ -314,8 +314,10 @@ class Trial(object):
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
if "env" in self.config:
identifier = "{}_{}".format(self.trainable_name,
self.config["env"])
env = self.config["env"]
if isinstance(env, type):
env = env.__name__
identifier = "{}_{}".format(self.trainable_name, env)
else:
identifier = self.trainable_name
if self.experiment_tag: