[tune] Dynamic Resources for Trials (#3974)

## What do these changes do?

Provides a small helper function for modifying the resource requirements of a trial.

Also implements the following:
 - setting the last_result to be {} instead of None
 - Adding a shuffle to the BasicVariantGenerator
This commit is contained in:
Richard Liaw
2019-03-03 11:38:36 -08:00
committed by GitHub
parent ba03048254
commit fb1369d96f
5 changed files with 69 additions and 14 deletions
+10 -1
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import itertools
import random
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
@@ -23,11 +24,17 @@ class BasicVariantGenerator(SearchAlgorithm):
>>> searcher.is_finished == True
"""
def __init__(self):
def __init__(self, shuffle=False):
"""Initializes the Variant Generator.
Arguments:
shuffle (bool): Shuffles the generated list of configurations.
"""
self._parser = make_parser()
self._trial_generator = []
self._counter = 0
self._finished = False
self._shuffle = shuffle
def add_configurations(self, experiments):
"""Chains generator given experiment specifications.
@@ -48,6 +55,8 @@ class BasicVariantGenerator(SearchAlgorithm):
trials (list): Returns a list of trials.
"""
trials = list(self._trial_generator)
if self._shuffle:
random.shuffle(trials)
self._finished = True
return trials
+3 -3
View File
@@ -153,7 +153,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
runner.add_trial(t)
runner.step() # start
runner.step() # 1 result
assert t.last_result is not None
assert t.last_result
node2 = cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
@@ -384,7 +384,7 @@ tune.run_experiments(
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
last_res = trials[0].last_result
if last_res is not None and last_res["training_iteration"]:
if last_res and last_res.get("training_iteration"):
break
time.sleep(0.3)
@@ -476,7 +476,7 @@ tune.run_experiments(
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
last_res = trials[0].last_result
if last_res is not None and last_res["training_iteration"] == 3:
if last_res and last_res.get("training_iteration") == 3:
break
time.sleep(0.2)
@@ -1314,6 +1314,40 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertRaises(TuneError, runner.step)
def testChangeResources(self):
"""Checks that resource requirements can be changed on fly."""
ray.init(num_cpus=2)
class ChangingScheduler(FIFOScheduler):
def on_trial_result(self, trial_runner, trial, result):
if result["training_iteration"] == 1:
executor = trial_runner.trial_executor
executor.stop_trial(trial, stop_logger=False)
trial.update_resources(2, 0)
executor.start_trial(trial)
return TrialScheduler.CONTINUE
runner = TrialRunner(
BasicVariantGenerator(), scheduler=ChangingScheduler())
kwargs = {
"stopping_criterion": {
"training_iteration": 2
},
"resources": Resources(cpu=1, gpu=0),
}
trials = [Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(runner.trial_executor._committed_resources.cpu, 1)
self.assertRaises(ValueError, lambda: trials[0].update_resources(2, 0))
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(runner.trial_executor._committed_resources.cpu, 2)
def testErrorHandling(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner(BasicVariantGenerator())
+21 -9
View File
@@ -193,7 +193,7 @@ class Checkpoint(object):
def __init__(self, storage, value, last_result=None):
self.storage = storage
self.value = value
self.last_result = last_result
self.last_result = last_result or {}
@staticmethod
def from_object(value=None):
@@ -283,7 +283,7 @@ class Trial(object):
self.max_failures = max_failures
# Local trial state that is updated during the run
self.last_result = None
self.last_result = {}
self.last_update_time = -float("inf")
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
@@ -336,6 +336,18 @@ class Trial(object):
loggers=self.loggers,
sync_function=self.sync_function)
def update_resources(self, cpu, gpu, **kwargs):
"""EXPERIMENTAL: Updates the resource requirements.
Should only be called when the trial is not running.
Raises:
ValueError if trial status is running.
"""
if self.status is Trial.RUNNING:
raise ValueError("Cannot update resources while Trial is running.")
self.resources = Resources(cpu, gpu, **kwargs)
def sync_logger_to_new_location(self, worker_ip):
"""Updates the logger location.
@@ -392,7 +404,7 @@ class Trial(object):
def progress_string(self):
"""Returns a progress message for printing out to the console."""
if self.last_result is None:
if not self.last_result:
return self._status_string()
def location_string(hostname, pid):
@@ -402,12 +414,12 @@ class Trial(object):
return '{} pid={}'.format(hostname, pid)
pieces = [
'{} [{}]'.format(
self._status_string(),
location_string(
self.last_result.get(HOSTNAME),
self.last_result.get(PID))), '{} s'.format(
int(self.last_result.get(TIME_TOTAL_S)))
'{}'.format(self._status_string()), '[{}]'.format(
self.resources.summary_string()), '[{}]'.format(
location_string(
self.last_result.get(HOSTNAME),
self.last_result.get(PID))), '{} s'.format(
int(self.last_result.get(TIME_TOTAL_S)))
]
if self.last_result.get(TRAINING_ITERATION) is not None:
+1 -1
View File
@@ -280,9 +280,9 @@ class TrialRunner(object):
trial (Trial): Trial to queue.
"""
trial.set_verbose(self._verbose)
self._trials.append(trial)
self._scheduler_alg.on_trial_add(self, trial)
self.trial_executor.try_checkpoint_metadata(trial)
self._trials.append(trial)
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
"""Returns a human readable message for printing to the console."""