mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:08:16 +08:00
[tune] Dynamic Resources for Trials (#3974)
## What do these changes do?
Provides a small helper function for modifying the resource requirements of a trial.
Also implements the following:
- setting the last_result to be {} instead of None
- Adding a shuffle to the BasicVariantGenerator
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import random
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
@@ -23,11 +24,17 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
>>> searcher.is_finished == True
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, shuffle=False):
|
||||
"""Initializes the Variant Generator.
|
||||
|
||||
Arguments:
|
||||
shuffle (bool): Shuffles the generated list of configurations.
|
||||
"""
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
self._shuffle = shuffle
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Chains generator given experiment specifications.
|
||||
@@ -48,6 +55,8 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
trials (list): Returns a list of trials.
|
||||
"""
|
||||
trials = list(self._trial_generator)
|
||||
if self._shuffle:
|
||||
random.shuffle(trials)
|
||||
self._finished = True
|
||||
return trials
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
runner.add_trial(t)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
assert t.last_result is not None
|
||||
assert t.last_result
|
||||
node2 = cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
@@ -384,7 +384,7 @@ tune.run_experiments(
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
last_res = trials[0].last_result
|
||||
if last_res is not None and last_res["training_iteration"]:
|
||||
if last_res and last_res.get("training_iteration"):
|
||||
break
|
||||
time.sleep(0.3)
|
||||
|
||||
@@ -476,7 +476,7 @@ tune.run_experiments(
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
last_res = trials[0].last_result
|
||||
if last_res is not None and last_res["training_iteration"] == 3:
|
||||
if last_res and last_res.get("training_iteration") == 3:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@@ -1314,6 +1314,40 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
def testChangeResources(self):
|
||||
"""Checks that resource requirements can be changed on fly."""
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
class ChangingScheduler(FIFOScheduler):
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
if result["training_iteration"] == 1:
|
||||
executor = trial_runner.trial_executor
|
||||
executor.stop_trial(trial, stop_logger=False)
|
||||
trial.update_resources(2, 0)
|
||||
executor.start_trial(trial)
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), scheduler=ChangingScheduler())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=0),
|
||||
}
|
||||
trials = [Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(runner.trial_executor._committed_resources.cpu, 1)
|
||||
self.assertRaises(ValueError, lambda: trials[0].update_resources(2, 0))
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(runner.trial_executor._committed_resources.cpu, 2)
|
||||
|
||||
def testErrorHandling(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
|
||||
@@ -193,7 +193,7 @@ class Checkpoint(object):
|
||||
def __init__(self, storage, value, last_result=None):
|
||||
self.storage = storage
|
||||
self.value = value
|
||||
self.last_result = last_result
|
||||
self.last_result = last_result or {}
|
||||
|
||||
@staticmethod
|
||||
def from_object(value=None):
|
||||
@@ -283,7 +283,7 @@ class Trial(object):
|
||||
self.max_failures = max_failures
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self.last_result = {}
|
||||
self.last_update_time = -float("inf")
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
@@ -336,6 +336,18 @@ class Trial(object):
|
||||
loggers=self.loggers,
|
||||
sync_function=self.sync_function)
|
||||
|
||||
def update_resources(self, cpu, gpu, **kwargs):
|
||||
"""EXPERIMENTAL: Updates the resource requirements.
|
||||
|
||||
Should only be called when the trial is not running.
|
||||
|
||||
Raises:
|
||||
ValueError if trial status is running.
|
||||
"""
|
||||
if self.status is Trial.RUNNING:
|
||||
raise ValueError("Cannot update resources while Trial is running.")
|
||||
self.resources = Resources(cpu, gpu, **kwargs)
|
||||
|
||||
def sync_logger_to_new_location(self, worker_ip):
|
||||
"""Updates the logger location.
|
||||
|
||||
@@ -392,7 +404,7 @@ class Trial(object):
|
||||
def progress_string(self):
|
||||
"""Returns a progress message for printing out to the console."""
|
||||
|
||||
if self.last_result is None:
|
||||
if not self.last_result:
|
||||
return self._status_string()
|
||||
|
||||
def location_string(hostname, pid):
|
||||
@@ -402,12 +414,12 @@ class Trial(object):
|
||||
return '{} pid={}'.format(hostname, pid)
|
||||
|
||||
pieces = [
|
||||
'{} [{}]'.format(
|
||||
self._status_string(),
|
||||
location_string(
|
||||
self.last_result.get(HOSTNAME),
|
||||
self.last_result.get(PID))), '{} s'.format(
|
||||
int(self.last_result.get(TIME_TOTAL_S)))
|
||||
'{}'.format(self._status_string()), '[{}]'.format(
|
||||
self.resources.summary_string()), '[{}]'.format(
|
||||
location_string(
|
||||
self.last_result.get(HOSTNAME),
|
||||
self.last_result.get(PID))), '{} s'.format(
|
||||
int(self.last_result.get(TIME_TOTAL_S)))
|
||||
]
|
||||
|
||||
if self.last_result.get(TRAINING_ITERATION) is not None:
|
||||
|
||||
@@ -280,9 +280,9 @@ class TrialRunner(object):
|
||||
trial (Trial): Trial to queue.
|
||||
"""
|
||||
trial.set_verbose(self._verbose)
|
||||
self._trials.append(trial)
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
self._trials.append(trial)
|
||||
|
||||
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
Reference in New Issue
Block a user