mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 01:07:38 +08:00
[tune] Ray Tune API cleanup (#1454)
Remove rllib dep: trainable is now a standalone abstract class that can be easily subclassed. Clean up hyperband: fix debug string and add an example. Remove YAML api / ScriptRunner: this was never really used. Move ray.init() out of run_experiments(): This provides greater flexibility and should be less confusing since there isn't an implicit init() done there. Note that this is a breaking API change for tune.
This commit is contained in:
@@ -2,15 +2,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Note: do not introduce unnecessary library dependencies here, e.g. gym
|
||||
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
|
||||
# This file is imported from the tune module in order to register RLlib agents.
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.rllib.agent import get_agent_class
|
||||
|
||||
|
||||
def _register_all():
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"]:
|
||||
try:
|
||||
from ray.rllib.agent import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
except ImportError as e:
|
||||
print("Warning: could not import {}: {}".format(key, e))
|
||||
|
||||
+20
-186
@@ -2,25 +2,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import io
|
||||
import os
|
||||
import gzip
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Note: avoid introducing unnecessary library dependencies here, e.g. gym
|
||||
# until https://github.com/ray-project/ray/issues/1144 is resolved
|
||||
import tensorflow as tf
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR, get_registry
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,7 +55,7 @@ class Agent(Trainable):
|
||||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Tune object registry, for registering user-defined
|
||||
registry (obj): Tune object registry which holds user-registered
|
||||
classes and objects by name.
|
||||
"""
|
||||
|
||||
@@ -83,183 +72,43 @@ class Agent(Trainable):
|
||||
env (str): Name of the environment to use. Note that this can also
|
||||
be specified as the `env` key in config.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, it will be assumed empty.
|
||||
If unspecified, the default registry will be used.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
env = env or config.get("env")
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = env or config.get("env")
|
||||
Trainable.__init__(self, config, registry, logger_creator)
|
||||
|
||||
def _setup(self):
|
||||
env = self._env_id
|
||||
if env:
|
||||
config["env"] = env
|
||||
if registry and registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = registry.get(ENV_CREATOR, env)
|
||||
self.config["env"] = env
|
||||
if self.registry and self.registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = self.registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda env_config: gym.make(env)
|
||||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
self.config = self._default_config.copy()
|
||||
self.registry = registry
|
||||
|
||||
self.config = _deep_update(self.config, config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
self.logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_suffix = "{}_{}_{}".format(
|
||||
env, self._agent_name,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_suffix, dir=DEFAULT_RESULTS_DIR)
|
||||
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
|
||||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = 0
|
||||
# Merge the supplied config with the class default
|
||||
merged_config = self._default_config.copy()
|
||||
merged_config = _deep_update(merged_config, self.config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
self.config = merged_config
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
|
||||
self._initialize_ok = True
|
||||
|
||||
def _init(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Agent initialization failed, see previous errors")
|
||||
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
self._iteration += 1
|
||||
if result.time_this_iter_s is not None:
|
||||
time_this_iter = result.time_this_iter_s
|
||||
else:
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
assert result.timesteps_this_iter is not None
|
||||
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
|
||||
now = datetime.today()
|
||||
result = result._replace(
|
||||
experiment_id=self._experiment_id,
|
||||
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
timestamp=int(time.mktime(now.timetuple())),
|
||||
training_iteration=self._iteration,
|
||||
timesteps_total=self._timesteps_total,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self._time_total,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
checkpoint_path = self._save()
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".rllib_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object. It also
|
||||
saves to disk but does not return the checkpoint path.
|
||||
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_prefix = self.save()
|
||||
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
data[os.path.basename(path)] = open(path, "rb").read()
|
||||
|
||||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
"""
|
||||
|
||||
self._restore(checkpoint_path)
|
||||
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
|
||||
self._experiment_id = metadata[0]
|
||||
self._iteration = metadata[1]
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
"""Restores training state from a checkpoint object.
|
||||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
|
||||
out = io.BytesIO(obj)
|
||||
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for file_name, file_contents in data.items():
|
||||
with open(os.path.join(tmpdir, file_name), "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this agent."""
|
||||
|
||||
if self._initialize_ok:
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
def _stop(self):
|
||||
"""Subclasses should override this for custom stopping."""
|
||||
|
||||
pass
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
||||
@@ -283,21 +132,6 @@ class Agent(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _save(self):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _MockAgent(Agent):
|
||||
"""Mock agent for use in tests"""
|
||||
|
||||
@@ -8,6 +8,7 @@ import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
from ray.tune.config_parser import make_parser, resources_to_json
|
||||
from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
@@ -76,7 +77,7 @@ if __name__ == "__main__":
|
||||
if not exp.get("env") and not exp.get("config", {}).get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
|
||||
run_experiments(
|
||||
experiments, scheduler=_make_scheduler(args),
|
||||
ray.init(
|
||||
redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
run_experiments(experiments, scheduler=_make_scheduler(args))
|
||||
|
||||
@@ -6,13 +6,10 @@ from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.script_runner import ScriptRunner
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.variant_generator import grid_search
|
||||
|
||||
|
||||
register_trainable("script", ScriptRunner)
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"TrainingResult",
|
||||
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, TrainingResult, register_trainable, \
|
||||
run_experiments
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Example agent whose learning curve is a random sigmoid.
|
||||
|
||||
The dummy hyperparameters "width" and "height" determine the slope and
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy (see tune/result.py).
|
||||
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
|
||||
|
||||
def _save(self):
|
||||
path = os.path.join(self.logdir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
max_t=100)
|
||||
|
||||
run_experiments({
|
||||
"hyperband_test": {
|
||||
"run": "my_class",
|
||||
"repeat": 100,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
}, scheduler=hyperband)
|
||||
@@ -33,6 +33,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import grid_search, run_experiments, register_trainable
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
@@ -222,4 +223,5 @@ if __name__ == '__main__':
|
||||
if args.fast:
|
||||
mnist_spec['stop']['training_iteration'] = 2
|
||||
|
||||
ray.init()
|
||||
run_experiments({'tune_mnist_test': mnist_spec})
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
tune_mnist:
|
||||
run: script
|
||||
repeat: 2
|
||||
resources:
|
||||
cpu: 1
|
||||
stop:
|
||||
mean_accuracy: 0.99
|
||||
time_total_s: 600
|
||||
config:
|
||||
script_file_path: examples/tune_mnist_ray.py
|
||||
script_entrypoint: train
|
||||
script_min_iter_time_s: 1
|
||||
activation:
|
||||
grid_search: ['relu', 'elu', 'tanh']
|
||||
@@ -2,15 +2,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
@@ -53,12 +50,6 @@ class StatusReporter(object):
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# path of the script to run
|
||||
"script_file_path": "/path/to/file.py",
|
||||
|
||||
# name of train function in the file, e.g. train(config, status_reporter)
|
||||
"script_entrypoint": "train",
|
||||
|
||||
# batch results to at least this granularity
|
||||
"script_min_iter_time_s": 1,
|
||||
}
|
||||
@@ -85,67 +76,37 @@ class _RunnerThread(threading.Thread):
|
||||
self._status_reporter._done = True
|
||||
|
||||
|
||||
def import_function(file_path, function_name):
|
||||
# strong assumption here that we're in a new process
|
||||
file_path = os.path.expanduser(file_path)
|
||||
sys.path.insert(0, os.path.dirname(file_path))
|
||||
if hasattr(importlib, "util"):
|
||||
# Python 3.4+
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"external_file", file_path)
|
||||
external_file = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(external_file)
|
||||
elif hasattr(importlib, "machinery"):
|
||||
# Python 3.3
|
||||
from importlib.machinery import SourceFileLoader
|
||||
external_file = SourceFileLoader(
|
||||
"external_file", file_path).load_module()
|
||||
else:
|
||||
# Python 2.x
|
||||
import imp
|
||||
external_file = imp.load_source("external_file", file_path)
|
||||
if not external_file:
|
||||
raise TuneError("Unable to import file at {}".format(file_path))
|
||||
return getattr(external_file, function_name)
|
||||
class FunctionRunner(Trainable):
|
||||
"""Trainable that runs a user function returning training results.
|
||||
|
||||
This mode of execution does not support checkpoint/restore."""
|
||||
|
||||
class ScriptRunner(Agent):
|
||||
"""Agent that runs a user script returning training results."""
|
||||
|
||||
_agent_name = "script"
|
||||
_name = "func"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_configs = True
|
||||
|
||||
def _init(self):
|
||||
def _setup(self):
|
||||
entrypoint = self._trainable_func()
|
||||
if not entrypoint:
|
||||
entrypoint = import_function(
|
||||
self.config["script_file_path"],
|
||||
self.config["script_entrypoint"])
|
||||
self._status_reporter = StatusReporter()
|
||||
scrubbed_config = self.config.copy()
|
||||
for k in self._default_config:
|
||||
del scrubbed_config[k]
|
||||
if k in scrubbed_config:
|
||||
del scrubbed_config[k]
|
||||
self._runner = _RunnerThread(
|
||||
entrypoint, scrubbed_config, self._status_reporter)
|
||||
self._start_time = time.time()
|
||||
self._last_reported_time = self._start_time
|
||||
self._last_reported_timestep = 0
|
||||
self._runner.start()
|
||||
|
||||
# Subclasses can override this to set the trainable func
|
||||
# TODO(ekl) this isn't a very clean layering, we should refactor it
|
||||
def _trainable_func(self):
|
||||
return None
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
|
||||
def train(self):
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Agent initialization failed, see previous errors")
|
||||
|
||||
now = time.time()
|
||||
time.sleep(self.config["script_min_iter_time_s"])
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
time.sleep(
|
||||
self.config.get(
|
||||
"script_min_iter_time_s",
|
||||
self._default_config["script_min_iter_time_s"]))
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
while result is None:
|
||||
time.sleep(1)
|
||||
@@ -153,29 +114,10 @@ class ScriptRunner(Agent):
|
||||
if result.timesteps_total is None:
|
||||
raise TuneError("Must specify timesteps_total in result", result)
|
||||
|
||||
# Include the negative loss to use as a stopping condition
|
||||
if result.mean_loss is not None:
|
||||
neg_loss = -result.mean_loss
|
||||
else:
|
||||
neg_loss = result.neg_mean_loss
|
||||
|
||||
result = result._replace(
|
||||
experiment_id=self._experiment_id,
|
||||
neg_mean_loss=neg_loss,
|
||||
training_iteration=self.iteration,
|
||||
time_this_iter_s=now - self._last_reported_time,
|
||||
timesteps_this_iter=(
|
||||
result.timesteps_total - self._last_reported_timestep),
|
||||
time_total_s=now - self._start_time,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
|
||||
if result.timesteps_total:
|
||||
self._last_reported_timestep = result.timesteps_total
|
||||
self._last_reported_time = now
|
||||
self._iteration += 1
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
result.timesteps_total - self._last_reported_timestep))
|
||||
self._last_reported_timestep = result.timesteps_total
|
||||
|
||||
return result
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
@@ -61,7 +62,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
max_t (int): max time units per trial. Trials will be stopped after
|
||||
max_t time units (determined by time_attr) have passed.
|
||||
The HyperBand scheduler automatically tries to determine a
|
||||
reasonable number of brackets based on this.
|
||||
reasonable number of brackets based on this. The scheduler will
|
||||
terminate trials after this time has passed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -210,7 +212,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
of scheduler. If iteration is occupied (ie, no trials to run),
|
||||
then look into next iteration."""
|
||||
then look into next iteration.
|
||||
"""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
@@ -222,18 +225,14 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
return None
|
||||
|
||||
def debug_string(self):
|
||||
# TODO(rliaw): This debug string needs work
|
||||
brackets = [
|
||||
"({0}/{1})".format(
|
||||
len(bracket._live_trials), len(bracket._all_trials))
|
||||
for band in self._hyperbands for bracket in band]
|
||||
return " ".join([
|
||||
"Using HyperBand:",
|
||||
"num_stopped={}".format(self._num_stopped),
|
||||
"total_brackets={}".format(
|
||||
sum(len(band) for band in self._hyperbands)),
|
||||
" ".join(brackets)
|
||||
])
|
||||
out = "Using HyperBand: "
|
||||
out += "num_stopped={} total_brackets={}".format(
|
||||
self._num_stopped, sum(len(band) for band in self._hyperbands))
|
||||
for i, band in enumerate(self._hyperbands):
|
||||
out += "\nRound #{}:".format(i)
|
||||
for bracket in band:
|
||||
out += "\n {}".format(bracket)
|
||||
return out
|
||||
|
||||
|
||||
class Bracket():
|
||||
@@ -370,10 +369,9 @@ class Bracket():
|
||||
status = ", ".join([
|
||||
"n={}".format(self._n),
|
||||
"r={}".format(self._r),
|
||||
"progress={}".format(self.completion_percentage())
|
||||
"completed={}%".format(int(100 * self.completion_percentage()))
|
||||
])
|
||||
return "Bracket({})".format(status)
|
||||
|
||||
def debug_string(self):
|
||||
trials = ", ".join([t.status for t in self._live_trials])
|
||||
return "{}[{}]".format(self, trials)
|
||||
counts = collections.Counter()
|
||||
for t in self._all_trials:
|
||||
counts[t.status] += 1
|
||||
return "Bracket({}): {}".format(status, dict(counts))
|
||||
|
||||
+218
-12
@@ -2,55 +2,261 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import gzip
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
class Trainable(object):
|
||||
"""Interface for trainable models, functions, etc.
|
||||
"""Abstract class for trainable models, functions, etc.
|
||||
|
||||
Implementing this interface is required to use Ray.tune's full
|
||||
functionality, though you can also get away with supplying just a
|
||||
`my_train(config, reporter)` function and calling:
|
||||
A call to ``train()`` on a trainable will execute one logical iteration of
|
||||
training. As a rule of thumb, the execution time of one train call should
|
||||
be large enough to avoid overheads (i.e. more than a few seconds), but
|
||||
short enough to report progress periodically (i.e. at most a few minutes).
|
||||
|
||||
Calling ``save()`` should save the training state of a trainable to disk,
|
||||
and ``restore(path)`` should restore a trainable to the given state.
|
||||
|
||||
Generally you only need to implement ``_train``, ``_save``, and
|
||||
``_restore`` here when subclassing Trainable.
|
||||
|
||||
Note that, if you don't require checkpoint/restore functionality, then
|
||||
instead of implementing this class you can also get away with supplying
|
||||
just a `my_train(config, reporter)` function and calling:
|
||||
|
||||
``register_trainable("my_func", train)``
|
||||
|
||||
to register it for use with tune. The function will be automatically
|
||||
converted to this interface (sans checkpoint functionality)."""
|
||||
to register it for use with Tune. The function will be automatically
|
||||
converted to this interface (sans checkpoint functionality).
|
||||
|
||||
Attributes:
|
||||
config (obj): The hyperparam configuration for this trial.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Tune object registry which holds user-registered
|
||||
classes and objects by name.
|
||||
"""
|
||||
|
||||
def __init__(self, config={}, registry=None, logger_creator=None):
|
||||
"""Initialize an Trainable.
|
||||
|
||||
Subclasses should prefer defining ``_setup()`` instead of overriding
|
||||
``__init__()`` directly.
|
||||
|
||||
Args:
|
||||
config (dict): Trainable-specific configuration data.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, the default registry will be used.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
if registry is None:
|
||||
from ray.tune.registry import get_registry
|
||||
registry = get_registry()
|
||||
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
self.config = config
|
||||
self.registry = registry
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
self.logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
|
||||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = 0
|
||||
self._setup()
|
||||
self._initialize_ok = True
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Subclasses should override ``_train()`` instead to return results.
|
||||
This method auto-fills many fields, so only ``timesteps_this_iter``
|
||||
is requied to be present.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Trainable initialization failed, see previous errors")
|
||||
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
self._iteration += 1
|
||||
if result.time_this_iter_s is not None:
|
||||
time_this_iter = result.time_this_iter_s
|
||||
else:
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
if result.timesteps_this_iter is None:
|
||||
raise TuneError(
|
||||
"Must specify timesteps_this_iter in result", result)
|
||||
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
|
||||
# Include the negative loss to use as a stopping condition
|
||||
if result.mean_loss is not None:
|
||||
neg_loss = -result.mean_loss
|
||||
else:
|
||||
neg_loss = result.neg_mean_loss
|
||||
|
||||
now = datetime.today()
|
||||
result = result._replace(
|
||||
experiment_id=self._experiment_id,
|
||||
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
timestamp=int(time.mktime(now.timetuple())),
|
||||
training_iteration=self._iteration,
|
||||
timesteps_total=self._timesteps_total,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self._time_total,
|
||||
neg_mean_loss=neg_loss,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Subclasses should override ``_save()`` instead to save state.
|
||||
This method dumps additional metadata alongside the saved path.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
checkpoint_path = self._save()
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object. It also
|
||||
saves to disk but does not return the checkpoint path.
|
||||
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_prefix = self.save()
|
||||
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
data[os.path.basename(path)] = open(path, "rb").read()
|
||||
|
||||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
|
||||
Subclasses should override ``_restore()`` instead to restore state.
|
||||
This method restores additional metadata saved with the checkpoint.
|
||||
"""
|
||||
|
||||
self._restore(checkpoint_path)
|
||||
metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb"))
|
||||
self._experiment_id = metadata[0]
|
||||
self._iteration = metadata[1]
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
"""Restores training state from a checkpoint object.
|
||||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
|
||||
out = io.BytesIO(obj)
|
||||
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for file_name, file_contents in data.items():
|
||||
with open(os.path.join(tmpdir, file_name), "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this trainable."""
|
||||
|
||||
if self._initialize_ok:
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this class."""
|
||||
def _save(self):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _setup(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
pass
|
||||
|
||||
def _stop(self):
|
||||
"""Subclasses should override this for any cleanup on stop."""
|
||||
pass
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
from ray.tune.script_runner import ScriptRunner
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
|
||||
class WrappedFunc(ScriptRunner):
|
||||
class WrappedFunc(FunctionRunner):
|
||||
def _trainable_func(self):
|
||||
return train_func
|
||||
|
||||
|
||||
@@ -2,18 +2,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import ray
|
||||
import os
|
||||
|
||||
from collections import namedtuple
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import NoopLogger, UnifiedLogger
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
||||
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
|
||||
|
||||
class Resources(
|
||||
@@ -106,6 +109,7 @@ class Trial(object):
|
||||
self.location = None
|
||||
self.logdir = None
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
|
||||
def start(self):
|
||||
@@ -293,8 +297,10 @@ class Trial(object):
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
self.last_result = result
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import ray
|
||||
import time
|
||||
@@ -13,6 +14,9 @@ from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
@@ -127,9 +131,40 @@ class TrialRunner(object):
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self._trials.append(trial)
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
messages = self._debug_messages()
|
||||
states = collections.defaultdict(set)
|
||||
limit_per_state = collections.Counter()
|
||||
for t in self._trials:
|
||||
states[t.status].add(t)
|
||||
|
||||
# Show at most max_debug total, but divide the limit fairly
|
||||
while max_debug > 0:
|
||||
start_num = max_debug
|
||||
for s in states:
|
||||
if limit_per_state[s] >= len(states[s]):
|
||||
continue
|
||||
max_debug -= 1
|
||||
limit_per_state[s] += 1
|
||||
if max_debug == start_num:
|
||||
break
|
||||
|
||||
for local_dir in sorted(set([t.local_dir for t in self._trials])):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
messages.append("{} trials:".format(state))
|
||||
for t in sorted(
|
||||
trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
messages.append(" - {}:\t{}".format(t, t.progress_string()))
|
||||
if len(trials) > limit:
|
||||
messages.append(" ... {} more not shown".format(
|
||||
len(trials) - limit))
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _debug_messages(self):
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
if self._resources_initialized:
|
||||
@@ -139,13 +174,7 @@ class TrialRunner(object):
|
||||
self._avail_resources.cpu,
|
||||
self._committed_resources.gpu,
|
||||
self._avail_resources.gpu))
|
||||
for local_dir in sorted(set([t.local_dir for t in self._trials])):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
for t in self._trials:
|
||||
if t.local_dir == local_dir:
|
||||
messages.append(
|
||||
" - {}:\t{}".format(t, t.progress_string()))
|
||||
return "\n".join(messages) + "\n"
|
||||
return messages
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
||||
Executable → Regular
+14
-54
@@ -1,55 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import ray
|
||||
import time
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
MNIST tuning example:
|
||||
./tune.py -f examples/tune_mnist_ray.yaml
|
||||
"""
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Tune hyperparameters with Ray.",
|
||||
epilog=EXAMPLE_USAGE)
|
||||
|
||||
# See also the base parser definition in ray/tune/config_parser.py
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--num-cpus", default=None, type=int,
|
||||
help="Number of CPUs to allocate to Ray.")
|
||||
parser.add_argument("--num-gpus", default=None, type=int,
|
||||
help="Number of GPUs to allocate to Ray.")
|
||||
parser.add_argument("--scheduler", default="FIFO", type=str,
|
||||
help="FIFO, MedianStopping, or HyperBand")
|
||||
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
parser.add_argument("--server", default=False, type=bool,
|
||||
help="Option to launch Tune Server")
|
||||
parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT,
|
||||
type=int, help="Option to launch Tune Server")
|
||||
parser.add_argument("-f", "--config-file", required=True, type=str,
|
||||
help="Read experiment options from this JSON/YAML file.")
|
||||
|
||||
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
@@ -67,7 +31,11 @@ def _make_scheduler(args):
|
||||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, **ray_args):
|
||||
server_port=TuneServer.DEFAULT_PORT):
|
||||
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa # pylint: disable=unused-import
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
@@ -77,13 +45,16 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
runner.add_trial(trial)
|
||||
print(runner.debug_string())
|
||||
|
||||
ray.init(**ray_args)
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
last_debug = 0
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
print(runner.debug_string())
|
||||
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
|
||||
print(runner.debug_string())
|
||||
last_debug = time.time()
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
for trial in runner.get_trials():
|
||||
# TODO(rliaw): What about errored?
|
||||
@@ -91,14 +62,3 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
raise TuneError("Trial did not complete", trial)
|
||||
|
||||
return runner.get_trials()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import yaml
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
with open(args.config_file) as f:
|
||||
experiments = yaml.load(f)
|
||||
run_experiments(
|
||||
experiments, _make_scheduler(args), with_server=args.server,
|
||||
server_port=args.server_port, redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
|
||||
Reference in New Issue
Block a user