mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[tune] Cross-Framework Compatibility (#2646)
This commit is a first pass at restructuring the Trial execution logic to support running on multiple frameworks.
This commit is contained in:
@@ -0,0 +1,274 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import ray
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.trial import Trial, Resources, Checkpoint
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
|
||||
|
||||
class RayTrialExecutor(TrialExecutor):
|
||||
"""An implemention of TrialExecutor based on Ray."""
|
||||
|
||||
def __init__(self, queue_trials=False):
|
||||
super(RayTrialExecutor, self).__init__(queue_trials)
|
||||
self._running = {} # TODO
|
||||
# Since trial resume after paused should not run
|
||||
# trial.train.remote(), thus no more new remote object id generated.
|
||||
# We use self._paused to store paused trials here.
|
||||
self._paused = {}
|
||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||
self._committed_resources = Resources(cpu=0, gpu=0)
|
||||
self._resources_initialized = False
|
||||
|
||||
def _setup_runner(self, trial):
|
||||
cls = ray.remote(
|
||||
num_cpus=trial.resources.cpu,
|
||||
num_gpus=trial.resources.gpu)(trial._get_trainable_cls())
|
||||
|
||||
trial.init_logger()
|
||||
remote_logdir = trial.logdir
|
||||
|
||||
def logger_creator(config):
|
||||
# Set the working dir in the remote process, for user file writes
|
||||
if not os.path.exists(remote_logdir):
|
||||
os.makedirs(remote_logdir)
|
||||
os.chdir(remote_logdir)
|
||||
return NoopLogger(config, remote_logdir)
|
||||
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
return cls.remote(config=trial.config, logger_creator=logger_creator)
|
||||
|
||||
def _train(self, trial):
|
||||
"""Start one iteration of training and save remote id."""
|
||||
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
remote = trial.runner.train.remote()
|
||||
self._running[remote] = trial
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None):
|
||||
prior_status = trial.status
|
||||
trial.status = Trial.RUNNING
|
||||
trial.runner = self._setup_runner(trial)
|
||||
if not self.restore(trial, checkpoint):
|
||||
return
|
||||
if prior_status == Trial.PAUSED:
|
||||
# If prev status is PAUSED, self._paused stores its remote_id.
|
||||
remote_id = self._find_item(self._paused, trial)[0]
|
||||
self._paused.pop(remote_id)
|
||||
self._running[remote_id] = trial
|
||||
else:
|
||||
self._train(trial)
|
||||
|
||||
def _stop_trial(self, trial, error=False, error_msg=None,
|
||||
stop_logger=True):
|
||||
"""Stops this trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources. If stopping the
|
||||
trial fails, the run will be marked as terminated in error, but no
|
||||
exception will be thrown.
|
||||
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
|
||||
if error:
|
||||
trial.status = Trial.ERROR
|
||||
else:
|
||||
trial.status = Trial.TERMINATED
|
||||
|
||||
try:
|
||||
trial.write_error_log(error_msg)
|
||||
if hasattr(trial, 'runner') and trial.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(trial.runner.stop.remote())
|
||||
stop_tasks.append(trial.runner.__ray_terminate__.remote())
|
||||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
except Exception:
|
||||
print("Error stopping runner:", traceback.format_exc())
|
||||
trial.status = Trial.ERROR
|
||||
finally:
|
||||
trial.runner = None
|
||||
|
||||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
def start_trial(self, trial, checkpoint_obj=None):
|
||||
"""Starts the trial."""
|
||||
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint_obj)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error starting runner, retrying:", error_msg)
|
||||
time.sleep(2)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
try:
|
||||
self._start_trial(trial)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error starting runner, abort:", error_msg)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
# note that we don't return the resources, since they may
|
||||
# have been lost
|
||||
|
||||
def _find_item(self, dictionary, item):
|
||||
out = [rid for rid, t in dictionary.items() if t is item]
|
||||
return out
|
||||
|
||||
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
|
||||
"""Only returns resources if resources allocated."""
|
||||
prior_status = trial.status
|
||||
self._stop_trial(
|
||||
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
|
||||
if prior_status == Trial.RUNNING:
|
||||
self._return_resources(trial.resources)
|
||||
out = self._find_item(self._running, trial)
|
||||
for result_id in out:
|
||||
self._running.pop(result_id)
|
||||
|
||||
def continue_training(self, trial):
|
||||
"""Continues the training of this trial."""
|
||||
|
||||
self._train(trial)
|
||||
|
||||
def pause_trial(self, trial):
|
||||
"""Pauses the trial."""
|
||||
|
||||
remote_id = self._find_item(self._running, trial)[0]
|
||||
self._paused[remote_id] = trial
|
||||
super(RayTrialExecutor, self).pause_trial(trial)
|
||||
|
||||
def get_running_trials(self):
|
||||
"""Returns the running trials."""
|
||||
|
||||
return list(self._running.values())
|
||||
|
||||
def fetch_one_result(self):
|
||||
"""Fetches one result of the running trials."""
|
||||
|
||||
[result_id], _ = ray.wait(list(self._running))
|
||||
trial = self._running.pop(result_id)
|
||||
result = None
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
except Exception as e:
|
||||
print("fetch_one_result failed:", traceback.format_exc())
|
||||
|
||||
return trial, result
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu + resources.cpu_total(),
|
||||
self._committed_resources.gpu + resources.gpu_total())
|
||||
|
||||
def _return_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu - resources.cpu_total(),
|
||||
self._committed_resources.gpu - resources.gpu_total())
|
||||
assert self._committed_resources.cpu >= 0
|
||||
assert self._committed_resources.gpu >= 0
|
||||
|
||||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
if ray.worker.global_worker.use_raylet:
|
||||
# TODO(rliaw): Remove once raylet flag is swapped
|
||||
num_cpus = sum(cl['Resources']['CPU'] for cl in clients)
|
||||
num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients)
|
||||
else:
|
||||
local_schedulers = [
|
||||
entry for client in clients.values() for entry in client
|
||||
if (entry['ClientType'] == 'local_scheduler'
|
||||
and not entry['Deleted'])
|
||||
]
|
||||
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
|
||||
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)
|
||||
self._avail_resources = Resources(int(num_cpus), int(num_gpus))
|
||||
self._resources_initialized = True
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
|
||||
have_space = (resources.cpu_total() <= cpu_avail
|
||||
and resources.gpu_total() <= gpu_avail)
|
||||
|
||||
if have_space:
|
||||
return True
|
||||
|
||||
can_overcommit = self._queue_trials
|
||||
|
||||
if (resources.cpu_total() > 0 and cpu_avail <= 0) or \
|
||||
(resources.gpu_total() > 0 and gpu_avail <= 0):
|
||||
can_overcommit = False # requested resource is already saturated
|
||||
|
||||
if can_overcommit:
|
||||
print("WARNING:tune:allowing trial to start even though the "
|
||||
"cluster does not have enough free resources. Trial actors "
|
||||
"may appear to hang until enough resources are added to the "
|
||||
"cluster (e.g., via autoscaling). You can disable this "
|
||||
"behavior by specifying `queue_trials=False` in "
|
||||
"ray.tune.run_experiments().")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def debug_string(self):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
if self._resources_initialized:
|
||||
return "Resources requested: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu, self._avail_resources.cpu,
|
||||
self._committed_resources.gpu, self._avail_resources.gpu)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def on_step_begin(self):
|
||||
"""Before step() called, update the available resources."""
|
||||
|
||||
self._update_avail_resources()
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK):
|
||||
"""Saves the trial's state to a checkpoint."""
|
||||
trial._checkpoint.storage = storage
|
||||
if storage == Checkpoint.MEMORY:
|
||||
trial._checkpoint.value = trial.runner.save_to_object.remote()
|
||||
else:
|
||||
trial._checkpoint.value = ray.get(trial.runner.save.remote())
|
||||
return trial._checkpoint.value
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a given model checkpoint."""
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
checkpoint = trial._checkpoint
|
||||
if checkpoint is None or checkpoint.value is None:
|
||||
return True
|
||||
if trial.runner is None:
|
||||
print("Unable to restore - no runner")
|
||||
trial.status = Trial.ERROR
|
||||
return False
|
||||
try:
|
||||
value = checkpoint.value
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
assert type(value) != Checkpoint, type(value)
|
||||
ray.get(trial.runner.restore_from_object.remote(value))
|
||||
else:
|
||||
ray.get(trial.runner.restore.remote(value))
|
||||
return True
|
||||
except Exception:
|
||||
print("Error restoring runner:", traceback.format_exc())
|
||||
trial.status = Trial.ERROR
|
||||
return False
|
||||
@@ -185,7 +185,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
raise Exception("Trial with unexpected status encountered")
|
||||
if bracket.continue_trial(t):
|
||||
if t.status == Trial.PAUSED:
|
||||
t.unpause()
|
||||
trial_runner.trial_executor.unpause_trial(t)
|
||||
elif t.status == Trial.RUNNING:
|
||||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
|
||||
@@ -7,8 +7,8 @@ import math
|
||||
import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
@@ -187,7 +187,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
|
||||
if trial in upper_quantile:
|
||||
state.last_checkpoint = trial.checkpoint(to_object_store=True)
|
||||
state.last_checkpoint = trial_runner.trial_executor.save(
|
||||
trial, Checkpoint.MEMORY)
|
||||
self._num_checkpoints += 1
|
||||
else:
|
||||
state.last_checkpoint = None # not a top trial
|
||||
@@ -195,7 +196,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
if trial in lower_quantile:
|
||||
trial_to_clone = random.choice(upper_quantile)
|
||||
assert trial is not trial_to_clone
|
||||
self._exploit(trial, trial_to_clone)
|
||||
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
|
||||
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
@@ -203,7 +204,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def _exploit(self, trial, trial_to_clone):
|
||||
def _exploit(self, trial_executor, trial, trial_to_clone):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial."""
|
||||
|
||||
trial_state = self._trial_state[trial]
|
||||
@@ -220,11 +221,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
trial_state.last_score))
|
||||
# TODO(ekl) restarting the trial is expensive. We should implement a
|
||||
# lighter way reset() method that can alter the trial config.
|
||||
trial.stop(stop_logger=False)
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = make_experiment_tag(
|
||||
trial_state.orig_tag, new_config, self._hyperparam_mutations)
|
||||
trial.start(new_state.last_checkpoint)
|
||||
trial_executor.start_trial(
|
||||
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
||||
|
||||
class RayTrialExecutorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def _get_trials(self):
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"bar": {
|
||||
"grid_search": [True, False]
|
||||
},
|
||||
"foo": {
|
||||
"grid_search": [1, 2, 3]
|
||||
},
|
||||
},
|
||||
}, "grid_search")
|
||||
return list(trials)
|
||||
|
||||
def testStartStop(self):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
running = self.trial_executor.get_running_trials()
|
||||
self.assertEqual(1, len(running))
|
||||
self.trial_executor.stop_trial(trial)
|
||||
|
||||
def testSaveRestore(self):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.save(trial, Checkpoint.DISK)
|
||||
self.trial_executor.restore(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def generate_trials(self, spec, name):
|
||||
suggester = BasicVariantGenerator({name: spec})
|
||||
return suggester.next_trials()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -11,6 +11,7 @@ from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE
|
||||
@@ -667,12 +668,13 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
def testTrialStatus(self):
|
||||
ray.init()
|
||||
trial = Trial("__fake")
|
||||
trial_executor = RayTrialExecutor()
|
||||
self.assertEqual(trial.status, Trial.PENDING)
|
||||
trial.start()
|
||||
trial_executor.start_trial(trial)
|
||||
self.assertEqual(trial.status, Trial.RUNNING)
|
||||
trial.stop()
|
||||
trial_executor.stop_trial(trial)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
trial.stop(error=True)
|
||||
trial_executor.stop_trial(trial, error=True)
|
||||
self.assertEqual(trial.status, Trial.ERROR)
|
||||
|
||||
def testExperimentTagTruncation(self):
|
||||
@@ -681,6 +683,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
trial_executor = RayTrialExecutor()
|
||||
register_trainable("f1", train)
|
||||
|
||||
experiments = {
|
||||
@@ -697,16 +700,17 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
trial.start()
|
||||
trial_executor.start_trial(trial)
|
||||
self.assertLessEqual(len(trial.logdir), 200)
|
||||
trial.stop()
|
||||
trial_executor.stop_trial(trial)
|
||||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
||||
trial_executor = RayTrialExecutor()
|
||||
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf", resources=Resources(1, 0))
|
||||
try:
|
||||
trial.start()
|
||||
trial_executor.start_trial(trial)
|
||||
except Exception as e:
|
||||
self.assertIn("a class", str(e))
|
||||
|
||||
@@ -901,8 +905,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
|
||||
path = trials[0].checkpoint()
|
||||
path = runner.trial_executor.save(trials[0])
|
||||
kwargs["restore_path"] = path
|
||||
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
@@ -956,10 +959,10 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
|
||||
trials[0].pause()
|
||||
runner.trial_executor.pause_trial(trials[0])
|
||||
self.assertEqual(trials[0].status, Trial.PAUSED)
|
||||
|
||||
trials[0].resume()
|
||||
runner.trial_executor.resume_trial(trials[0])
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
@@ -969,6 +972,36 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
def testStepHook(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
|
||||
def on_step_begin(self):
|
||||
self._update_avail_resources()
|
||||
cnt = self.pre_step if hasattr(self, 'pre_step') else 0
|
||||
setattr(self, 'pre_step', cnt + 1)
|
||||
|
||||
def on_step_end(self):
|
||||
cnt = self.pre_step if hasattr(self, 'post_step') else 0
|
||||
setattr(self, 'post_step', 1 + cnt)
|
||||
|
||||
import types
|
||||
runner.trial_executor.on_step_begin = types.MethodType(
|
||||
on_step_begin, runner.trial_executor)
|
||||
runner.trial_executor.on_step_end = types.MethodType(
|
||||
on_step_end, runner.trial_executor)
|
||||
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
runner.step()
|
||||
self.assertEqual(runner.trial_executor.pre_step, 1)
|
||||
self.assertEqual(runner.trial_executor.post_step, 1)
|
||||
|
||||
def testStopTrial(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
|
||||
@@ -11,7 +11,8 @@ from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
TrialScheduler)
|
||||
from ray.tune.schedulers.pbt import explore
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial import Trial, Resources, Checkpoint
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
@@ -150,10 +151,29 @@ class EarlyStoppingSuite(unittest.TestCase):
|
||||
TrialScheduler.CONTINUE)
|
||||
|
||||
|
||||
class _MockTrialExecutor(TrialExecutor):
|
||||
def start_trial(self, trial, checkpoint_obj=None):
|
||||
trial.logger_running = True
|
||||
trial.restored_checkpoint = checkpoint_obj.value
|
||||
trial.status = Trial.RUNNING
|
||||
|
||||
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
|
||||
trial.status = Trial.ERROR if error else Trial.TERMINATED
|
||||
if stop_logger:
|
||||
trial.logger_running = False
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
pass
|
||||
|
||||
def save(self, trial, type=Checkpoint.DISK):
|
||||
return trial.trainable_name
|
||||
|
||||
|
||||
class _MockTrialRunner():
|
||||
def __init__(self, scheduler):
|
||||
self._scheduler_alg = scheduler
|
||||
self.trials = []
|
||||
self.trial_executor = _MockTrialExecutor()
|
||||
|
||||
def process_action(self, trial, action):
|
||||
if action == TrialScheduler.CONTINUE:
|
||||
@@ -161,7 +181,7 @@ class _MockTrialRunner():
|
||||
elif action == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
elif action == TrialScheduler.STOP:
|
||||
trial.stop()
|
||||
self.trial_executor.stop_trial(trial)
|
||||
|
||||
def stop_trial(self, trial):
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
@@ -169,7 +189,6 @@ class _MockTrialRunner():
|
||||
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
else:
|
||||
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
|
||||
|
||||
def add_trial(self, trial):
|
||||
@@ -198,7 +217,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def schedulerSetup(self, num_trials):
|
||||
"""Setup a scheduler and Runner with max Iter = 9
|
||||
"""Setup a scheduler and Runner with max Iter = 9.
|
||||
|
||||
Bracketing is placed as follows:
|
||||
(5, 81);
|
||||
@@ -214,7 +233,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
return sched, runner
|
||||
|
||||
def default_statistics(self):
|
||||
"""Default statistics for HyperBand"""
|
||||
"""Default statistics for HyperBand."""
|
||||
sched = HyperBandScheduler()
|
||||
res = {
|
||||
str(s): {
|
||||
@@ -232,8 +251,8 @@ class HyperbandSuite(unittest.TestCase):
|
||||
return int(np.ceil(n / sched._eta))
|
||||
|
||||
def basicSetup(self):
|
||||
"""Setup and verify full band.
|
||||
"""
|
||||
"""Setup and verify full band."""
|
||||
|
||||
stats = self.default_statistics()
|
||||
sched, _ = self.schedulerSetup(stats["max_trials"])
|
||||
|
||||
@@ -299,6 +318,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
def testSuccessiveHalving(self):
|
||||
"""Setup full band, then iterate through last bracket (n=81)
|
||||
to make sure successive halving is correct."""
|
||||
|
||||
stats = self.default_statistics()
|
||||
sched, mock_runner = self.schedulerSetup(stats["max_trials"])
|
||||
big_bracket = sched._state["bracket"]
|
||||
@@ -360,6 +380,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
def testTrialErrored(self):
|
||||
"""If a trial errored, make sure successive halving still happens"""
|
||||
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 3
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
@@ -510,7 +531,7 @@ class HyperbandSuite(unittest.TestCase):
|
||||
self.assertLess(current_length, 27)
|
||||
|
||||
def testRemove(self):
|
||||
"""Test with 4: start 1, remove 1 pending, add 2, remove 1 pending"""
|
||||
"""Test with 4: start 1, remove 1 pending, add 2, remove 1 pending."""
|
||||
sched, runner = self.schedulerSetup(4)
|
||||
trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id)
|
||||
runner._launch_trial(trials[0])
|
||||
@@ -542,17 +563,6 @@ class _MockTrial(Trial):
|
||||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
return self.trainable_name
|
||||
|
||||
def start(self, checkpoint=None):
|
||||
self.logger_running = True
|
||||
self.restored_checkpoint = checkpoint
|
||||
|
||||
def stop(self, stop_logger=False):
|
||||
if stop_logger:
|
||||
self.logger_running = False
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -88,7 +88,7 @@ class TuneServerSuite(unittest.TestCase):
|
||||
self.assertEqual(len(all_trials), 2)
|
||||
|
||||
def testStopTrial(self):
|
||||
"""Check if Stop Trial works"""
|
||||
"""Check if Stop Trial works."""
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(2):
|
||||
runner.step()
|
||||
|
||||
+54
-170
@@ -2,16 +2,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import ray
|
||||
import os
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import NoopLogger, UnifiedLogger, pretty_print
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
@@ -39,7 +38,9 @@ class Resources(
|
||||
launch additional Ray actors that use CPUs.
|
||||
extra_gpu (int): Extra GPUs to reserve in case the trial needs to
|
||||
launch additional Ray actors that use GPUs.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0):
|
||||
@@ -62,6 +63,30 @@ def has_trainable(trainable_name):
|
||||
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
|
||||
|
||||
|
||||
class Checkpoint(object):
|
||||
"""Describes a checkpoint of trial state.
|
||||
|
||||
Checkpoint may be saved in different storage.
|
||||
|
||||
Attributes:
|
||||
storage (str): Storage type.
|
||||
value (str): If storage==MEMORY,value is a Python object.
|
||||
If storage==DISK,value is a path points to the checkpoint in disk.
|
||||
"""
|
||||
|
||||
MEMORY = "memory"
|
||||
DISK = "disk"
|
||||
|
||||
def __init__(self, storage, value):
|
||||
self.storage = storage
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def from_object(value=None):
|
||||
"""Creates a checkpoint from a Python object."""
|
||||
return Checkpoint(Checkpoint.MEMORY, value)
|
||||
|
||||
|
||||
class Trial(object):
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
@@ -110,16 +135,15 @@ class Trial(object):
|
||||
resources
|
||||
or self._get_trainable_cls().default_resource_request(self.config))
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.upload_dir = upload_dir
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self._checkpoint_path = restore_path
|
||||
self._checkpoint_obj = None
|
||||
self.runner = None
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self._checkpoint = Checkpoint(
|
||||
storage=Checkpoint.DISK, value=restore_path)
|
||||
self.status = Trial.PENDING
|
||||
self.location = None
|
||||
self.logdir = None
|
||||
@@ -136,96 +160,34 @@ class Trial(object):
|
||||
def generate_id(cls):
|
||||
return binary_to_hex(random_string())[:8]
|
||||
|
||||
def start(self, checkpoint_obj=None):
|
||||
"""Starts this trial.
|
||||
def init_logger(self):
|
||||
"""Init logger."""
|
||||
|
||||
If an error is encountered when starting the trial, an exception will
|
||||
be thrown.
|
||||
if not self.result_logger:
|
||||
if not os.path.exists(self.local_dir):
|
||||
os.makedirs(self.local_dir)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
self.result_logger = UnifiedLogger(self.config, self.logdir,
|
||||
self.upload_dir)
|
||||
|
||||
Args:
|
||||
checkpoint_obj (obj): Optional checkpoint to resume from.
|
||||
"""
|
||||
def close_logger(self):
|
||||
"""Close logger."""
|
||||
|
||||
self._setup_runner()
|
||||
if checkpoint_obj:
|
||||
self.restore_from_obj(checkpoint_obj)
|
||||
elif self._checkpoint_path:
|
||||
self.restore_from_path(self._checkpoint_path)
|
||||
elif self._checkpoint_obj:
|
||||
self.restore_from_obj(self._checkpoint_obj)
|
||||
|
||||
def stop(self, error=False, error_msg=None, stop_logger=True):
|
||||
"""Stops this trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources. If stopping the
|
||||
trial fails, the run will be marked as terminated in error, but no
|
||||
exception will be thrown.
|
||||
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
|
||||
if error:
|
||||
self.status = Trial.ERROR
|
||||
else:
|
||||
self.status = Trial.TERMINATED
|
||||
|
||||
try:
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1
|
||||
error_file = os.path.join(self.logdir,
|
||||
"error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
self.error_file = error_file
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
stop_tasks.append(self.runner.__ray_terminate__.remote())
|
||||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
except Exception:
|
||||
print("Error stopping runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
finally:
|
||||
self.runner = None
|
||||
|
||||
if stop_logger and self.result_logger:
|
||||
if self.result_logger:
|
||||
self.result_logger.close()
|
||||
self.result_logger = None
|
||||
|
||||
def pause(self):
|
||||
"""We want to release resources (specifically GPUs) when pausing an
|
||||
experiment. This results in a state similar to TERMINATED."""
|
||||
|
||||
assert self.status == Trial.RUNNING, self.status
|
||||
try:
|
||||
self.checkpoint(to_object_store=True)
|
||||
self.stop(stop_logger=False)
|
||||
self.status = Trial.PAUSED
|
||||
except Exception:
|
||||
print("Error pausing runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def unpause(self):
|
||||
"""Sets PAUSED trial to pending to allow scheduler to start."""
|
||||
assert self.status == Trial.PAUSED, self.status
|
||||
self.status = Trial.PENDING
|
||||
|
||||
def resume(self):
|
||||
"""Resume PAUSED trials. This is a blocking call."""
|
||||
|
||||
assert self.status == Trial.PAUSED, self.status
|
||||
self.start()
|
||||
|
||||
def train_remote(self):
|
||||
"""Returns Ray future for one iteration of training."""
|
||||
|
||||
assert self.status == Trial.RUNNING, self.status
|
||||
return self.runner.train.remote()
|
||||
def write_error_log(self, error_msg):
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1 # may be moved to outer scope?
|
||||
error_file = os.path.join(self.logdir,
|
||||
"error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
self.error_file = error_file
|
||||
|
||||
def should_stop(self, result):
|
||||
"""Whether the given result meets this trial's stopping criteria."""
|
||||
@@ -294,57 +256,7 @@ class Trial(object):
|
||||
if self.error_file else "")
|
||||
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint_path is not None or \
|
||||
self._checkpoint_obj is not None
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
"""Checkpoints the state of this trial.
|
||||
|
||||
Args:
|
||||
to_object_store (bool): Whether to save to the Ray object store
|
||||
(async) vs a path on local disk (sync).
|
||||
"""
|
||||
|
||||
obj = None
|
||||
path = None
|
||||
if to_object_store:
|
||||
obj = self.runner.save_to_object.remote()
|
||||
else:
|
||||
path = ray.get(self.runner.save.remote())
|
||||
self._checkpoint_path = path
|
||||
self._checkpoint_obj = obj
|
||||
|
||||
if self.verbose:
|
||||
print("Saved checkpoint for {} to {}".format(self, path or obj))
|
||||
return path or obj
|
||||
|
||||
def restore_from_path(self, path):
|
||||
"""Restores runner state from specified path.
|
||||
|
||||
Args:
|
||||
path (str): A path where state will be restored.
|
||||
"""
|
||||
|
||||
if self.runner is None:
|
||||
print("Unable to restore - no runner")
|
||||
else:
|
||||
try:
|
||||
ray.get(self.runner.restore.remote(path))
|
||||
except Exception:
|
||||
print("Error restoring runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def restore_from_obj(self, obj):
|
||||
"""Restores runner state from the specified object."""
|
||||
|
||||
if self.runner is None:
|
||||
print("Unable to restore - no runner")
|
||||
else:
|
||||
try:
|
||||
ray.get(self.runner.restore_from_object.remote(obj))
|
||||
except Exception:
|
||||
print("Error restoring runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
return self._checkpoint.value is not None
|
||||
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
@@ -357,34 +269,6 @@ class Trial(object):
|
||||
self.last_result = result
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
def _setup_runner(self):
|
||||
self.status = Trial.RUNNING
|
||||
cls = ray.remote(
|
||||
num_cpus=self.resources.cpu,
|
||||
num_gpus=self.resources.gpu)(self._get_trainable_cls())
|
||||
if not self.result_logger:
|
||||
if not os.path.exists(self.local_dir):
|
||||
os.makedirs(self.local_dir)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
self.result_logger = UnifiedLogger(self.config, self.logdir,
|
||||
self.upload_dir)
|
||||
remote_logdir = self.logdir
|
||||
|
||||
def logger_creator(config):
|
||||
# Set the working dir in the remote process, for user file writes
|
||||
if not os.path.exists(remote_logdir):
|
||||
os.makedirs(remote_logdir)
|
||||
os.chdir(remote_logdir)
|
||||
return NoopLogger(config, remote_logdir)
|
||||
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
self.runner = cls.remote(
|
||||
config=self.config, logger_creator=logger_creator)
|
||||
|
||||
def _get_trainable_cls(self):
|
||||
return ray.tune.registry._global_registry.get(
|
||||
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import traceback
|
||||
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
||||
|
||||
class TrialExecutor(object):
|
||||
"""Manages platform-specific details such as resource handling
|
||||
and starting/stopping trials.
|
||||
"""
|
||||
|
||||
def __init__(self, queue_trials=False):
|
||||
"""Initializes a new TrialExecutor.
|
||||
|
||||
Args:
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
"""
|
||||
self._queue_trials = queue_trials
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"has_resources() method")
|
||||
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
"""Starts the trial restoring from checkpoint if checkpoint != None.
|
||||
|
||||
If an error is encountered when starting the trial, an exception will
|
||||
be thrown.
|
||||
|
||||
Args:
|
||||
checkpoint(Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"start_trial() method")
|
||||
|
||||
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
|
||||
"""Stops the trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources.
|
||||
If stopping the trial fails, the run will be marked as terminated
|
||||
in error, but no exception will be thrown.
|
||||
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"stop_trial() method")
|
||||
|
||||
def restart_trial(self, trial, error_msg=None):
|
||||
"""Restarts the trial.
|
||||
|
||||
The state of the trial should restore from the last checkpoint.
|
||||
|
||||
Args:
|
||||
error_msg (str): Optional error message.
|
||||
"""
|
||||
try:
|
||||
print("Attempting to recover trial state from last checkpoint")
|
||||
self.stop_trial(
|
||||
trial, error=True, error_msg=error_msg, stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
self.start_trial(trial)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error recovering trial from checkpoint, abort:", error_msg)
|
||||
self.stop_trial(trial, error=True, error_msg=error_msg)
|
||||
|
||||
def continue_training(self, trial):
|
||||
"""Continues the training of this trial."""
|
||||
pass
|
||||
|
||||
def pause_trial(self, trial):
|
||||
"""Pauses the trial.
|
||||
|
||||
We want to release resources (specifically GPUs) when pausing an
|
||||
experiment. This results in PAUSED state that similar to TERMINATED.
|
||||
"""
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
try:
|
||||
self.save(trial, Checkpoint.MEMORY)
|
||||
self.stop_trial(trial, stop_logger=False)
|
||||
trial.status = Trial.PAUSED
|
||||
except Exception:
|
||||
print("Error pausing runner:", traceback.format_exc())
|
||||
trial.status = Trial.ERROR
|
||||
|
||||
def unpause_trial(self, trial):
|
||||
"""Sets PAUSED trial to pending to allow scheduler to start."""
|
||||
assert trial.status == Trial.PAUSED, trial.status
|
||||
trial.status = Trial.PENDING
|
||||
|
||||
def resume_trial(self, trial):
|
||||
"""Resumes PAUSED trials. This is a blocking call."""
|
||||
|
||||
assert trial.status == Trial.PAUSED, trial.status
|
||||
self.start_trial(trial)
|
||||
|
||||
def get_running_trials(self):
|
||||
"""Returns all running trials."""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"get_running_trials() method")
|
||||
|
||||
def on_step_begin(self):
|
||||
"""A hook called before running one step of the trial event loop."""
|
||||
pass
|
||||
|
||||
def on_step_end(self):
|
||||
"""A hook called after running one step of the trial event loop."""
|
||||
pass
|
||||
|
||||
def fetch_one_result(self):
|
||||
"""Fetches one result from running trials.
|
||||
|
||||
It's a blocking call waits until one result is ready.
|
||||
|
||||
Return:
|
||||
A tuple of (trial, result). If fetch result failed,
|
||||
return (trial, None) other than raise Exception.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"fetch_one_result() method")
|
||||
|
||||
def debug_string(self):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
pass
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a checkpoint.
|
||||
|
||||
If checkpoint is None, try to restore from trial._checkpoint.
|
||||
If restoring fails, the trial status will be set to ERROR.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be restored.
|
||||
checkpoint (Checkpoint): Checkpoint to restore from.
|
||||
|
||||
Return:
|
||||
False if error occurred, otherwise return True.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"restore() method")
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK):
|
||||
"""Saves training state of this trial to a checkpoint.
|
||||
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
storage (str): Where to store the checkpoint. Defaults to DISK.
|
||||
|
||||
Return:
|
||||
A Python object if storage==Checkpoint.MEMORY otherwise
|
||||
a path to the checkpoint.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"save() method")
|
||||
+34
-130
@@ -4,15 +4,15 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import ray
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
@@ -45,7 +45,8 @@ class TrialRunner(object):
|
||||
launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
queue_trials=False):
|
||||
queue_trials=False,
|
||||
trial_executor=None):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
Args:
|
||||
@@ -60,14 +61,13 @@ class TrialRunner(object):
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
|
||||
"""
|
||||
self._search_alg = search_alg
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self._trials = []
|
||||
self._running = {}
|
||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||
self._committed_resources = Resources(cpu=0, gpu=0)
|
||||
self._resources_initialized = False
|
||||
self.trial_executor = trial_executor or \
|
||||
RayTrialExecutor(queue_trials=queue_trials)
|
||||
|
||||
# For debugging, it may be useful to halt trials after some time has
|
||||
# elapsed. TODO(ekl) consider exposing this in the API.
|
||||
@@ -98,10 +98,11 @@ class TrialRunner(object):
|
||||
Callers should typically run this method repeatedly in a loop. They
|
||||
may inspect or modify the runner's state in between calls to step().
|
||||
"""
|
||||
self.trial_executor.on_step_begin()
|
||||
next_trial = self._get_next_trial()
|
||||
if next_trial is not None:
|
||||
self._launch_trial(next_trial)
|
||||
elif self._running:
|
||||
self.trial_executor.start_trial(next_trial)
|
||||
elif self.trial_executor.get_running_trials():
|
||||
self._process_events()
|
||||
else:
|
||||
for trial in self._trials:
|
||||
@@ -109,13 +110,13 @@ class TrialRunner(object):
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError(
|
||||
("Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster only has {} "
|
||||
"available. Pass `queue_trials=True` in "
|
||||
"trial requested {} but the cluster summary: {} "
|
||||
"Pass `queue_trials=True` in "
|
||||
"ray.tune.run_experiments() or on the command "
|
||||
"line to queue trials until the cluster scales "
|
||||
"up. {}").format(
|
||||
trial.resources.summary_string(),
|
||||
self._avail_resources.summary_string(),
|
||||
self.trial_executor.debug_string(),
|
||||
trial._get_trainable_cls().resource_help(
|
||||
trial.config)))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
@@ -129,6 +130,7 @@ class TrialRunner(object):
|
||||
|
||||
if self.is_finished():
|
||||
self._server.shutdown()
|
||||
self.trial_executor.on_step_end()
|
||||
|
||||
def get_trial(self, tid):
|
||||
trial = [t for t in self._trials if t.trial_id == tid]
|
||||
@@ -189,41 +191,12 @@ class TrialRunner(object):
|
||||
def _debug_messages(self):
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
if self._resources_initialized:
|
||||
messages.append(
|
||||
"Resources requested: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu, self._avail_resources.cpu,
|
||||
self._committed_resources.gpu, self._avail_resources.gpu))
|
||||
messages.append(self.trial_executor.debug_string())
|
||||
return messages
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
|
||||
have_space = (resources.cpu_total() <= cpu_avail
|
||||
and resources.gpu_total() <= gpu_avail)
|
||||
|
||||
if have_space:
|
||||
return True
|
||||
|
||||
can_overcommit = self._queue_trials
|
||||
|
||||
if ((resources.cpu_total() > 0 and cpu_avail <= 0)
|
||||
or (resources.gpu_total() > 0 and gpu_avail <= 0)):
|
||||
can_overcommit = False # requested resource is already saturated
|
||||
|
||||
if can_overcommit:
|
||||
print("WARNING:tune:allowing trial to start even though the "
|
||||
"cluster does not have enough free resources. Trial actors "
|
||||
"may appear to hang until enough resources are added to the "
|
||||
"cluster (e.g., via autoscaling). You can disable this "
|
||||
"behavior by specifying `queue_trials=False` in "
|
||||
"ray.tune.run_experiments().")
|
||||
return True
|
||||
|
||||
return False
|
||||
return self.trial_executor.has_resources(resources)
|
||||
|
||||
def _get_next_trial(self):
|
||||
"""Replenishes queue.
|
||||
@@ -231,38 +204,17 @@ class TrialRunner(object):
|
||||
Blocks if all trials queued have finished, but search algorithm is
|
||||
still not finished.
|
||||
"""
|
||||
self._update_avail_resources()
|
||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||
wait_for_trial = trials_done and not self._search_alg.is_finished()
|
||||
self._update_trial_queue(blocking=wait_for_trial)
|
||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||
return trial
|
||||
|
||||
def _launch_trial(self, trial):
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error starting runner, retrying:", error_msg)
|
||||
time.sleep(2)
|
||||
trial.stop(error=True, error_msg=error_msg)
|
||||
try:
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error starting runner, abort:", error_msg)
|
||||
trial.stop(error=True, error_msg=error_msg)
|
||||
# note that we don't return the resources, since they may
|
||||
# have been lost
|
||||
|
||||
def _process_events(self):
|
||||
[result_id], _ = ray.wait(list(self._running))
|
||||
trial = self._running.pop(result_id)
|
||||
trial, result = self.trial_executor.fetch_one_result()
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
if result is None:
|
||||
raise ValueError("fetch_one_result failed")
|
||||
self._total_time += result[TIME_THIS_ITER_S]
|
||||
|
||||
if trial.should_stop(result):
|
||||
@@ -284,12 +236,12 @@ class TrialRunner(object):
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
if trial.should_checkpoint():
|
||||
# TODO(rliaw): This is a blocking call
|
||||
trial.checkpoint()
|
||||
self._running[trial.train_remote()] = trial
|
||||
self.trial_executor.save(trial)
|
||||
self.trial_executor.continue_training(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
self._stop_trial(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
@@ -304,30 +256,28 @@ class TrialRunner(object):
|
||||
self._scheduler_alg.on_trial_error(self, trial)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, error=True)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
self.trial_executor.stop_trial(trial, True, error_msg)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
try:
|
||||
print("Attempting to recover trial state from last checkpoint")
|
||||
trial.stop(error=True, error_msg=error_msg, stop_logger=False)
|
||||
trial.result_logger.flush() # make sure checkpoint is synced
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
self.trial_executor.restart_trial(trial, error_msg)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
print("Error recovering trial from checkpoint, abort:", error_msg)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
self.trial_executor.stop_trial(trial, True, error_msg=error_msg)
|
||||
|
||||
def _update_trial_queue(self, blocking=False, timeout=600):
|
||||
"""Adds next trials to queue if possible.
|
||||
|
||||
Note that the timeout is currently unexposed to the user.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
blocking (bool): Blocks until either a trial is available
|
||||
or the Runner finishes (i.e., timeout or search algorithm
|
||||
finishes).
|
||||
timeout (int): Seconds before blocking times out."""
|
||||
timeout (int): Seconds before blocking times out.
|
||||
"""
|
||||
trials = self._search_alg.next_trials()
|
||||
if blocking and not trials:
|
||||
start = time.time()
|
||||
@@ -340,18 +290,6 @@ class TrialRunner(object):
|
||||
for trial in trials:
|
||||
self.add_trial(trial)
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu + resources.cpu_total(),
|
||||
self._committed_resources.gpu + resources.gpu_total())
|
||||
|
||||
def _return_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
self._committed_resources.cpu - resources.cpu_total(),
|
||||
self._committed_resources.gpu - resources.gpu_total())
|
||||
assert self._committed_resources.cpu >= 0
|
||||
assert self._committed_resources.gpu >= 0
|
||||
|
||||
def request_stop_trial(self, trial):
|
||||
self._stop_queue.append(trial)
|
||||
|
||||
@@ -379,13 +317,10 @@ class TrialRunner(object):
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, early_terminated=True)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
# NOTE: There should only be one...
|
||||
result_id = [
|
||||
rid for rid, t in self._running.items() if t is trial
|
||||
][0]
|
||||
self._running.pop(result_id)
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
_, result = self.trial_executor.fetch_one_result()
|
||||
if result is None:
|
||||
raise ValueError("fetch_one_result failed")
|
||||
trial.update_last_result(result, terminate=True)
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._search_alg.on_trial_complete(
|
||||
@@ -397,35 +332,4 @@ class TrialRunner(object):
|
||||
self._search_alg.on_trial_complete(trial.trial_id, error=True)
|
||||
error = True
|
||||
|
||||
self._stop_trial(trial, error=error, error_msg=error_msg)
|
||||
|
||||
def _stop_trial(self, trial, error=False, error_msg=None):
|
||||
"""Only returns resources if resources allocated."""
|
||||
prior_status = trial.status
|
||||
trial.stop(error=error, error_msg=error_msg)
|
||||
if prior_status == Trial.RUNNING:
|
||||
self._return_resources(trial.resources)
|
||||
|
||||
def _pause_trial(self, trial):
|
||||
"""Only returns resources if resources allocated."""
|
||||
prior_status = trial.status
|
||||
trial.pause()
|
||||
if prior_status == Trial.RUNNING:
|
||||
self._return_resources(trial.resources)
|
||||
|
||||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
if ray.worker.global_worker.use_raylet:
|
||||
# TODO(rliaw): Remove once raylet flag is swapped
|
||||
num_cpus = sum(cl['Resources']['CPU'] for cl in clients)
|
||||
num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients)
|
||||
else:
|
||||
local_schedulers = [
|
||||
entry for client in clients.values() for entry in client
|
||||
if (entry['ClientType'] == 'local_scheduler'
|
||||
and not entry['Deleted'])
|
||||
]
|
||||
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
|
||||
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)
|
||||
self._avail_resources = Resources(int(num_cpus), int(num_gpus))
|
||||
self._resources_initialized = True
|
||||
self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg)
|
||||
|
||||
@@ -35,7 +35,8 @@ def run_experiments(experiments=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
queue_trials=False):
|
||||
queue_trials=False,
|
||||
trial_executor=None):
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Args:
|
||||
@@ -54,6 +55,7 @@ def run_experiments(experiments=None,
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
trial_executor (TrialExecutor): Manage the execution of trials.
|
||||
|
||||
Examples:
|
||||
>>> experiment_spec = Experiment("experiment", my_func)
|
||||
@@ -73,7 +75,9 @@ def run_experiments(experiments=None,
|
||||
|
||||
Returns:
|
||||
List of Trial objects, holding data for each executed trial.
|
||||
|
||||
"""
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
@@ -88,7 +92,8 @@ def run_experiments(experiments=None,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
queue_trials=queue_trials)
|
||||
queue_trials=queue_trials,
|
||||
trial_executor=trial_executor)
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user