mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:45:44 +08:00
[rllib] Refactor to support passing custom env_creator function (#1096)
* refactor to use env creator * doc * lint
This commit is contained in:
committed by
Philipp Moritz
parent
1837824881
commit
b1660c4edf
+12
-11
@@ -10,7 +10,7 @@ import os
|
||||
|
||||
import ray
|
||||
from ray.rllib.a3c.runner import RunnerThread, process_rollout
|
||||
from ray.rllib.a3c.envs import create_env
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
|
||||
from ray.rllib.a3c.shared_model import SharedModel
|
||||
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
|
||||
@@ -33,9 +33,9 @@ class Runner(object):
|
||||
|
||||
The gradient computation is also executed from this object.
|
||||
"""
|
||||
def __init__(self, env_name, policy_cls, actor_id, batch_size,
|
||||
def __init__(self, env_creator, policy_cls, actor_id, batch_size,
|
||||
preprocess_config, logdir):
|
||||
env = create_env(env_name, preprocess_config)
|
||||
env = create_and_wrap(env_creator, preprocess_config)
|
||||
self.id = actor_id
|
||||
# TODO(rliaw): should change this to be just env.observation_space
|
||||
self.policy = policy_cls(env.observation_space.shape, env.action_space)
|
||||
@@ -95,20 +95,21 @@ class Runner(object):
|
||||
|
||||
|
||||
class A3CAgent(Agent):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "A3C"})
|
||||
Agent.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
self.env = create_env(env_name, config["model"])
|
||||
if config.get("use_lstm", True):
|
||||
_agent_name = "A3C"
|
||||
|
||||
def _init(self):
|
||||
self.env = create_and_wrap(self.env_creator, self.config["model"])
|
||||
if self.config["use_lstm"]:
|
||||
policy_cls = SharedModelLSTM
|
||||
else:
|
||||
policy_cls = SharedModel
|
||||
self.policy = policy_cls(
|
||||
self.env.observation_space.shape, self.env.action_space)
|
||||
self.agents = [
|
||||
Runner.remote(env_name, policy_cls, i,
|
||||
config["batch_size"], config["model"], self.logdir)
|
||||
for i in range(config["num_workers"])]
|
||||
Runner.remote(self.env_creator, policy_cls, i,
|
||||
self.config["batch_size"],
|
||||
self.config["model"], self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
|
||||
def _train(self):
|
||||
|
||||
@@ -13,9 +13,9 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def create_env(env_id, options):
|
||||
env = gym.make(env_id)
|
||||
env = RLLibPreprocessing(env_id, env, options)
|
||||
def create_and_wrap(env_creator, options):
|
||||
env = env_creator()
|
||||
env = RLLibPreprocessing(env.spec.id, env, options)
|
||||
env = Diagnostic(env)
|
||||
return env
|
||||
|
||||
|
||||
+53
-22
@@ -1,5 +1,6 @@
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
@@ -9,7 +10,10 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import gym
|
||||
import smart_open
|
||||
import tensorflow as tf
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cStringIO as StringIO
|
||||
@@ -118,28 +122,35 @@ class Agent(object):
|
||||
you should create a new agent instance for each training session.
|
||||
|
||||
Attributes:
|
||||
env_name (str): Name of the OpenAI gym environment to train against.
|
||||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
"""
|
||||
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
def __init__(self, env_creator, config, upload_dir=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
env_name (str): The name of the OpenAI gym environment to use.
|
||||
env_creator (str|func): Name of the OpenAI gym environment to train
|
||||
against, or a function that creates such an env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
upload_dir (str): Root directory into which the output directory
|
||||
should be placed. Can be local like file:///tmp/ray/ or on S3
|
||||
like s3://bucketname/.
|
||||
"""
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
upload_dir = "file:///tmp/ray" if upload_dir is None else upload_dir
|
||||
self.experiment_id = uuid.uuid4().hex
|
||||
self.env_name = env_name
|
||||
if type(env_creator) is str:
|
||||
env_name = env_creator
|
||||
self.env_creator = lambda: gym.make(env_name)
|
||||
else:
|
||||
env_name = "custom"
|
||||
self.env_creator = env_creator
|
||||
|
||||
self.config = config
|
||||
self.config.update({"experiment_id": self.experiment_id})
|
||||
self.config.update({"experiment_id": self._experiment_id})
|
||||
self.config.update({"env_name": env_name})
|
||||
self.config.update({"alg": self._agent_name})
|
||||
prefix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self.__class__.__name__,
|
||||
@@ -160,9 +171,17 @@ class Agent(object):
|
||||
"%s algorithm created with logdir '%s'",
|
||||
self.__class__.__name__, self.logdir)
|
||||
|
||||
self.iteration = 0
|
||||
self.time_total = 0.0
|
||||
self.timesteps_total = 0
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = 0
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
@@ -173,18 +192,18 @@ class Agent(object):
|
||||
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
self.iteration += 1
|
||||
self._iteration += 1
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
self.time_total += time_this_iter
|
||||
self.timesteps_total += result.timesteps_this_iter
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
|
||||
result = result._replace(
|
||||
experiment_id=self.experiment_id,
|
||||
training_iteration=self.iteration,
|
||||
timesteps_total=self.timesteps_total,
|
||||
experiment_id=self._experiment_id,
|
||||
training_iteration=self._iteration,
|
||||
timesteps_total=self._timesteps_total,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self.time_total)
|
||||
time_total_s=self._time_total)
|
||||
|
||||
for field in result:
|
||||
assert field is not None, result
|
||||
@@ -200,8 +219,8 @@ class Agent(object):
|
||||
|
||||
checkpoint_path = self._save()
|
||||
pickle.dump(
|
||||
[self.experiment_id, self.iteration, self.timesteps_total,
|
||||
self.time_total],
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".rllib_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
@@ -213,16 +232,28 @@ class Agent(object):
|
||||
|
||||
self._restore(checkpoint_path)
|
||||
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
|
||||
self.experiment_id = metadata[0]
|
||||
self.iteration = metadata[1]
|
||||
self.timesteps_total = metadata[2]
|
||||
self.time_total = metadata[3]
|
||||
self._experiment_id = metadata[0]
|
||||
self._iteration = metadata[1]
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""Current training iter, auto-incremented with each train() call."""
|
||||
|
||||
return self._iteration
|
||||
|
||||
@property
|
||||
def _agent_name(self):
|
||||
"""Subclasses should override this to declare their name."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
|
||||
+11
-18
@@ -4,7 +4,6 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pickle
|
||||
import os
|
||||
@@ -103,10 +102,10 @@ DEFAULT_CONFIG = dict(
|
||||
|
||||
|
||||
class Actor(object):
|
||||
def __init__(self, env_name, config, logdir):
|
||||
env = gym.make(env_name)
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
env = env_creator()
|
||||
# TODO(ekl): replace this with RLlib preprocessors
|
||||
if "NoFrameskip" in env_name:
|
||||
if "NoFrameskip" in env.spec.id:
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
self.env = env
|
||||
self.config = config
|
||||
@@ -239,27 +238,21 @@ class Actor(object):
|
||||
|
||||
@ray.remote
|
||||
class RemoteActor(Actor):
|
||||
def __init__(self, env_name, config, logdir, gpu_mask):
|
||||
def __init__(self, env_creator, config, logdir, gpu_mask):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_mask
|
||||
Actor.__init__(self, env_name, config, logdir)
|
||||
Actor.__init__(self, env_creator, config, logdir)
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "DQN"})
|
||||
_agent_name = "DQN"
|
||||
|
||||
Agent.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self._init(config, env_name)
|
||||
|
||||
def _init(self, config, env_name):
|
||||
self.actor = Actor(env_name, config, self.logdir)
|
||||
def _init(self):
|
||||
self.actor = Actor(self.env_creator, self.config, self.logdir)
|
||||
self.workers = [
|
||||
RemoteActor.remote(
|
||||
env_name, config, self.logdir,
|
||||
"{}".format(i + config["gpu_offset"]))
|
||||
for i in range(config["num_workers"])]
|
||||
self.env_creator, self.config, self.logdir,
|
||||
"{}".format(i + self.config["gpu_offset"]))
|
||||
for i in range(self.config["num_workers"])]
|
||||
|
||||
self.cur_timestep = 0
|
||||
self.num_iterations = 0
|
||||
|
||||
@@ -6,14 +6,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.models import ModelCatalog
|
||||
@@ -68,16 +65,16 @@ class SharedNoiseTable(object):
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self, config, policy_params, env_name, noise,
|
||||
def __init__(self, config, policy_params, env_creator, noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.policy_params = policy_params
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = gym.make(env_name)
|
||||
self.env = env_creator()
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
env_name, self.env.observation_space.shape)
|
||||
self.env.spec.id, self.env.observation_space.shape)
|
||||
self.preprocessor_shape = self.preprocessor.transform_shape(
|
||||
self.env.observation_space.shape)
|
||||
|
||||
@@ -161,13 +158,7 @@ class Worker(object):
|
||||
|
||||
|
||||
class ESAgent(Agent):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "EvolutionStrategies"})
|
||||
|
||||
Agent.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
_agent_name = "ES"
|
||||
|
||||
def _init(self):
|
||||
|
||||
@@ -175,9 +166,9 @@ class ESAgent(Agent):
|
||||
"ac_noise_std": 0.01
|
||||
}
|
||||
|
||||
env = gym.make(self.env_name)
|
||||
env = self.env_creator()
|
||||
preprocessor = ModelCatalog.get_preprocessor(
|
||||
self.env_name, env.observation_space.shape)
|
||||
env.spec.id, env.observation_space.shape)
|
||||
preprocessor_shape = preprocessor.transform_shape(
|
||||
env.observation_space.shape)
|
||||
|
||||
@@ -197,7 +188,8 @@ class ESAgent(Agent):
|
||||
# Create the actors.
|
||||
print("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(self.config, policy_params, self.env_name, noise_id)
|
||||
Worker.remote(
|
||||
self.config, policy_params, self.env_creator, noise_id)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
@@ -10,13 +9,14 @@ from ray.rllib.models import ModelCatalog
|
||||
|
||||
class BatchedEnv(object):
|
||||
"""This holds multiple gym envs and performs steps on all of them."""
|
||||
def __init__(self, name, batchsize, options):
|
||||
self.envs = [gym.make(name) for _ in range(batchsize)]
|
||||
def __init__(self, env_creator, batchsize, options):
|
||||
self.envs = [env_creator() for _ in range(batchsize)]
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
name, self.envs[0].observation_space.shape, options["model"])
|
||||
self.envs[0].spec.id, self.envs[0].observation_space.shape,
|
||||
options["model"])
|
||||
self.extra_frameskip = options.get("extra_frameskip", 1)
|
||||
assert self.extra_frameskip >= 1
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ DEFAULT_CONFIG = {
|
||||
# If >1, adds frameskip
|
||||
"extra_frameskip": 1,
|
||||
# Number of timesteps collected in each outer loop
|
||||
"timesteps_per_batch": 40000,
|
||||
"timesteps_per_batch": 4000,
|
||||
# Each tasks performs rollouts until at least this
|
||||
# number of steps is obtained
|
||||
"min_steps_per_task": 1000,
|
||||
@@ -81,21 +81,16 @@ DEFAULT_CONFIG = {
|
||||
|
||||
|
||||
class PPOAgent(Agent):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "PPO"})
|
||||
|
||||
Agent.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
_agent_name = "PPO"
|
||||
|
||||
def _init(self):
|
||||
self.global_step = 0
|
||||
self.kl_coeff = self.config["kl_coeff"]
|
||||
self.model = Runner(self.env_name, 1, self.config, self.logdir, False)
|
||||
self.model = Runner(
|
||||
self.env_creator, 1, self.config, self.logdir, False)
|
||||
self.agents = [
|
||||
RemoteRunner.remote(
|
||||
self.env_name, 1, self.config, self.logdir, True)
|
||||
self.env_creator, 1, self.config, self.logdir, True)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.start_time = time.time()
|
||||
if self.config["write_logs"]:
|
||||
|
||||
@@ -37,7 +37,7 @@ class Runner(object):
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
|
||||
def __init__(self, name, batchsize, config, logdir, is_remote):
|
||||
def __init__(self, env_creator, batchsize, config, logdir, is_remote):
|
||||
if is_remote:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
devices = ["/cpu:0"]
|
||||
@@ -46,7 +46,7 @@ class Runner(object):
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = BatchedEnv(name, batchsize, config)
|
||||
self.env = BatchedEnv(env_creator, batchsize, config)
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user