Files
ray/python/ray/tune/tests/test_ray_trial_executor.py
T

115 lines
4.1 KiB
Python

# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import ray
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, Checkpoint, Resources
class RayTrialExecutorTest(unittest.TestCase):
def setUp(self):
self.trial_executor = RayTrialExecutor(queue_trials=False)
ray.init()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def testStartStop(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
running = self.trial_executor.get_running_trials()
self.assertEqual(1, len(running))
self.trial_executor.stop_trial(trial)
def testSaveRestore(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.save(trial, Checkpoint.DISK)
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testPauseResume(self):
"""Tests that pausing works for trials in flight."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testStartFailure(self):
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf", resources=Resources(1, 0))
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.ERROR, trial.status)
def testPauseResume2(self):
"""Tests that pausing works for trials being processed."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.fetch_result(trial)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testNoResetTrial(self):
"""Tests that reset handles NotImplemented properly."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
self.assertEqual(exists, False)
self.assertEqual(Trial.RUNNING, trial.status)
def testResetTrial(self):
"""Tests that reset works as expected."""
class B(Trainable):
def _train(self):
return dict(timesteps_this_iter=1, done=True)
def reset_config(self, config):
self.config = config
return True
trials = self.generate_trials({
"run": B,
"config": {
"foo": 0
},
}, "grid_search")
trial = trials[0]
self.trial_executor.start_trial(trial)
exists = self.trial_executor.reset_trial(trial, {"hi": 1},
"modified_mock")
self.assertEqual(exists, True)
self.assertEqual(trial.config.get("hi"), 1)
self.assertEqual(trial.experiment_tag, "modified_mock")
self.assertEqual(Trial.RUNNING, trial.status)
def generate_trials(self, spec, name):
suggester = BasicVariantGenerator()
suggester.add_configurations({name: spec})
return suggester.next_trials()
if __name__ == "__main__":
unittest.main(verbosity=2)