diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 63bdbb7bb..30da21c16 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import itertools +import random from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list @@ -23,11 +24,17 @@ class BasicVariantGenerator(SearchAlgorithm): >>> searcher.is_finished == True """ - def __init__(self): + def __init__(self, shuffle=False): + """Initializes the Variant Generator. + + Arguments: + shuffle (bool): Shuffles the generated list of configurations. + """ self._parser = make_parser() self._trial_generator = [] self._counter = 0 self._finished = False + self._shuffle = shuffle def add_configurations(self, experiments): """Chains generator given experiment specifications. @@ -48,6 +55,8 @@ class BasicVariantGenerator(SearchAlgorithm): trials (list): Returns a list of trials. """ trials = list(self._trial_generator) + if self._shuffle: + random.shuffle(trials) self._finished = True return trials diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index d26f64ec9..4f962299d 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -153,7 +153,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): runner.add_trial(t) runner.step() # start runner.step() # 1 result - assert t.last_result is not None + assert t.last_result node2 = cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() @@ -384,7 +384,7 @@ tune.run_experiments( runner = TrialRunner.restore(metadata_checkpoint_dir) trials = runner.get_trials() last_res = trials[0].last_result - if last_res is not None and last_res["training_iteration"]: + if last_res and last_res.get("training_iteration"): break time.sleep(0.3) @@ -476,7 +476,7 @@ tune.run_experiments( runner = TrialRunner.restore(metadata_checkpoint_dir) trials = runner.get_trials() last_res = trials[0].last_result - if last_res is not None and last_res["training_iteration"] == 3: + if last_res and last_res.get("training_iteration") == 3: break time.sleep(0.2) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index a49297c65..bb4e92795 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -1314,6 +1314,40 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertRaises(TuneError, runner.step) + def testChangeResources(self): + """Checks that resource requirements can be changed on fly.""" + ray.init(num_cpus=2) + + class ChangingScheduler(FIFOScheduler): + def on_trial_result(self, trial_runner, trial, result): + if result["training_iteration"] == 1: + executor = trial_runner.trial_executor + executor.stop_trial(trial, stop_logger=False) + trial.update_resources(2, 0) + executor.start_trial(trial) + return TrialScheduler.CONTINUE + + runner = TrialRunner( + BasicVariantGenerator(), scheduler=ChangingScheduler()) + kwargs = { + "stopping_criterion": { + "training_iteration": 2 + }, + "resources": Resources(cpu=1, gpu=0), + } + trials = [Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(runner.trial_executor._committed_resources.cpu, 1) + self.assertRaises(ValueError, lambda: trials[0].update_resources(2, 0)) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(runner.trial_executor._committed_resources.cpu, 2) + def testErrorHandling(self): ray.init(num_cpus=4, num_gpus=2) runner = TrialRunner(BasicVariantGenerator()) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 694c0519d..0989ba3f4 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -193,7 +193,7 @@ class Checkpoint(object): def __init__(self, storage, value, last_result=None): self.storage = storage self.value = value - self.last_result = last_result + self.last_result = last_result or {} @staticmethod def from_object(value=None): @@ -283,7 +283,7 @@ class Trial(object): self.max_failures = max_failures # Local trial state that is updated during the run - self.last_result = None + self.last_result = {} self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end @@ -336,6 +336,18 @@ class Trial(object): loggers=self.loggers, sync_function=self.sync_function) + def update_resources(self, cpu, gpu, **kwargs): + """EXPERIMENTAL: Updates the resource requirements. + + Should only be called when the trial is not running. + + Raises: + ValueError if trial status is running. + """ + if self.status is Trial.RUNNING: + raise ValueError("Cannot update resources while Trial is running.") + self.resources = Resources(cpu, gpu, **kwargs) + def sync_logger_to_new_location(self, worker_ip): """Updates the logger location. @@ -392,7 +404,7 @@ class Trial(object): def progress_string(self): """Returns a progress message for printing out to the console.""" - if self.last_result is None: + if not self.last_result: return self._status_string() def location_string(hostname, pid): @@ -402,12 +414,12 @@ class Trial(object): return '{} pid={}'.format(hostname, pid) pieces = [ - '{} [{}]'.format( - self._status_string(), - location_string( - self.last_result.get(HOSTNAME), - self.last_result.get(PID))), '{} s'.format( - int(self.last_result.get(TIME_TOTAL_S))) + '{}'.format(self._status_string()), '[{}]'.format( + self.resources.summary_string()), '[{}]'.format( + location_string( + self.last_result.get(HOSTNAME), + self.last_result.get(PID))), '{} s'.format( + int(self.last_result.get(TIME_TOTAL_S))) ] if self.last_result.get(TRAINING_ITERATION) is not None: diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 8d3ef320b..afce6863a 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -280,9 +280,9 @@ class TrialRunner(object): trial (Trial): Trial to queue. """ trial.set_verbose(self._verbose) + self._trials.append(trial) self._scheduler_alg.on_trial_add(self, trial) self.trial_executor.try_checkpoint_metadata(trial) - self._trials.append(trial) def debug_string(self, max_debug=MAX_DEBUG_TRIALS): """Returns a human readable message for printing to the console."""