mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:01:10 +08:00
[rllib] Allow envs to be auto-registered; add on_train_result callback with curriculum example (#3451)
* train step and docs * debug * doc * doc * fix examples * fix code * integration test * fix * ... * space * instance * Update .travis.yml * fix test
This commit is contained in:
@@ -2,12 +2,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import logging
|
||||
import pickle
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
@@ -15,7 +16,7 @@ from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
@@ -40,6 +41,7 @@ COMMON_CONFIG = {
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_train_result": None, # arg: {"agent": ..., "result": ...}
|
||||
},
|
||||
|
||||
# === Policy ===
|
||||
@@ -277,7 +279,7 @@ class Agent(Trainable):
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = env or config.get("env")
|
||||
self._env_id = _register_if_needed(env or config.get("env"))
|
||||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
@@ -319,7 +321,13 @@ class Agent(Trainable):
|
||||
logger.debug("synchronized filters: {}".format(
|
||||
self.local_evaluator.filters))
|
||||
|
||||
return Trainable.train(self)
|
||||
result = Trainable.train(self)
|
||||
if self.config["callbacks"].get("on_train_result"):
|
||||
self.config["callbacks"]["on_train_result"]({
|
||||
"agent": self,
|
||||
"result": result,
|
||||
})
|
||||
return result
|
||||
|
||||
def _setup(self, config):
|
||||
env = self._env_id
|
||||
@@ -447,6 +455,15 @@ class Agent(Trainable):
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
|
||||
def _register_if_needed(env_object):
|
||||
if isinstance(env_object, six.string_types):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
from models import register_carla_model
|
||||
from scenarios import LANE_KEEP
|
||||
|
||||
env_name = "carla_env"
|
||||
env_config = ENV_CONFIG.copy()
|
||||
env_config.update({
|
||||
"verbose": False,
|
||||
"x_res": 80,
|
||||
"y_res": 80,
|
||||
"use_depth_camera": False,
|
||||
"discrete_actions": False,
|
||||
"server_map": "/Game/Maps/Town02",
|
||||
"reward_function": "lane_keep",
|
||||
"enable_planner": False,
|
||||
"scenarios": [LANE_KEEP],
|
||||
})
|
||||
|
||||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-a3c": {
|
||||
"run": "A3C",
|
||||
"env": "carla_env",
|
||||
"config": {
|
||||
"env_config": env_config,
|
||||
"model": {
|
||||
"custom_model": "carla",
|
||||
"custom_options": {
|
||||
"image_shape": [80, 80, 6],
|
||||
},
|
||||
"conv_filters": [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [10, 10], 1],
|
||||
],
|
||||
},
|
||||
"gamma": 0.8,
|
||||
"num_workers": 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1,53 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
from models import register_carla_model
|
||||
from scenarios import LANE_KEEP
|
||||
|
||||
env_name = "carla_env"
|
||||
env_config = ENV_CONFIG.copy()
|
||||
env_config.update({
|
||||
"verbose": False,
|
||||
"x_res": 80,
|
||||
"y_res": 80,
|
||||
"use_depth_camera": False,
|
||||
"discrete_actions": True,
|
||||
"server_map": "/Game/Maps/Town02",
|
||||
"reward_function": "lane_keep",
|
||||
"enable_planner": False,
|
||||
"scenarios": [LANE_KEEP],
|
||||
})
|
||||
|
||||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-dqn": {
|
||||
"run": "DQN",
|
||||
"env": "carla_env",
|
||||
"config": {
|
||||
"env_config": env_config,
|
||||
"model": {
|
||||
"custom_model": "carla",
|
||||
"custom_options": {
|
||||
"image_shape": [80, 80, 6],
|
||||
},
|
||||
"conv_filters": [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [10, 10], 1],
|
||||
],
|
||||
},
|
||||
"timesteps_per_iteration": 100,
|
||||
"learning_starts": 1000,
|
||||
"schedule_max_timesteps": 100000,
|
||||
"gamma": 0.8,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
from models import register_carla_model
|
||||
from scenarios import LANE_KEEP
|
||||
|
||||
env_name = "carla_env"
|
||||
env_config = ENV_CONFIG.copy()
|
||||
env_config.update({
|
||||
"verbose": False,
|
||||
"x_res": 80,
|
||||
"y_res": 80,
|
||||
"use_depth_camera": False,
|
||||
"discrete_actions": False,
|
||||
"server_map": "/Game/Maps/Town02",
|
||||
"reward_function": "lane_keep",
|
||||
"enable_planner": False,
|
||||
"scenarios": [LANE_KEEP],
|
||||
})
|
||||
|
||||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-ppo": {
|
||||
"run": "PPO",
|
||||
"env": "carla_env",
|
||||
"config": {
|
||||
"env_config": env_config,
|
||||
"model": {
|
||||
"custom_model": "carla",
|
||||
"custom_options": {
|
||||
"image_shape": [80, 80, 6],
|
||||
},
|
||||
"conv_filters": [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[512, [10, 10], 1],
|
||||
],
|
||||
},
|
||||
"num_workers": 1,
|
||||
"timesteps_per_batch": 2000,
|
||||
"min_steps_per_task": 100,
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
"num_sgd_iter": 20,
|
||||
"sgd_stepsize": 0.0001,
|
||||
"sgd_batchsize": 32,
|
||||
"devices": ["/gpu:0"],
|
||||
"tf_session_args": {
|
||||
"gpu_options": {
|
||||
"allow_growth": True
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -3,13 +3,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import grid_search, register_env, run_experiments
|
||||
from ray.tune import grid_search, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
from models import register_carla_model
|
||||
from scenarios import TOWN2_STRAIGHT
|
||||
|
||||
env_name = "carla_env"
|
||||
env_config = ENV_CONFIG.copy()
|
||||
env_config.update({
|
||||
"verbose": False,
|
||||
@@ -23,7 +22,6 @@ env_config.update({
|
||||
"scenarios": TOWN2_STRAIGHT,
|
||||
})
|
||||
|
||||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
redis_address = ray.services.get_node_ip_address() + ":6379"
|
||||
|
||||
@@ -31,7 +29,7 @@ ray.init(redis_address=redis_address)
|
||||
run_experiments({
|
||||
"carla-a3c": {
|
||||
"run": "A3C",
|
||||
"env": "carla_env",
|
||||
"env": CarlaEnv,
|
||||
"config": {
|
||||
"env_config": env_config,
|
||||
"use_gpu_for_workers": True,
|
||||
|
||||
@@ -3,13 +3,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
from ray.tune import run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
from models import register_carla_model
|
||||
from scenarios import TOWN2_ONE_CURVE
|
||||
|
||||
env_name = "carla_env"
|
||||
env_config = ENV_CONFIG.copy()
|
||||
env_config.update({
|
||||
"verbose": False,
|
||||
@@ -21,7 +20,6 @@ env_config.update({
|
||||
"scenarios": TOWN2_ONE_CURVE,
|
||||
})
|
||||
|
||||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
@@ -35,7 +33,7 @@ def shape_out(spec):
|
||||
run_experiments({
|
||||
"carla-dqn": {
|
||||
"run": "DQN",
|
||||
"env": "carla_env",
|
||||
"env": CarlaEnv,
|
||||
"config": {
|
||||
"env_config": env_config,
|
||||
"model": {
|
||||
|
||||
@@ -3,13 +3,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
from ray.tune import run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
from models import register_carla_model
|
||||
from scenarios import TOWN2_STRAIGHT
|
||||
|
||||
env_name = "carla_env"
|
||||
env_config = ENV_CONFIG.copy()
|
||||
env_config.update({
|
||||
"verbose": False,
|
||||
@@ -20,14 +19,13 @@ env_config.update({
|
||||
"server_map": "/Game/Maps/Town02",
|
||||
"scenarios": TOWN2_STRAIGHT,
|
||||
})
|
||||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init(redirect_output=True)
|
||||
run_experiments({
|
||||
"carla": {
|
||||
"run": "PPO",
|
||||
"env": "carla_env",
|
||||
"env": CarlaEnv,
|
||||
"config": {
|
||||
"env_config": env_config,
|
||||
"model": {
|
||||
|
||||
@@ -11,7 +11,6 @@ from gym.envs.registration import EnvSpec
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class SimpleCorridor(gym.Env):
|
||||
@@ -42,13 +41,13 @@ class SimpleCorridor(gym.Env):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_creator_name = "corridor"
|
||||
register_env(env_creator_name, lambda config: SimpleCorridor(config))
|
||||
# Can also register the env creator function explicitly with:
|
||||
# register_env("corridor", lambda config: SimpleCorridor(config))
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"demo": {
|
||||
"run": "PPO",
|
||||
"env": "corridor",
|
||||
"env": SimpleCorridor, # or "corridor" if registered above
|
||||
"config": {
|
||||
"env_config": {
|
||||
"corridor_length": 5,
|
||||
|
||||
@@ -35,6 +35,13 @@ def on_sample_end(info):
|
||||
print("returned sample batch of size {}".format(info["samples"].count))
|
||||
|
||||
|
||||
def on_train_result(info):
|
||||
print("agent.train() result: {} -> {} episodes".format(
|
||||
info["agent"], info["result"]["episodes_this_iter"]))
|
||||
# you can mutate the result dict to add new fields to return
|
||||
info["result"]["callback_ok"] = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=2000)
|
||||
@@ -54,6 +61,7 @@ if __name__ == "__main__":
|
||||
"on_episode_step": tune.function(on_episode_step),
|
||||
"on_episode_end": tune.function(on_episode_end),
|
||||
"on_sample_end": tune.function(on_sample_end),
|
||||
"on_train_result": tune.function(on_train_result),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -64,3 +72,4 @@ if __name__ == "__main__":
|
||||
print(custom_metrics)
|
||||
assert "mean_pole_angle" in custom_metrics
|
||||
assert type(custom_metrics["mean_pole_angle"]) is float
|
||||
assert "callback_ok" in trials[0].last_result
|
||||
|
||||
@@ -314,8 +314,10 @@ class Trial(object):
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(self.trainable_name,
|
||||
self.config["env"])
|
||||
env = self.config["env"]
|
||||
if isinstance(env, type):
|
||||
env = env.__name__
|
||||
identifier = "{}_{}".format(self.trainable_name, env)
|
||||
else:
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
|
||||
Reference in New Issue
Block a user