mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 07:50:30 +08:00
[rllib] Initial work on integrating hyperparameter search tool (#1107)
* clean up train * update * update train script * add tuned examples * add agent catalog * add tune lib * update * fix * testS * remove * train docs * comments * todo * fix resource parsing * fix cr test * add test * try to fix travis test
This commit is contained in:
@@ -11,7 +11,7 @@ import os
|
||||
import ray
|
||||
from ray.rllib.a3c.runner import RunnerThread, process_rollout
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.a3c.shared_model import SharedModel
|
||||
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
|
||||
|
||||
@@ -73,9 +73,8 @@ class Runner(object):
|
||||
return completed
|
||||
|
||||
def start(self):
|
||||
logdir = get_tensorflow_log_dir(self.logdir)
|
||||
summary_writer = tf.summary.FileWriter(
|
||||
os.path.join(logdir, "agent_%d" % self.id))
|
||||
os.path.join(self.logdir, "agent_%d" % self.id))
|
||||
self.summary_writer = summary_writer
|
||||
self.runner.start_runner(self.policy.sess, summary_writer)
|
||||
|
||||
@@ -96,6 +95,7 @@ class Runner(object):
|
||||
|
||||
class A3CAgent(Agent):
|
||||
_agent_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
self.env = create_and_wrap(self.env_creator, self.config["model"])
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
|
||||
|
||||
class _MockAgent(Agent):
|
||||
"""Mock agent for use in tests"""
|
||||
|
||||
_agent_name = "MockAgent"
|
||||
_default_config = {}
|
||||
|
||||
def _init(self):
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
return TrainingResult(
|
||||
episode_reward_mean=10, episode_len_mean=10,
|
||||
timesteps_this_iter=10, info={})
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of an known agent given its name."""
|
||||
|
||||
if alg == "PPO":
|
||||
from ray.rllib import ppo
|
||||
return ppo.PPOAgent
|
||||
elif alg == "ES":
|
||||
from ray.rllib import es
|
||||
return es.ESAgent
|
||||
elif alg == "DQN":
|
||||
from ray.rllib import dqn
|
||||
return dqn.DQNAgent
|
||||
elif alg == "A3C":
|
||||
from ray.rllib import a3c
|
||||
return a3c.A3CAgent
|
||||
elif alg == "__fake":
|
||||
return _MockAgent
|
||||
else:
|
||||
raise Exception(
|
||||
("Unknown algorithm {}, check --alg argument. Valid choices " +
|
||||
"are PPO, ES, DQN, and A3C.").format(alg))
|
||||
+135
-81
@@ -1,3 +1,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
|
||||
@@ -11,8 +15,6 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import gym
|
||||
import smart_open
|
||||
import tensorflow as tf
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
@@ -24,64 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def get_tensorflow_log_dir(logdir):
|
||||
if logdir.startswith("s3"):
|
||||
print("WARNING: TensorFlow logging to S3 not supported by"
|
||||
"TensorFlow, logging to /tmp/ray/ instead")
|
||||
logdir = "/tmp/ray/"
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
return logdir
|
||||
|
||||
|
||||
class RLLibEncoder(json.JSONEncoder):
|
||||
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(RLLibEncoder, self).__init__(**kwargs)
|
||||
self.nan_str = nan_str
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
if self.ensure_ascii:
|
||||
_encoder = json.encoder.encode_basestring_ascii
|
||||
else:
|
||||
_encoder = json.encoder.encode_basestring
|
||||
|
||||
def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str):
|
||||
return repr(o) if not np.isnan(o) else nan_str
|
||||
|
||||
_iterencode = json.encoder._make_iterencode(
|
||||
None, self.default, _encoder, self.indent, floatstr,
|
||||
self.key_separator, self.item_separator, self.sort_keys,
|
||||
self.skipkeys, _one_shot)
|
||||
return _iterencode(o, 0)
|
||||
|
||||
def default(self, value):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
if np.issubdtype(value, float):
|
||||
return float(value)
|
||||
if np.issubdtype(value, int):
|
||||
return int(value)
|
||||
|
||||
|
||||
class RLLibLogger(object):
|
||||
"""Writing small amounts of data to S3 with real-time updates.
|
||||
"""
|
||||
|
||||
def __init__(self, uri):
|
||||
self.result_buffer = StringIO.StringIO()
|
||||
self.uri = uri
|
||||
|
||||
def write(self, b):
|
||||
# TODO(pcm): At the moment we are writing the whole results output from
|
||||
# the beginning in each iteration. This will write O(n^2) bytes where n
|
||||
# is the number of bytes printed so far. Fix this! This should at least
|
||||
# only write the last 5MBs (S3 chunksize).
|
||||
with smart_open.smart_open(self.uri, "w") as f:
|
||||
self.result_buffer.write(b)
|
||||
f.write(self.result_buffer.getvalue())
|
||||
|
||||
|
||||
TrainingResult = namedtuple("TrainingResult", [
|
||||
# Unique string identifier for this experiment. This id is preserved
|
||||
# across checkpoint / restore calls.
|
||||
@@ -127,49 +71,72 @@ class Agent(object):
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
"""
|
||||
|
||||
def __init__(self, env_creator, config, upload_dir=None):
|
||||
def __init__(
|
||||
self, env_creator, config, local_dir='/tmp/ray',
|
||||
upload_dir=None, agent_id=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
env_creator (str|func): Name of the OpenAI gym environment to train
|
||||
against, or a function that creates such an env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
upload_dir (str): Root directory into which the output directory
|
||||
should be placed. Can be local like file:///tmp/ray/ or on S3
|
||||
like s3://bucketname/.
|
||||
local_dir (str): Directory where results and temporary files will
|
||||
be placed.
|
||||
upload_dir (str): Optional remote URI like s3://bucketname/ where
|
||||
results will be uploaded.
|
||||
agent_id (str): Optional unique identifier for this agent, used
|
||||
to determine where to store results in the local dir.
|
||||
"""
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
upload_dir = "file:///tmp/ray" if upload_dir is None else upload_dir
|
||||
if type(env_creator) is str:
|
||||
import gym
|
||||
env_name = env_creator
|
||||
self.env_creator = lambda: gym.make(env_name)
|
||||
else:
|
||||
env_name = "custom"
|
||||
self.env_creator = env_creator
|
||||
|
||||
self.config = config
|
||||
self.config.update({"experiment_id": self._experiment_id})
|
||||
self.config.update({"env_name": env_name})
|
||||
self.config.update({"alg": self._agent_name})
|
||||
prefix = "{}_{}_{}".format(
|
||||
self.config = self._default_config.copy()
|
||||
for k in config.keys():
|
||||
if k not in self.config:
|
||||
raise Exception(
|
||||
"Unknown agent config `{}`, "
|
||||
"all agent configs: {}".format(k, self.config.keys()))
|
||||
self.config.update(config)
|
||||
self.config.update({
|
||||
"agent_id": agent_id,
|
||||
"alg": self._agent_name,
|
||||
"env_name": env_name,
|
||||
"experiment_id": self._experiment_id,
|
||||
})
|
||||
|
||||
logdir_suffix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self.__class__.__name__,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if upload_dir.startswith("file"):
|
||||
local_dir = upload_dir[len("file://"):]
|
||||
if not os.path.exists(local_dir):
|
||||
os.makedirs(local_dir)
|
||||
self.logdir = tempfile.mkdtemp(prefix=prefix, dir=local_dir)
|
||||
agent_id or datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
|
||||
if not os.path.exists(local_dir):
|
||||
os.makedirs(local_dir)
|
||||
|
||||
self.logdir = tempfile.mkdtemp(prefix=logdir_suffix, dir=local_dir)
|
||||
|
||||
if upload_dir:
|
||||
log_upload_uri = os.path.join(upload_dir, logdir_suffix)
|
||||
else:
|
||||
self.logdir = os.path.join(upload_dir, prefix)
|
||||
log_upload_uri = None
|
||||
|
||||
# TODO(ekl) consider inlining config into the result jsons
|
||||
log_path = os.path.join(self.logdir, "config.json")
|
||||
with smart_open.smart_open(log_path, "w") as f:
|
||||
config_out = os.path.join(self.logdir, "config.json")
|
||||
with open(config_out, "w") as f:
|
||||
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
|
||||
logger.info(
|
||||
"%s algorithm created with logdir '%s'",
|
||||
self.__class__.__name__, self.logdir)
|
||||
"%s algorithm created with logdir '%s' and upload uri '%s'",
|
||||
self.__class__.__name__, self.logdir, log_upload_uri)
|
||||
|
||||
self._result_logger = RLLibLogger(
|
||||
os.path.join(self.logdir, "result.json"),
|
||||
log_upload_uri and os.path.join(log_upload_uri, "result.json"))
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
@@ -208,8 +175,29 @@ class Agent(object):
|
||||
for field in result:
|
||||
assert field is not None, result
|
||||
|
||||
self._log_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def _log_result(self, result):
|
||||
"""Appends the given result to this agent's log dir."""
|
||||
|
||||
# We need to use a custom json serializer class so that NaNs get
|
||||
# encoded as null as required by Athena.
|
||||
json.dump(result._asdict(), self._result_logger, cls=RLLibEncoder)
|
||||
self._result_logger.write("\n")
|
||||
train_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
tag="rllib/time_this_iter_s",
|
||||
simple_value=result.time_this_iter_s),
|
||||
tf.Summary.Value(
|
||||
tag="rllib/episode_reward_mean",
|
||||
simple_value=result.episode_reward_mean),
|
||||
tf.Summary.Value(
|
||||
tag="rllib/episode_len_mean",
|
||||
simple_value=result.episode_len_mean)])
|
||||
self._file_writer.add_summary(train_stats, result.training_iteration)
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
@@ -237,6 +225,11 @@ class Agent(object):
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this agent."""
|
||||
|
||||
self._file_writer.close()
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
||||
@@ -254,6 +247,12 @@ class Agent(object):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _default_config(self):
|
||||
"""Subclasses should override this to declare their default config."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
@@ -268,3 +267,58 @@ class Agent(object):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RLLibEncoder(json.JSONEncoder):
|
||||
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(RLLibEncoder, self).__init__(**kwargs)
|
||||
self.nan_str = nan_str
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
if self.ensure_ascii:
|
||||
_encoder = json.encoder.encode_basestring_ascii
|
||||
else:
|
||||
_encoder = json.encoder.encode_basestring
|
||||
|
||||
def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str):
|
||||
return repr(o) if not np.isnan(o) else nan_str
|
||||
|
||||
_iterencode = json.encoder._make_iterencode(
|
||||
None, self.default, _encoder, self.indent, floatstr,
|
||||
self.key_separator, self.item_separator, self.sort_keys,
|
||||
self.skipkeys, _one_shot)
|
||||
return _iterencode(o, 0)
|
||||
|
||||
def default(self, value):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
if np.issubdtype(value, float):
|
||||
return float(value)
|
||||
if np.issubdtype(value, int):
|
||||
return int(value)
|
||||
|
||||
|
||||
class RLLibLogger(object):
|
||||
"""Writing small amounts of data to S3 with real-time updates.
|
||||
"""
|
||||
|
||||
def __init__(self, local_file, uri=None):
|
||||
self.local_out = open(local_file, "w")
|
||||
self.result_buffer = StringIO.StringIO()
|
||||
self.uri = uri
|
||||
if self.uri:
|
||||
import smart_open
|
||||
self.smart_open = smart_open.smart_open
|
||||
|
||||
def write(self, b):
|
||||
self.local_out.write(b)
|
||||
self.local_out.flush()
|
||||
# TODO(pcm): At the moment we are writing the whole results output from
|
||||
# the beginning in each iteration. This will write O(n^2) bytes where n
|
||||
# is the number of bytes printed so far. Fix this! This should at least
|
||||
# only write the last 5MBs (S3 chunksize).
|
||||
if self.uri:
|
||||
with self.smart_open(self.uri, "w") as f:
|
||||
self.result_buffer.write(b)
|
||||
f.write(self.result_buffer.getvalue())
|
||||
|
||||
@@ -245,6 +245,7 @@ class RemoteActor(Actor):
|
||||
|
||||
class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
self.actor = Actor(self.env_creator, self.config, self.logdir)
|
||||
|
||||
@@ -159,6 +159,7 @@ class Worker(object):
|
||||
|
||||
class ESAgent(Agent):
|
||||
_agent_name = "ES"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import tensorflow as tf
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
import ray
|
||||
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.ppo.runner import Runner, RemoteRunner
|
||||
from ray.rllib.ppo.rollout import collect_samples
|
||||
from ray.rllib.ppo.utils import shuffle
|
||||
@@ -82,6 +82,7 @@ DEFAULT_CONFIG = {
|
||||
|
||||
class PPOAgent(Agent):
|
||||
_agent_name = "PPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def _init(self):
|
||||
self.global_step = 0
|
||||
@@ -94,9 +95,8 @@ class PPOAgent(Agent):
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.start_time = time.time()
|
||||
if self.config["write_logs"]:
|
||||
logdir = get_tensorflow_log_dir(self.logdir)
|
||||
self.file_writer = tf.summary.FileWriter(
|
||||
logdir, self.model.sess.graph)
|
||||
self.logdir, self.model.sess.graph)
|
||||
else:
|
||||
self.file_writer = None
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
|
||||
@@ -8,10 +8,7 @@ import numpy as np
|
||||
import ray
|
||||
import random
|
||||
|
||||
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
|
||||
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
|
||||
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
|
||||
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
||||
from ray.rllib.agents import get_agent_class
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
@@ -22,20 +19,18 @@ def get_mean_action(alg, obs):
|
||||
|
||||
|
||||
ray.init()
|
||||
for (cls, default_config) in [
|
||||
(DQNAgent, DQN_CONFIG),
|
||||
(PPOAgent, PG_CONFIG),
|
||||
(A3CAgent, A3C_CONFIG),
|
||||
# https://github.com/ray-project/ray/issues/1062
|
||||
# (ESAgent, ES_CONFIG),
|
||||
]:
|
||||
config = default_config.copy()
|
||||
config["num_sgd_iter"] = 5
|
||||
config["use_lstm"] = False # for a3c
|
||||
config["episodes_per_batch"] = 100
|
||||
config["timesteps_per_batch"] = 1000
|
||||
alg1 = cls("CartPole-v0", config)
|
||||
alg2 = cls("CartPole-v0", config)
|
||||
|
||||
CONFIGS = {
|
||||
"DQN": {},
|
||||
"PPO": {"num_sgd_iter": 5, "timesteps_per_batch": 1000},
|
||||
"A3C": {"use_lstm": False},
|
||||
}
|
||||
|
||||
# https://github.com/ray-project/ray/issues/1062 for enabling ES test as well
|
||||
for name in ["DQN", "PPO", "A3C"]:
|
||||
cls = get_agent_class(name)
|
||||
alg1 = cls("CartPole-v0", CONFIGS[name])
|
||||
alg2 = cls("CartPole-v0", CONFIGS[name])
|
||||
|
||||
for _ in range(3):
|
||||
res = alg1.train()
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
python train.py --env Hopper-v1 --config '{"gamma": 0.995, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 160000, "num_workers": 64}' --alg PPO --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env CartPole-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 160000, "num_workers": 64}' --alg PPO --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env Walker2d-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64}' --alg PPO --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env Humanoid-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64, "model": {"free_log_std": true}, "use_gae": false}' --alg PPO --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env Humanoid-v1 --config '{"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64, "model": {"free_log_std": true}, "write_logs": false}' --alg PPO --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env PongDeterministic-v4 --alg A3C --config '{"num_workers": 16, "num_batches_per_iteration": 1000, "batch_size": 20}' --upload-dir s3://bucketname/
|
||||
|
||||
python train.py --env Humanoid-v1 --alg EvolutionStrategies --upload-dir s3://bucketname/
|
||||
+47
-77
@@ -1,97 +1,67 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""The main command line interface to RLlib.
|
||||
|
||||
Arguments may either be specified on the command line or in JSON/YAML
|
||||
files. Additionally, the file-based interface supports hyperparameter
|
||||
exploration through grid or random search, though both interfaces allow
|
||||
for the concurrent execution of multiple trials on Ray.
|
||||
|
||||
Single-trial example:
|
||||
./train.py --alg=DQN --env=CartPole-v0
|
||||
|
||||
Hyperparameter grid search example:
|
||||
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
import ray.rllib.ppo as ppo
|
||||
import ray.rllib.es as es
|
||||
import ray.rllib.dqn as dqn
|
||||
import ray.rllib.a3c as a3c
|
||||
from ray.tune.config_parser import make_parser, parse_to_trials
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Train a reinforcement learning agent."))
|
||||
|
||||
parser = make_parser("Train a reinforcement learning agent.")
|
||||
|
||||
# Extends the base parser defined in ray/tune/config_parser, to add some
|
||||
# RLlib specific arguments. For more arguments, see the configuration
|
||||
# defined there.
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--env", required=True, type=str,
|
||||
help="The gym environment to use.")
|
||||
parser.add_argument("--alg", required=True, type=str,
|
||||
help="The reinforcement learning algorithm to use.")
|
||||
parser.add_argument("--num-iterations", default=sys.maxsize, type=int,
|
||||
help="The number of training iterations to run.")
|
||||
parser.add_argument("--config", default="{}", type=str,
|
||||
help="The configuration options of the algorithm.")
|
||||
parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str,
|
||||
help="Where the traces are stored.")
|
||||
parser.add_argument("--checkpoint-freq", default=sys.maxsize, type=int,
|
||||
help="How many iterations between checkpoints.")
|
||||
parser.add_argument("--restore", default="", type=str,
|
||||
help="If specified, restores state from this checkpoint.")
|
||||
parser.add_argument("--restore", default=None, type=str,
|
||||
help="If specified, restore from this checkpoint.")
|
||||
parser.add_argument("-f", "--config-file", default=None, type=str,
|
||||
help="If specified, use config options from this file.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
json_config = json.loads(args.config)
|
||||
runner = TrialRunner()
|
||||
|
||||
if args.config_file:
|
||||
with open(args.config_file) as f:
|
||||
config = yaml.load(f)
|
||||
for trial in parse_to_trials(config):
|
||||
runner.add_trial(trial)
|
||||
else:
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
args.env, args.alg, args.config, args.local_dir, None,
|
||||
args.resources, args.stop, args.checkpoint_freq,
|
||||
args.restore, args.upload_dir))
|
||||
|
||||
ray.init(redis_address=args.redis_address)
|
||||
|
||||
def _check_and_update(config, json):
|
||||
for k in json.keys():
|
||||
if k not in config:
|
||||
raise Exception(
|
||||
"Unknown model config `{}`, all model configs: {}".format(
|
||||
k, config.keys()))
|
||||
config.update(json)
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
print(runner.debug_string())
|
||||
|
||||
env_name = args.env
|
||||
if args.alg == "PPO":
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
_check_and_update(config, json_config)
|
||||
alg = ppo.PPOAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "ES":
|
||||
config = es.DEFAULT_CONFIG.copy()
|
||||
_check_and_update(config, json_config)
|
||||
alg = es.ESAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "DQN":
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
_check_and_update(config, json_config)
|
||||
alg = dqn.DQNAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
elif args.alg == "A3C":
|
||||
config = a3c.DEFAULT_CONFIG.copy()
|
||||
_check_and_update(config, json_config)
|
||||
alg = a3c.A3CAgent(
|
||||
env_name, config, upload_dir=args.upload_dir)
|
||||
else:
|
||||
assert False, ("Unknown algorithm, check --alg argument. Valid "
|
||||
"choices are PPO, ES, DQN and A3C.")
|
||||
|
||||
result_logger = ray.rllib.common.RLLibLogger(
|
||||
os.path.join(alg.logdir, "result.json"))
|
||||
|
||||
if args.restore:
|
||||
alg.restore(args.restore)
|
||||
|
||||
for i in range(args.num_iterations):
|
||||
result = alg.train()
|
||||
|
||||
# We need to use a custom json serializer class so that NaNs get
|
||||
# encoded as null as required by Athena.
|
||||
json.dump(result._asdict(), result_logger,
|
||||
cls=ray.rllib.common.RLLibEncoder)
|
||||
result_logger.write("\n")
|
||||
|
||||
print("== Iteration {} ==".format(alg.iteration))
|
||||
pprint.pprint(result._asdict())
|
||||
|
||||
if (i + 1) % args.checkpoint_freq == 0:
|
||||
print("checkpoint path: {}".format(alg.save()))
|
||||
for trial in runner.get_trials():
|
||||
if trial.status != Trial.TERMINATED:
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
num_trials: 6
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 180
|
||||
resources:
|
||||
cpu: 2
|
||||
config:
|
||||
num_workers: 2
|
||||
num_sgd_iter:
|
||||
grid_search: [1, 4]
|
||||
sgd_batchsize:
|
||||
grid_search: [128, 256, 512]
|
||||
@@ -0,0 +1,8 @@
|
||||
hopper-ppo:
|
||||
env: Hopper-v1
|
||||
alg: PPO
|
||||
num_trials: 1
|
||||
resources:
|
||||
cpu: 64
|
||||
gpu: 4
|
||||
config: {"gamma": 0.995, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 160000, "num_workers": 64}
|
||||
@@ -0,0 +1,9 @@
|
||||
humanoid-es:
|
||||
env: Humanoid-v1
|
||||
alg: ES
|
||||
resources:
|
||||
cpu: 100
|
||||
stop:
|
||||
episode_reward_mean: 6000
|
||||
config:
|
||||
num_workers: 100
|
||||
@@ -0,0 +1,11 @@
|
||||
humanoid-ppo-gae:
|
||||
env: Humanoid-v1
|
||||
alg: PPO
|
||||
num_trials: 1
|
||||
stop:
|
||||
episode_reward_mean: 6000
|
||||
resources:
|
||||
cpu: 64
|
||||
gpu: 4
|
||||
config: {"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64, "model": {"free_log_std": true}, "write_logs": false}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
humanoid-ppo:
|
||||
env: Humanoid-v1
|
||||
alg: PPO
|
||||
num_trials: 1
|
||||
stop:
|
||||
episode_reward_mean: 6000
|
||||
resources:
|
||||
cpu: 64
|
||||
gpu: 4
|
||||
config: {"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64, "model": {"free_log_std": true}, "use_gae": false}
|
||||
@@ -0,0 +1,9 @@
|
||||
pong-a3c:
|
||||
env: PongDeterministic-v4
|
||||
alg: A3C
|
||||
resources:
|
||||
cpu: 16
|
||||
config:
|
||||
num_workers: 16
|
||||
num_batches_per_iteration: 1000
|
||||
batch_size: 20
|
||||
@@ -0,0 +1,9 @@
|
||||
pong-dqn:
|
||||
env: PongDeterministic-v4
|
||||
alg: DQN
|
||||
resources:
|
||||
cpu: 1
|
||||
gpu: 1
|
||||
stop:
|
||||
episode_reward_mean: 20
|
||||
time_total_s: 7200
|
||||
@@ -0,0 +1,8 @@
|
||||
walker2d-v1-ppo:
|
||||
env: Walker2d-v1
|
||||
alg: PPO
|
||||
num_trials: 1
|
||||
resources:
|
||||
cpu: 64
|
||||
gpu: 4
|
||||
config: {"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64}
|
||||
@@ -0,0 +1,54 @@
|
||||
Ray.tune: Fast hyperparameter search
|
||||
====================================
|
||||
|
||||
Using ray.tune with RLlib
|
||||
-------------------------
|
||||
|
||||
One way to use ray.tune is through RLlib's train.py script. The train.py script
|
||||
supports two modes. For example, to run multiple concurrent trials of Pong:
|
||||
|
||||
- Inline args: ``./train.py --env=Pong-v0 --alg=PPO --num_trials=8 --stop '{"time_total_s": 3200}' --resources '{"cpu": 8, "gpu": 2}' --config '{"num_workers": 8, "sgd_num_iter": 10}'``
|
||||
|
||||
- File-based: ``./train.py -f tune-pong.yaml``
|
||||
|
||||
Both delegate scheduling of trials to the ray.tune TrialRunner class.
|
||||
Additionally, the file-based mode supports hyper-parameter tuning
|
||||
(currently just grid and random search).
|
||||
|
||||
To specify search parameters, variables in the `config` section may be set to
|
||||
different values for each trial. You can either specify `grid_search: <list>`
|
||||
in place of a concrete value to specify a grid search across the list of
|
||||
values, or `eval: <str>` for values to be sampled from the given Python
|
||||
expression.
|
||||
|
||||
.. code:: yaml
|
||||
|
||||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
num_trials: 6
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 180
|
||||
resources:
|
||||
cpu: 4
|
||||
config:
|
||||
num_workers: 4
|
||||
num_sgd_iter:
|
||||
grid_search: [1, 4]
|
||||
sgd_batchsize:
|
||||
grid_search: [128, 256, 512]
|
||||
lr:
|
||||
eval: random.uniform(1e-4, 1e-3)
|
||||
|
||||
See ray/rllib/tuned_examples for more examples of configs in YAML form.
|
||||
|
||||
Using ray.tune to run custom scripts
|
||||
------------------------------------
|
||||
|
||||
TODO
|
||||
|
||||
Using ray.tune as a library
|
||||
---------------------------
|
||||
|
||||
TODO
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from ray.tune.trial import Trial, Resources
|
||||
|
||||
|
||||
def _resource_json(data):
|
||||
values = json.loads(data)
|
||||
return Resources(values.get('cpu', 0), values.get('gpu', 0))
|
||||
|
||||
|
||||
def make_parser(description):
|
||||
"""Returns a base argument parser for the ray.tune tool."""
|
||||
|
||||
parser = argparse.ArgumentParser(description=(description))
|
||||
|
||||
parser.add_argument("--alg", default="PPO", type=str,
|
||||
help="The learning algorithm to train.")
|
||||
parser.add_argument("--stop", default="{}", type=json.loads,
|
||||
help="The stopping criteria, specified in JSON.")
|
||||
parser.add_argument("--config", default="{}", type=json.loads,
|
||||
help="The config of the algorithm, specified in JSON.")
|
||||
parser.add_argument("--resources", default='{"cpu": 1}',
|
||||
type=_resource_json,
|
||||
help="Amount of resources to allocate per trial.")
|
||||
parser.add_argument("--num_trials", default=1, type=int,
|
||||
help="Number of trials to evaluate.")
|
||||
parser.add_argument("--local_dir", default="/tmp/ray", type=str,
|
||||
help="Local dir to save training results to.")
|
||||
parser.add_argument("--upload_dir", default=None, type=str,
|
||||
help="URI to upload training results to.")
|
||||
parser.add_argument("--checkpoint_freq", default=sys.maxsize, type=int,
|
||||
help="How many iterations between checkpoints.")
|
||||
|
||||
# TODO(ekl) environments are RL specific
|
||||
parser.add_argument("--env", default=None, type=str,
|
||||
help="The gym environment to use.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_to_trials(config):
|
||||
"""Parses a json config to the number of trials specified by the config.
|
||||
|
||||
The input config is a mapping from experiment names to an argument
|
||||
dictionary describing a set of trials. These args include the parser args
|
||||
documented in make_parser().
|
||||
"""
|
||||
|
||||
def resolve(agent_cfg, resolved_vars, i):
|
||||
assert type(agent_cfg) == dict
|
||||
cfg = agent_cfg.copy()
|
||||
for p, val in cfg.items():
|
||||
if type(val) == dict and "eval" in val:
|
||||
cfg[p] = eval(val["eval"], {
|
||||
"random": random,
|
||||
"np": np,
|
||||
}, {
|
||||
"_i": i,
|
||||
})
|
||||
resolved_vars[p] = True
|
||||
return cfg, resolved_vars
|
||||
|
||||
def to_argv(config):
|
||||
argv = []
|
||||
for k, v in config.items():
|
||||
argv.append("--{}".format(k))
|
||||
if type(v) is str:
|
||||
argv.append(v)
|
||||
else:
|
||||
argv.append(json.dumps(v))
|
||||
return argv
|
||||
|
||||
def param_str(config, resolved_vars):
|
||||
return "_".join(
|
||||
[k + "=" + str(v) for k, v in sorted(config.items())
|
||||
if resolved_vars.get(k)])
|
||||
|
||||
parser = make_parser("Ray hyperparameter tuning tool")
|
||||
trials = []
|
||||
for experiment_name, exp_cfg in config.items():
|
||||
args = parser.parse_args(to_argv(exp_cfg))
|
||||
grid_search = _GridSearchGenerator(args.config)
|
||||
for i in range(args.num_trials):
|
||||
next_cfg, resolved_vars = grid_search.next()
|
||||
resolved, resolved_vars = resolve(next_cfg, resolved_vars, i)
|
||||
if resolved_vars:
|
||||
agent_id = "{}_{}".format(
|
||||
i, param_str(resolved, resolved_vars))
|
||||
else:
|
||||
agent_id = str(i)
|
||||
trials.append(Trial(
|
||||
args.env, args.alg, resolved,
|
||||
os.path.join(args.local_dir, experiment_name), agent_id,
|
||||
args.resources, args.stop, args.checkpoint_freq, None,
|
||||
args.upload_dir))
|
||||
|
||||
return trials
|
||||
|
||||
|
||||
class _GridSearchGenerator(object):
|
||||
"""Generator that implements grid search over a set of value lists."""
|
||||
|
||||
def __init__(self, agent_cfg):
|
||||
self.cfg = agent_cfg
|
||||
self.grid_values = []
|
||||
for p, val in sorted(agent_cfg.items()):
|
||||
if type(val) == dict and "grid_search" in val:
|
||||
assert type(val["grid_search"] == list)
|
||||
self.grid_values.append((p, val["grid_search"]))
|
||||
self.value_indices = [0] * len(self.grid_values)
|
||||
|
||||
def next(self):
|
||||
cfg = self.cfg.copy()
|
||||
resolved_vars = {}
|
||||
for i, (k, values) in enumerate(self.grid_values):
|
||||
idx = self.value_indices[i]
|
||||
cfg[k] = values[idx]
|
||||
resolved_vars[k] = True
|
||||
if self.grid_values:
|
||||
self._increment(0)
|
||||
return cfg, resolved_vars
|
||||
|
||||
def _increment(self, i):
|
||||
self.value_indices[i] += 1
|
||||
if self.value_indices[i] >= len(self.grid_values[i][1]):
|
||||
self.value_indices[i] = 0
|
||||
if i + 1 < len(self.value_indices):
|
||||
self._increment(i + 1)
|
||||
@@ -0,0 +1,167 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import ray
|
||||
|
||||
from collections import namedtuple
|
||||
from ray.rllib.agents import get_agent_class
|
||||
|
||||
|
||||
# Ray resources required to schedule a Trial
|
||||
Resources = namedtuple("Resources", ["cpu", "gpu"])
|
||||
|
||||
|
||||
class Trial(object):
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
Trials are themselves managed by the TrialRunner class, which implements
|
||||
the event loop for submitting trial runs to a Ray cluster.
|
||||
|
||||
Trials start in the PENDING state, and transition to RUNNING once started.
|
||||
On error it transitions to ERROR, otherwise TERMINATED on success.
|
||||
"""
|
||||
|
||||
PENDING = 'PENDING'
|
||||
RUNNING = 'RUNNING'
|
||||
TERMINATED = 'TERMINATED'
|
||||
ERROR = 'ERROR'
|
||||
|
||||
def __init__(
|
||||
self, env_creator, alg, config={}, local_dir='/tmp/ray',
|
||||
agent_id=None, resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion={}, checkpoint_freq=sys.maxsize,
|
||||
restore_path=None, upload_dir=None):
|
||||
"""Initialize a new trial.
|
||||
|
||||
The args here take the same meaning as the command line flags defined
|
||||
in ray.tune.config_parser.
|
||||
"""
|
||||
|
||||
# Immutable config
|
||||
self.env_creator = env_creator
|
||||
if type(env_creator) is str:
|
||||
self.env_name = env_creator
|
||||
else:
|
||||
self.env_name = "custom"
|
||||
self.alg = alg
|
||||
self.config = config
|
||||
self.local_dir = local_dir
|
||||
self.agent_id = agent_id
|
||||
self.resources = resources
|
||||
self.stopping_criterion = stopping_criterion
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.restore_path = restore_path
|
||||
self.upload_dir = upload_dir
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self.checkpoint_path = None
|
||||
self.agent = None
|
||||
self.status = Trial.PENDING
|
||||
|
||||
def start(self):
|
||||
"""Starts this trial.
|
||||
|
||||
If an error is encountered when starting the trial, an exception will
|
||||
be thrown.
|
||||
"""
|
||||
|
||||
self.status = Trial.RUNNING
|
||||
agent_cls = get_agent_class(self.alg)
|
||||
cls = ray.remote(
|
||||
num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)(
|
||||
agent_cls)
|
||||
self.agent = cls.remote(
|
||||
self.env_creator, self.config, self.local_dir, self.upload_dir,
|
||||
agent_id=self.agent_id)
|
||||
if self.restore_path:
|
||||
ray.get(self.agent.restore.remote(self.restore_path))
|
||||
|
||||
def stop(self, error=False):
|
||||
"""Stops this trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources. If stopping the
|
||||
trial fails, the run will be marked as terminated in error, but no
|
||||
exception will be thrown.
|
||||
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
"""
|
||||
|
||||
if error:
|
||||
self.status = Trial.ERROR
|
||||
else:
|
||||
self.status = Trial.TERMINATED
|
||||
|
||||
try:
|
||||
if self.agent:
|
||||
self.agent.stop.remote()
|
||||
self.agent.__ray_terminate__.remote(
|
||||
self.agent._ray_actor_id.id())
|
||||
except:
|
||||
print("Error stopping agent:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
finally:
|
||||
self.agent = None
|
||||
|
||||
def train_remote(self):
|
||||
"""Returns Ray future for one iteration of training."""
|
||||
|
||||
assert self.status == Trial.RUNNING, self.status
|
||||
return self.agent.train.remote()
|
||||
|
||||
def should_stop(self, result):
|
||||
"""Whether the given result meets this trial's stopping criteria."""
|
||||
|
||||
for criteria, stop_value in self.stopping_criterion.items():
|
||||
if getattr(result, criteria) >= stop_value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_checkpoint(self):
|
||||
"""Whether this trial is due for checkpointing."""
|
||||
|
||||
if self.checkpoint_freq is None:
|
||||
return False
|
||||
|
||||
return self.last_result.training_iteration % self.checkpoint_freq == 0
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a progress message for printing out to the console."""
|
||||
|
||||
if self.last_result is None:
|
||||
return self.status
|
||||
return '{}, {} s, {} ts, {} itrs, {} rew'.format(
|
||||
self.status,
|
||||
int(self.last_result.time_total_s),
|
||||
int(self.last_result.timesteps_total),
|
||||
self.last_result.training_iteration,
|
||||
round(self.last_result.episode_reward_mean, 1))
|
||||
|
||||
def checkpoint(self):
|
||||
"""Synchronously checkpoints the state of this trial.
|
||||
|
||||
TODO(ekl): we should support a PAUSED state based on checkpointing.
|
||||
"""
|
||||
|
||||
path = ray.get(self.agent.save.remote())
|
||||
self.checkpoint_path = path
|
||||
print("Saved checkpoint to:", path)
|
||||
|
||||
return path
|
||||
|
||||
def __str__(self):
|
||||
identifier = '{}_{}'.format(self.alg, self.env_name)
|
||||
if self.agent_id:
|
||||
identifier += '_' + self.agent_id
|
||||
return identifier
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
@@ -0,0 +1,182 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from ray.tune.trial import Trial, Resources
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
Example:
|
||||
runner = TrialRunner()
|
||||
runner.add_trial(Trial(...))
|
||||
runner.add_trial(Trial(...))
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
print(runner.debug_string())
|
||||
|
||||
The main job of TrialRunner is scheduling trials to efficiently use cluster
|
||||
resources, without overloading the cluster.
|
||||
|
||||
While Ray itself provides resource management for tasks and actors, this is
|
||||
not sufficient when scheduling trials that may instantiate multiple actors.
|
||||
This is because if insufficient resources are available, concurrent agents
|
||||
could deadlock waiting for new resources to become available. Furthermore,
|
||||
oversubscribing the cluster could degrade training performance, leading to
|
||||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a new TrialRunner."""
|
||||
|
||||
self._trials = []
|
||||
self._pending = {}
|
||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||
self._committed_resources = Resources(cpu=0, gpu=0)
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
||||
for t in self._trials:
|
||||
if t.status in [Trial.PENDING, Trial.RUNNING]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def step(self):
|
||||
"""Runs one step of the trial event loop.
|
||||
|
||||
Callers should typically run this method repeatedly in a loop. They
|
||||
may inspect or modify the runner's state in between calls to step().
|
||||
"""
|
||||
|
||||
if self._can_launch_more():
|
||||
self._launch_trial()
|
||||
elif self._pending:
|
||||
self._process_events()
|
||||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
assert self._has_resources(trial.resources), \
|
||||
("Insufficient cluster resources to launch trial",
|
||||
trial.resources)
|
||||
assert False, "Called step when all trials finished?"
|
||||
|
||||
def get_trials(self):
|
||||
"""Returns the list of trials managed by this TrialRunner.
|
||||
|
||||
Note that the caller usually should not mutate trial state directly.
|
||||
"""
|
||||
|
||||
return self._trials
|
||||
|
||||
def add_trial(self, trial):
|
||||
"""Adds a new trial to this TrialRunner.
|
||||
|
||||
Trials may be added at any time.
|
||||
"""
|
||||
|
||||
self._trials.append(trial)
|
||||
|
||||
def debug_string(self):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
messages = ["== Status =="]
|
||||
messages.append(
|
||||
"Available: {}".format(self._avail_resources))
|
||||
messages.append(
|
||||
"Committed: {}".format(self._committed_resources))
|
||||
for local_dir in sorted(set([t.local_dir for t in self._trials])):
|
||||
messages.append("Tensorboard logdir: {}".format(local_dir))
|
||||
for t in self._trials:
|
||||
if t.local_dir == local_dir:
|
||||
messages.append(
|
||||
" - {}:\t{}".format(t, t.progress_string()))
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _can_launch_more(self):
|
||||
self._update_avail_resources()
|
||||
trial = self._get_runnable()
|
||||
return trial is not None
|
||||
|
||||
def _launch_trial(self):
|
||||
trial = self._get_runnable()
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
trial.start()
|
||||
self._pending[trial.train_remote()] = trial
|
||||
except:
|
||||
print("Error starting agent, retrying:", traceback.format_exc())
|
||||
time.sleep(2)
|
||||
trial.stop(error=True)
|
||||
try:
|
||||
trial.start()
|
||||
self._pending[trial.train_remote()] = trial
|
||||
except:
|
||||
print("Error starting agent, abort:", traceback.format_exc())
|
||||
trial.stop(error=True)
|
||||
# note that we don't return the resources, since they may
|
||||
# have been lost
|
||||
|
||||
def _process_events(self):
|
||||
[result_id], _ = ray.wait(self._pending.keys())
|
||||
trial = self._pending[result_id]
|
||||
del self._pending[result_id]
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
print("result", result)
|
||||
trial.last_result = result
|
||||
|
||||
if trial.should_stop(result):
|
||||
self._return_resources(trial.resources)
|
||||
trial.stop()
|
||||
else:
|
||||
# TODO(rliaw): This implements checkpoint in a blocking manner
|
||||
if trial.should_checkpoint():
|
||||
trial.checkpoint()
|
||||
self._pending[trial.train_remote()] = trial
|
||||
except:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
if trial.status == Trial.RUNNING:
|
||||
self._return_resources(trial.resources)
|
||||
trial.stop(error=True)
|
||||
|
||||
def _get_runnable(self):
|
||||
for trial in self._trials:
|
||||
if (trial.status == Trial.PENDING and
|
||||
self._has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
def _has_resources(self, resources):
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
assert cpu_avail >= 0 and gpu_avail >= 0
|
||||
return resources.cpu <= cpu_avail and resources.gpu <= gpu_avail
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu + resources.cpu,
|
||||
self._committed_resources.gpu + resources.gpu)
|
||||
|
||||
def _return_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu - resources.cpu,
|
||||
self._committed_resources.gpu - resources.gpu)
|
||||
assert self._committed_resources.cpu >= 0
|
||||
assert self._committed_resources.gpu >= 0
|
||||
|
||||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
local_schedulers = [
|
||||
entry for client in clients.values() for entry in client
|
||||
if (entry['ClientType'] == 'local_scheduler' and not
|
||||
entry['Deleted'])
|
||||
]
|
||||
num_cpus = sum(ls['NumCPUs'] for ls in local_schedulers)
|
||||
num_gpus = sum(ls['NumGPUs'] for ls in local_schedulers)
|
||||
self._avail_resources = Resources(int(num_cpus), int(num_gpus))
|
||||
Reference in New Issue
Block a user