[tune] Ray Tune API cleanup (#1454)

Remove rllib dep: trainable is now a standalone abstract class that can be easily subclassed.

Clean up hyperband: fix debug string and add an example.

Remove YAML api / ScriptRunner: this was never really used.

Move ray.init() out of run_experiments(): This provides greater flexibility and should be less confusing since there isn't an implicit init() done there. Note that this is a breaking API change for tune.
This commit is contained in:
Eric Liang
2018-01-24 16:55:17 -08:00
committed by GitHub
parent a1b01ee7fb
commit 173f1d629a
24 changed files with 486 additions and 421 deletions
+3 -2
View File
@@ -2,15 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Note: do not introduce unnecessary library dependencies here, e.g. gym
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
# This file is imported from the tune module in order to register RLlib agents.
from ray.tune.registry import register_trainable
from ray.rllib.agent import get_agent_class
def _register_all():
for key in [
"PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"]:
try:
from ray.rllib.agent import get_agent_class
register_trainable(key, get_agent_class(key))
except ImportError as e:
print("Warning: could not import {}: {}".format(key, e))
+20 -186
View File
@@ -2,25 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import logging
import numpy as np
import io
import os
import gzip
import pickle
import shutil
import tempfile
import time
import uuid
# Note: avoid introducing unnecessary library dependencies here, e.g. gym
# until https://github.com/ray-project/ray/issues/1144 is resolved
import tensorflow as tf
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import ENV_CREATOR, get_registry
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
logger = logging.getLogger(__name__)
@@ -66,7 +55,7 @@ class Agent(Trainable):
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry, for registering user-defined
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
@@ -83,183 +72,43 @@ class Agent(Trainable):
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, it will be assumed empty.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
env = env or config.get("env")
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, registry, logger_creator)
def _setup(self):
env = self._env_id
if env:
config["env"] = env
if registry and registry.contains(ENV_CREATOR, env):
self.env_creator = registry.get(ENV_CREATOR, env)
self.config["env"] = env
if self.registry and self.registry.contains(ENV_CREATOR, env):
self.env_creator = self.registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
self.config = self._default_config.copy()
self.registry = registry
self.config = _deep_update(self.config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
if logger_creator:
self._result_logger = logger_creator(self.config)
self.logdir = self._result_logger.logdir
else:
logdir_suffix = "{}_{}_{}".format(
env, self._agent_name,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
self.logdir = tempfile.mkdtemp(
prefix=logdir_suffix, dir=DEFAULT_RESULTS_DIR)
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = 0
# Merge the supplied config with the class default
merged_config = self._default_config.copy()
merged_config = _deep_update(merged_config, self.config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.config = merged_config
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
self._init()
self._initialize_ok = True
def _init(self):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
def train(self):
"""Runs one logical iteration of training.
Returns:
A TrainingResult that describes training progress.
"""
if not self._initialize_ok:
raise ValueError(
"Agent initialization failed, see previous errors")
start = time.time()
result = self._train()
self._iteration += 1
if result.time_this_iter_s is not None:
time_this_iter = result.time_this_iter_s
else:
time_this_iter = time.time() - start
assert result.timesteps_this_iter is not None
self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter
now = datetime.today()
result = result._replace(
experiment_id=self._experiment_id,
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
timestamp=int(time.mktime(now.timetuple())),
training_iteration=self._iteration,
timesteps_total=self._timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total,
pid=os.getpid(),
hostname=os.uname()[1])
self._result_logger.on_result(result)
return result
def save(self):
"""Saves the current model state to a checkpoint.
Returns:
Checkpoint path that may be passed to restore().
"""
checkpoint_path = self._save()
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
open(checkpoint_path + ".rllib_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
"""
checkpoint_prefix = self.save()
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
data[os.path.basename(path)] = open(path, "rb").read()
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
f.write(compressed)
return out.getvalue()
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
"""
self._restore(checkpoint_path)
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
self._experiment_id = metadata[0]
self._iteration = metadata[1]
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
"""
out = io.BytesIO(obj)
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
f.write(file_contents)
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def stop(self):
"""Releases all resources used by this agent."""
if self._initialize_ok:
self._result_logger.close()
self._stop()
def _stop(self):
"""Subclasses should override this for custom stopping."""
pass
def compute_action(self, observation):
"""Computes an action using the current trained policy."""
@@ -283,21 +132,6 @@ class Agent(Trainable):
raise NotImplementedError
def _train(self):
"""Subclasses should override this to implement train()."""
raise NotImplementedError
def _save(self):
"""Subclasses should override this to implement save()."""
raise NotImplementedError
def _restore(self, checkpoint_path):
"""Subclasses should override this to implement restore()."""
raise NotImplementedError
class _MockAgent(Agent):
"""Mock agent for use in tests"""
+3 -2
View File
@@ -8,6 +8,7 @@ import argparse
import sys
import yaml
import ray
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments
@@ -76,7 +77,7 @@ if __name__ == "__main__":
if not exp.get("env") and not exp.get("config", {}).get("env"):
parser.error("the following arguments are required: --env")
run_experiments(
experiments, scheduler=_make_scheduler(args),
ray.init(
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
run_experiments(experiments, scheduler=_make_scheduler(args))
-3
View File
@@ -6,13 +6,10 @@ from ray.tune.error import TuneError
from ray.tune.tune import run_experiments
from ray.tune.registry import register_env, register_trainable
from ray.tune.result import TrainingResult
from ray.tune.script_runner import ScriptRunner
from ray.tune.trainable import Trainable
from ray.tune.variant_generator import grid_search
register_trainable("script", ScriptRunner)
__all__ = [
"Trainable",
"TrainingResult",
+70
View File
@@ -0,0 +1,70 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import random
import numpy as np
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
from ray.tune.hyperband import HyperBandScheduler
class MyTrainableClass(Trainable):
"""Example agent whose learning curve is a random sigmoid.
The dummy hyperparameters "width" and "height" determine the slope and
maximum reward value reached.
"""
def _setup(self):
self.timestep = 0
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
def _save(self):
path = os.path.join(self.logdir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
ray.init()
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
hyperband = HyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
max_t=100)
run_experiments({
"hyperband_test": {
"run": "my_class",
"repeat": 100,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=hyperband)
@@ -33,6 +33,7 @@ import sys
import tempfile
import time
import ray
from ray.tune import grid_search, run_experiments, register_trainable
from tensorflow.examples.tutorials.mnist import input_data
@@ -222,4 +223,5 @@ if __name__ == '__main__':
if args.fast:
mnist_spec['stop']['training_iteration'] = 2
ray.init()
run_experiments({'tune_mnist_test': mnist_spec})
@@ -1,14 +0,0 @@
tune_mnist:
run: script
repeat: 2
resources:
cpu: 1
stop:
mean_accuracy: 0.99
time_total_s: 600
config:
script_file_path: examples/tune_mnist_ray.py
script_entrypoint: train
script_min_iter_time_s: 1
activation:
grid_search: ['relu', 'elu', 'tanh']
@@ -2,15 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import os
import sys
import time
import threading
import traceback
from ray.rllib.agent import Agent
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TrainingResult
@@ -53,12 +50,6 @@ class StatusReporter(object):
DEFAULT_CONFIG = {
# path of the script to run
"script_file_path": "/path/to/file.py",
# name of train function in the file, e.g. train(config, status_reporter)
"script_entrypoint": "train",
# batch results to at least this granularity
"script_min_iter_time_s": 1,
}
@@ -85,67 +76,37 @@ class _RunnerThread(threading.Thread):
self._status_reporter._done = True
def import_function(file_path, function_name):
# strong assumption here that we're in a new process
file_path = os.path.expanduser(file_path)
sys.path.insert(0, os.path.dirname(file_path))
if hasattr(importlib, "util"):
# Python 3.4+
spec = importlib.util.spec_from_file_location(
"external_file", file_path)
external_file = importlib.util.module_from_spec(spec)
spec.loader.exec_module(external_file)
elif hasattr(importlib, "machinery"):
# Python 3.3
from importlib.machinery import SourceFileLoader
external_file = SourceFileLoader(
"external_file", file_path).load_module()
else:
# Python 2.x
import imp
external_file = imp.load_source("external_file", file_path)
if not external_file:
raise TuneError("Unable to import file at {}".format(file_path))
return getattr(external_file, function_name)
class FunctionRunner(Trainable):
"""Trainable that runs a user function returning training results.
This mode of execution does not support checkpoint/restore."""
class ScriptRunner(Agent):
"""Agent that runs a user script returning training results."""
_agent_name = "script"
_name = "func"
_default_config = DEFAULT_CONFIG
_allow_unknown_configs = True
def _init(self):
def _setup(self):
entrypoint = self._trainable_func()
if not entrypoint:
entrypoint = import_function(
self.config["script_file_path"],
self.config["script_entrypoint"])
self._status_reporter = StatusReporter()
scrubbed_config = self.config.copy()
for k in self._default_config:
del scrubbed_config[k]
if k in scrubbed_config:
del scrubbed_config[k]
self._runner = _RunnerThread(
entrypoint, scrubbed_config, self._status_reporter)
self._start_time = time.time()
self._last_reported_time = self._start_time
self._last_reported_timestep = 0
self._runner.start()
# Subclasses can override this to set the trainable func
# TODO(ekl) this isn't a very clean layering, we should refactor it
def _trainable_func(self):
return None
"""Subclasses can override this to set the trainable func."""
def train(self):
if not self._initialize_ok:
raise ValueError(
"Agent initialization failed, see previous errors")
now = time.time()
time.sleep(self.config["script_min_iter_time_s"])
raise NotImplementedError
def _train(self):
time.sleep(
self.config.get(
"script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
time.sleep(1)
@@ -153,29 +114,10 @@ class ScriptRunner(Agent):
if result.timesteps_total is None:
raise TuneError("Must specify timesteps_total in result", result)
# Include the negative loss to use as a stopping condition
if result.mean_loss is not None:
neg_loss = -result.mean_loss
else:
neg_loss = result.neg_mean_loss
result = result._replace(
experiment_id=self._experiment_id,
neg_mean_loss=neg_loss,
training_iteration=self.iteration,
time_this_iter_s=now - self._last_reported_time,
timesteps_this_iter=(
result.timesteps_total - self._last_reported_timestep),
time_total_s=now - self._start_time,
pid=os.getpid(),
hostname=os.uname()[1])
if result.timesteps_total:
self._last_reported_timestep = result.timesteps_total
self._last_reported_time = now
self._iteration += 1
self._result_logger.on_result(result)
result.timesteps_total - self._last_reported_timestep))
self._last_reported_timestep = result.timesteps_total
return result
+18 -20
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
@@ -61,7 +62,8 @@ class HyperBandScheduler(FIFOScheduler):
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The HyperBand scheduler automatically tries to determine a
reasonable number of brackets based on this.
reasonable number of brackets based on this. The scheduler will
terminate trials after this time has passed.
"""
def __init__(
@@ -210,7 +212,8 @@ class HyperBandScheduler(FIFOScheduler):
List of trials not used since all trials are tracked as state
of scheduler. If iteration is occupied (ie, no trials to run),
then look into next iteration."""
then look into next iteration.
"""
for hyperband in self._hyperbands:
for bracket in sorted(hyperband,
@@ -222,18 +225,14 @@ class HyperBandScheduler(FIFOScheduler):
return None
def debug_string(self):
# TODO(rliaw): This debug string needs work
brackets = [
"({0}/{1})".format(
len(bracket._live_trials), len(bracket._all_trials))
for band in self._hyperbands for bracket in band]
return " ".join([
"Using HyperBand:",
"num_stopped={}".format(self._num_stopped),
"total_brackets={}".format(
sum(len(band) for band in self._hyperbands)),
" ".join(brackets)
])
out = "Using HyperBand: "
out += "num_stopped={} total_brackets={}".format(
self._num_stopped, sum(len(band) for band in self._hyperbands))
for i, band in enumerate(self._hyperbands):
out += "\nRound #{}:".format(i)
for bracket in band:
out += "\n {}".format(bracket)
return out
class Bracket():
@@ -370,10 +369,9 @@ class Bracket():
status = ", ".join([
"n={}".format(self._n),
"r={}".format(self._r),
"progress={}".format(self.completion_percentage())
"completed={}%".format(int(100 * self.completion_percentage()))
])
return "Bracket({})".format(status)
def debug_string(self):
trials = ", ".join([t.status for t in self._live_trials])
return "{}[{}]".format(self, trials)
counts = collections.Counter()
for t in self._all_trials:
counts[t.status] += 1
return "Bracket({}): {}".format(status, dict(counts))
+218 -12
View File
@@ -2,55 +2,261 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import gzip
import io
import os
import pickle
import shutil
import tempfile
import time
import uuid
from ray.tune import TuneError
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
class Trainable(object):
"""Interface for trainable models, functions, etc.
"""Abstract class for trainable models, functions, etc.
Implementing this interface is required to use Ray.tune's full
functionality, though you can also get away with supplying just a
`my_train(config, reporter)` function and calling:
A call to ``train()`` on a trainable will execute one logical iteration of
training. As a rule of thumb, the execution time of one train call should
be large enough to avoid overheads (i.e. more than a few seconds), but
short enough to report progress periodically (i.e. at most a few minutes).
Calling ``save()`` should save the training state of a trainable to disk,
and ``restore(path)`` should restore a trainable to the given state.
Generally you only need to implement ``_train``, ``_save``, and
``_restore`` here when subclassing Trainable.
Note that, if you don't require checkpoint/restore functionality, then
instead of implementing this class you can also get away with supplying
just a `my_train(config, reporter)` function and calling:
``register_trainable("my_func", train)``
to register it for use with tune. The function will be automatically
converted to this interface (sans checkpoint functionality)."""
to register it for use with Tune. The function will be automatically
converted to this interface (sans checkpoint functionality).
Attributes:
config (obj): The hyperparam configuration for this trial.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
def __init__(self, config={}, registry=None, logger_creator=None):
"""Initialize an Trainable.
Subclasses should prefer defining ``_setup()`` instead of overriding
``__init__()`` directly.
Args:
config (dict): Trainable-specific configuration data.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
if registry is None:
from ray.tune.registry import get_registry
registry = get_registry()
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
self.config = config
self.registry = registry
if logger_creator:
self._result_logger = logger_creator(self.config)
self.logdir = self._result_logger.logdir
else:
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
self.logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = 0
self._setup()
self._initialize_ok = True
def train(self):
"""Runs one logical iteration of training.
Subclasses should override ``_train()`` instead to return results.
This method auto-fills many fields, so only ``timesteps_this_iter``
is requied to be present.
Returns:
A TrainingResult that describes training progress.
"""
raise NotImplementedError
if not self._initialize_ok:
raise ValueError(
"Trainable initialization failed, see previous errors")
start = time.time()
result = self._train()
self._iteration += 1
if result.time_this_iter_s is not None:
time_this_iter = result.time_this_iter_s
else:
time_this_iter = time.time() - start
if result.timesteps_this_iter is None:
raise TuneError(
"Must specify timesteps_this_iter in result", result)
self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter
# Include the negative loss to use as a stopping condition
if result.mean_loss is not None:
neg_loss = -result.mean_loss
else:
neg_loss = result.neg_mean_loss
now = datetime.today()
result = result._replace(
experiment_id=self._experiment_id,
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
timestamp=int(time.mktime(now.timetuple())),
training_iteration=self._iteration,
timesteps_total=self._timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total,
neg_mean_loss=neg_loss,
pid=os.getpid(),
hostname=os.uname()[1])
self._result_logger.on_result(result)
return result
def save(self):
"""Saves the current model state to a checkpoint.
Subclasses should override ``_save()`` instead to save state.
This method dumps additional metadata alongside the saved path.
Returns:
Checkpoint path that may be passed to restore().
"""
raise NotImplementedError
checkpoint_path = self._save()
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
open(checkpoint_path + ".tune_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
"""
checkpoint_prefix = self.save()
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
data[os.path.basename(path)] = open(path, "rb").read()
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
f.write(compressed)
return out.getvalue()
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
Subclasses should override ``_restore()`` instead to restore state.
This method restores additional metadata saved with the checkpoint.
"""
self._restore(checkpoint_path)
metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb"))
self._experiment_id = metadata[0]
self._iteration = metadata[1]
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
"""
out = io.BytesIO(obj)
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
f.write(file_contents)
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def stop(self):
"""Releases all resources used by this trainable."""
if self._initialize_ok:
self._result_logger.close()
self._stop()
def _train(self):
"""Subclasses should override this to implement train()."""
raise NotImplementedError
def stop(self):
"""Releases all resources used by this class."""
def _save(self):
"""Subclasses should override this to implement save()."""
raise NotImplementedError
def _restore(self, checkpoint_path):
"""Subclasses should override this to implement restore()."""
raise NotImplementedError
def _setup(self):
"""Subclasses should override this for custom initialization."""
pass
def _stop(self):
"""Subclasses should override this for any cleanup on stop."""
pass
def wrap_function(train_func):
from ray.tune.script_runner import ScriptRunner
from ray.tune.function_runner import FunctionRunner
class WrappedFunc(ScriptRunner):
class WrappedFunc(FunctionRunner):
def _trainable_func(self):
return train_func
+11 -5
View File
@@ -2,18 +2,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from datetime import datetime
import tempfile
import time
import traceback
import ray
import os
from collections import namedtuple
from ray.utils import random_string, binary_to_hex
from ray.tune import TuneError
from ray.tune.logger import NoopLogger, UnifiedLogger
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.utils import random_string, binary_to_hex
DEBUG_PRINT_INTERVAL = 5
class Resources(
@@ -106,6 +109,7 @@ class Trial(object):
self.location = None
self.logdir = None
self.result_logger = None
self.last_debug = 0
self.trial_id = binary_to_hex(random_string())[:8]
def start(self):
@@ -293,8 +297,10 @@ class Trial(object):
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
self.last_result = result
self.result_logger.on_result(self.last_result)
+37 -8
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import ray
import time
@@ -13,6 +14,9 @@ from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
MAX_DEBUG_TRIALS = 20
class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@@ -127,9 +131,40 @@ class TrialRunner(object):
self._scheduler_alg.on_trial_add(self, trial)
self._trials.append(trial)
def debug_string(self):
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
"""Returns a human readable message for printing to the console."""
messages = self._debug_messages()
states = collections.defaultdict(set)
limit_per_state = collections.Counter()
for t in self._trials:
states[t.status].add(t)
# Show at most max_debug total, but divide the limit fairly
while max_debug > 0:
start_num = max_debug
for s in states:
if limit_per_state[s] >= len(states[s]):
continue
max_debug -= 1
limit_per_state[s] += 1
if max_debug == start_num:
break
for local_dir in sorted(set([t.local_dir for t in self._trials])):
messages.append("Result logdir: {}".format(local_dir))
for state, trials in sorted(states.items()):
limit = limit_per_state[state]
messages.append("{} trials:".format(state))
for t in sorted(
trials, key=lambda t: t.experiment_tag)[:limit]:
messages.append(" - {}:\t{}".format(t, t.progress_string()))
if len(trials) > limit:
messages.append(" ... {} more not shown".format(
len(trials) - limit))
return "\n".join(messages) + "\n"
def _debug_messages(self):
messages = ["== Status =="]
messages.append(self._scheduler_alg.debug_string())
if self._resources_initialized:
@@ -139,13 +174,7 @@ class TrialRunner(object):
self._avail_resources.cpu,
self._committed_resources.gpu,
self._avail_resources.gpu))
for local_dir in sorted(set([t.local_dir for t in self._trials])):
messages.append("Result logdir: {}".format(local_dir))
for t in self._trials:
if t.local_dir == local_dir:
messages.append(
" - {}:\t{}".format(t, t.progress_string()))
return "\n".join(messages) + "\n"
return messages
def has_resources(self, resources):
"""Returns whether this runner has at least the specified resources."""
Executable → Regular
+14 -54
View File
@@ -1,55 +1,19 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import sys
import ray
import time
from ray.tune import TuneError
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.trial import Trial
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.trial_runner import TrialRunner
from ray.tune.trial_scheduler import FIFOScheduler
from ray.tune.web_server import TuneServer
from ray.tune.variant_generator import generate_trials
EXAMPLE_USAGE = """
MNIST tuning example:
./tune.py -f examples/tune_mnist_ray.yaml
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Tune hyperparameters with Ray.",
epilog=EXAMPLE_USAGE)
# See also the base parser definition in ray/tune/config_parser.py
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument("--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument("--scheduler", default="FIFO", type=str,
help="FIFO, MedianStopping, or HyperBand")
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")
parser.add_argument("--server", default=False, type=bool,
help="Option to launch Tune Server")
parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT,
type=int, help="Option to launch Tune Server")
parser.add_argument("-f", "--config-file", required=True, type=str,
help="Read experiment options from this JSON/YAML file.")
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
@@ -67,7 +31,11 @@ def _make_scheduler(args):
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT, **ray_args):
server_port=TuneServer.DEFAULT_PORT):
# Make sure rllib agents are registered
from ray import rllib # noqa # pylint: disable=unused-import
if scheduler is None:
scheduler = FIFOScheduler()
@@ -77,13 +45,16 @@ def run_experiments(experiments, scheduler=None, with_server=False,
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
runner.add_trial(trial)
print(runner.debug_string())
ray.init(**ray_args)
print(runner.debug_string(max_debug=99999))
last_debug = 0
while not runner.is_finished():
runner.step()
print(runner.debug_string())
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
print(runner.debug_string())
last_debug = time.time()
print(runner.debug_string(max_debug=99999))
for trial in runner.get_trials():
# TODO(rliaw): What about errored?
@@ -91,14 +62,3 @@ def run_experiments(experiments, scheduler=None, with_server=False,
raise TuneError("Trial did not complete", trial)
return runner.get_trials()
if __name__ == "__main__":
import yaml
args = parser.parse_args(sys.argv[1:])
with open(args.config_file) as f:
experiments = yaml.load(f)
run_experiments(
experiments, _make_scheduler(args), with_server=args.server,
server_port=args.server_port, redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)