mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 08:31:42 +08:00
115 lines
4.1 KiB
Python
115 lines
4.1 KiB
Python
# coding: utf-8
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib import _register_all
|
|
from ray.tune import Trainable
|
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
|
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
|
from ray.tune.suggest import BasicVariantGenerator
|
|
from ray.tune.trial import Trial, Checkpoint, Resources
|
|
|
|
|
|
class RayTrialExecutorTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
|
ray.init()
|
|
|
|
def tearDown(self):
|
|
ray.shutdown()
|
|
_register_all() # re-register the evicted objects
|
|
|
|
def testStartStop(self):
|
|
trial = Trial("__fake")
|
|
self.trial_executor.start_trial(trial)
|
|
running = self.trial_executor.get_running_trials()
|
|
self.assertEqual(1, len(running))
|
|
self.trial_executor.stop_trial(trial)
|
|
|
|
def testSaveRestore(self):
|
|
trial = Trial("__fake")
|
|
self.trial_executor.start_trial(trial)
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
self.trial_executor.save(trial, Checkpoint.DISK)
|
|
self.trial_executor.restore(trial)
|
|
self.trial_executor.stop_trial(trial)
|
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
|
|
|
def testPauseResume(self):
|
|
"""Tests that pausing works for trials in flight."""
|
|
trial = Trial("__fake")
|
|
self.trial_executor.start_trial(trial)
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
self.trial_executor.pause_trial(trial)
|
|
self.assertEqual(Trial.PAUSED, trial.status)
|
|
self.trial_executor.start_trial(trial)
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
self.trial_executor.stop_trial(trial)
|
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
|
|
|
def testStartFailure(self):
|
|
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
|
trial = Trial("asdf", resources=Resources(1, 0))
|
|
self.trial_executor.start_trial(trial)
|
|
self.assertEqual(Trial.ERROR, trial.status)
|
|
|
|
def testPauseResume2(self):
|
|
"""Tests that pausing works for trials being processed."""
|
|
trial = Trial("__fake")
|
|
self.trial_executor.start_trial(trial)
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
self.trial_executor.fetch_result(trial)
|
|
self.trial_executor.pause_trial(trial)
|
|
self.assertEqual(Trial.PAUSED, trial.status)
|
|
self.trial_executor.start_trial(trial)
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
self.trial_executor.stop_trial(trial)
|
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
|
|
|
def testNoResetTrial(self):
|
|
"""Tests that reset handles NotImplemented properly."""
|
|
trial = Trial("__fake")
|
|
self.trial_executor.start_trial(trial)
|
|
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
|
|
self.assertEqual(exists, False)
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
|
|
def testResetTrial(self):
|
|
"""Tests that reset works as expected."""
|
|
|
|
class B(Trainable):
|
|
def _train(self):
|
|
return dict(timesteps_this_iter=1, done=True)
|
|
|
|
def reset_config(self, config):
|
|
self.config = config
|
|
return True
|
|
|
|
trials = self.generate_trials({
|
|
"run": B,
|
|
"config": {
|
|
"foo": 0
|
|
},
|
|
}, "grid_search")
|
|
trial = trials[0]
|
|
self.trial_executor.start_trial(trial)
|
|
exists = self.trial_executor.reset_trial(trial, {"hi": 1},
|
|
"modified_mock")
|
|
self.assertEqual(exists, True)
|
|
self.assertEqual(trial.config.get("hi"), 1)
|
|
self.assertEqual(trial.experiment_tag, "modified_mock")
|
|
self.assertEqual(Trial.RUNNING, trial.status)
|
|
|
|
def generate_trials(self, spec, name):
|
|
suggester = BasicVariantGenerator()
|
|
suggester.add_configurations({name: spec})
|
|
return suggester.next_trials()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|