From 0c9817fa764ae01359f59fe9f1c474a934b1b82c Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 22 Oct 2017 23:04:15 -0700 Subject: [PATCH] [tune] Tune Pausing (#1136) * fix yaml bug * add ext agent * gpus * update * tuning * docs * Sun Oct 15 21:09:25 PDT 2017 * lint * update * Sun Oct 15 22:39:55 PDT 2017 * Sun Oct 15 22:40:17 PDT 2017 * Sun Oct 15 22:43:06 PDT 2017 * Sun Oct 15 22:46:06 PDT 2017 * Sun Oct 15 22:46:21 PDT 2017 * Sun Oct 15 22:48:11 PDT 2017 * Sun Oct 15 22:48:44 PDT 2017 * Sun Oct 15 22:49:23 PDT 2017 * Sun Oct 15 22:50:21 PDT 2017 * Sun Oct 15 22:53:00 PDT 2017 * Sun Oct 15 22:53:34 PDT 2017 * Sun Oct 15 22:54:33 PDT 2017 * Sun Oct 15 22:54:50 PDT 2017 * Sun Oct 15 22:55:20 PDT 2017 * Sun Oct 15 22:56:56 PDT 2017 * Sun Oct 15 22:59:03 PDT 2017 * fix * Update tune_mnist_ray.py * remove script trial * fix * reorder * fix ex * py2 support * upd * comments * comments * cleanup readme * fix trial * annotate * Update rllib.rst * init pausing * Docs, Lint * fix danglings and restore endpoint moved to trialrunner * renaming * nit * start always starts from checkpoint * smalls * nits * lint * last change --- python/ray/rllib/agent.py | 20 ++++++++++- python/ray/rllib/train.py | 5 ++- python/ray/tune/trial.py | 61 +++++++++++++++++++++++++-------- python/ray/tune/trial_runner.py | 31 ++++++++++------- test/trial_runner_test.py | 60 ++++++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 31 deletions(-) diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 6a75d73aa..e7ba9ddd9 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -314,13 +314,31 @@ class _MockAgent(Agent): _default_config = {} def _init(self): - pass + self.info = None def _train(self): return TrainingResult( episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}) + def _save(self): + path = os.path.join(self.logdir, "mock_agent.pkl") + with open(path, 'wb') as f: + pickle.dump(self.info, f) + return path + + def _restore(self, checkpoint_path): + with open(checkpoint_path, 'rb') as f: + info = pickle.load(f) + self.info = info + + def set_info(self, info): + self.info = info + return info + + def get_info(self): + return self.info + def get_agent_class(alg): """Returns the class of an known agent given its name.""" diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 5c3bbef2f..e32696760 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -57,9 +57,8 @@ def main(argv): runner.add_trial( Trial( args.env, args.alg, args.config, args.local_dir, None, - args.resources, args.stop, args.checkpoint_freq, - args.restore, args.upload_dir)) - + args.resources, args.stop, args.checkpoint_freq, args.restore, + args.upload_dir)) ray.init( redis_address=args.redis_address, num_cpus=args.num_cpus, num_gpus=args.num_gpus) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 9f5713253..f77a77b9e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -26,6 +26,7 @@ class Trial(object): PENDING = "PENDING" RUNNING = "RUNNING" + PAUSED = "PAUSED" TERMINATED = "TERMINATED" ERROR = "ERROR" @@ -56,12 +57,11 @@ class Trial(object): self.resources = resources self.stopping_criterion = stopping_criterion self.checkpoint_freq = checkpoint_freq - self.restore_path = restore_path self.upload_dir = upload_dir # Local trial state that is updated during the run self.last_result = None - self.checkpoint_path = None + self._checkpoint_path = restore_path self.agent = None self.status = Trial.PENDING self.location = None @@ -73,16 +73,9 @@ class Trial(object): be thrown. """ - self.status = Trial.RUNNING - agent_cls = get_agent_class(self.alg) - cls = ray.remote( - num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)( - agent_cls) - self.agent = cls.remote( - self.env_creator, self.config, self.local_dir, self.upload_dir, - experiment_tag=self.experiment_tag) - if self.restore_path: - ray.get(self.agent.restore.remote(self.restore_path)) + self._setup_agent() + if self._checkpoint_path: + self.restore_from_path(path=self._checkpoint_path) def stop(self, error=False): """Stops this trial. @@ -111,6 +104,21 @@ class Trial(object): finally: self.agent = None + def pause(self): + """We want to release resources (specifically GPUs) when pausing an + experiment. This results in a state similar to TERMINATED.""" + + assert self.status == Trial.RUNNING, self.status + self.checkpoint() + self.stop() + self.status = Trial.PAUSED + + def resume(self): + """Resume PAUSED tasks. This is a blocking call.""" + + assert self.status == Trial.PAUSED, self.status + self.start() + def train_remote(self): """Returns Ray future for one iteration of training.""" @@ -174,11 +182,36 @@ class Trial(object): """ path = ray.get(self.agent.save.remote()) - self.checkpoint_path = path + self._checkpoint_path = path print("Saved checkpoint to:", path) - return path + def restore_from_path(self, path): + """Restores agent state from specified path. + + Args: + path (str): A path where state will be restored. + """ + + if self.agent is None: + print("Unable to restore - no agent") + else: + try: + ray.get(self.agent.restore.remote(path)) + except: + print("Error restoring agent:", traceback.format_exc()) + self.status = Trial.ERROR + + def _setup_agent(self): + self.status = Trial.RUNNING + agent_cls = get_agent_class(self.alg) + cls = ray.remote( + num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)( + agent_cls) + self.agent = cls.remote( + self.env_creator, self.config, self.local_dir, self.upload_dir, + experiment_tag=self.experiment_tag) + def __str__(self): identifier = '{}_{}'.format(self.alg, self.env_name) if self.experiment_tag: diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 4a9b6ef69..361d97392 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -35,7 +35,7 @@ class TrialRunner(object): """Initializes a new TrialRunner.""" self._trials = [] - self._pending = {} + self._running = {} self._avail_resources = Resources(cpu=0, gpu=0) self._committed_resources = Resources(cpu=0, gpu=0) @@ -43,7 +43,7 @@ class TrialRunner(object): """Returns whether all trials have finished running.""" for t in self._trials: - if t.status in [Trial.PENDING, Trial.RUNNING]: + if t.status in [Trial.PENDING, Trial.RUNNING, Trial.PAUSED]: return False return True @@ -56,7 +56,7 @@ class TrialRunner(object): if self._can_launch_more(): self._launch_trial() - elif self._pending: + elif self._running: self._process_events() else: for trial in self._trials: @@ -64,6 +64,9 @@ class TrialRunner(object): assert self._has_resources(trial.resources), \ ("Insufficient cluster resources to launch trial", (trial.resources, self._avail_resources)) + elif trial.status == Trial.PAUSED: + assert False, "There are paused trials, but no more \ + pending trials with sufficient resources." assert False, "Called step when all trials finished?" def get_trials(self): @@ -110,14 +113,14 @@ class TrialRunner(object): self._commit_resources(trial.resources) try: trial.start() - self._pending[trial.train_remote()] = trial + self._running[trial.train_remote()] = trial except: print("Error starting agent, retrying:", traceback.format_exc()) time.sleep(2) trial.stop(error=True) try: trial.start() - self._pending[trial.train_remote()] = trial + self._running[trial.train_remote()] = trial except: print("Error starting agent, abort:", traceback.format_exc()) trial.stop(error=True) @@ -125,27 +128,25 @@ class TrialRunner(object): # have been lost def _process_events(self): - [result_id], _ = ray.wait(self._pending.keys()) - trial = self._pending[result_id] - del self._pending[result_id] + [result_id], _ = ray.wait(self._running.keys()) + trial = self._running[result_id] + del self._running[result_id] try: result = ray.get(result_id) print("result", result) trial.last_result = result if trial.should_stop(result): - self._return_resources(trial.resources) - trial.stop() + self._stop_trial(trial) else: # TODO(rliaw): This implements checkpoint in a blocking manner if trial.should_checkpoint(): trial.checkpoint() - self._pending[trial.train_remote()] = trial + self._running[trial.train_remote()] = trial except: print("Error processing event:", traceback.format_exc()) if trial.status == Trial.RUNNING: - self._return_resources(trial.resources) - trial.stop(error=True) + self._stop_trial(trial, error=True) def _get_runnable(self): for trial in self._trials: @@ -172,6 +173,10 @@ class TrialRunner(object): assert self._committed_resources.cpu >= 0 assert self._committed_resources.gpu >= 0 + def _stop_trial(self, trial, error=False): + self._return_resources(trial.resources) + trial.stop(error=error) + def _update_avail_resources(self): clients = ray.global_state.client_table() local_schedulers = [ diff --git a/test/trial_runner_test.py b/test/trial_runner_test.py index 3b842df44..bbf7f1f55 100644 --- a/test/trial_runner_test.py +++ b/test/trial_runner_test.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import unittest +import os import ray from ray.tune.trial import Trial, Resources @@ -195,6 +196,65 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[1].status, Trial.RUNNING) + def testCheckpointing(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner() + kwargs = { + "stopping_criterion": {"training_iteration": 1}, + "resources": Resources(cpu=1, gpu=1), + } + runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1) + + path = trials[0].checkpoint() + kwargs["restore_path"] = path + + runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertEqual(trials[1].status, Trial.PENDING) + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertEqual(trials[1].status, Trial.RUNNING) + self.assertEqual(ray.get(trials[1].agent.get_info.remote()), 1) + self.addCleanup(os.remove, path) + + def testPauseThenResume(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner() + kwargs = { + "stopping_criterion": {"training_iteration": 2}, + "resources": Resources(cpu=1, gpu=1), + } + runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(ray.get(trials[0].agent.get_info.remote()), None) + + self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1) + + trials[0].pause() + self.assertEqual(trials[0].status, Trial.PAUSED) + + trials[0].resume() + self.assertEqual(trials[0].status, Trial.RUNNING) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(ray.get(trials[0].agent.get_info.remote()), 1) + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + if __name__ == "__main__": unittest.main(verbosity=2)