From 691c9733f95a28d9a1ceeb291eda94a2ffd93cdd Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 10 Jul 2019 18:51:11 -0700 Subject: [PATCH] =?UTF-8?q?[tune]=20Document=20trainable=20attributes=20an?= =?UTF-8?q?d=20enable=20user-checkpoint=E2=80=A6=20(#4868)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/ray/rllib/agents/mock.py | 8 ++- python/ray/rllib/agents/trainer.py | 6 --- python/ray/tune/result.py | 3 ++ python/ray/tune/tests/test_trial_runner.py | 23 +++++++++ python/ray/tune/trainable.py | 60 ++++++++++++++++++---- python/ray/tune/trial_runner.py | 10 ++-- 6 files changed, 90 insertions(+), 20 deletions(-) diff --git a/python/ray/rllib/agents/mock.py b/python/ray/rllib/agents/mock.py index 0b7d77c2a..62574d5bf 100644 --- a/python/ray/rllib/agents/mock.py +++ b/python/ray/rllib/agents/mock.py @@ -6,6 +6,7 @@ import os import pickle import numpy as np +from ray.tune import result as tune_result from ray.rllib.agents.trainer import Trainer, with_common_config @@ -18,6 +19,7 @@ class _MockTrainer(Trainer): "persistent_error": False, "test_variable": 1, "num_workers": 0, + "user_checkpoint_freq": 0, }) @classmethod @@ -32,11 +34,15 @@ class _MockTrainer(Trainer): if self.config["mock_error"] and self.iteration == 1 \ and (self.config["persistent_error"] or not self.restored): raise Exception("mock error") - return dict( + result = dict( episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}) + if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0: + if self.iteration % self.config["user_checkpoint_freq"] == 0: + result.update({tune_result.SHOULD_CHECKPOINT: True}) + return result def _save(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "mock_agent.pkl") diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index bf84699aa..965014c32 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -579,12 +579,6 @@ class Trainer(Trainable): else: return res[0] # backwards compatibility - @property - def iteration(self): - """Current training iter, auto-incremented with each train() call.""" - - return self._iteration - @property def _name(self): """Subclasses should override this to declare their name.""" diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 51a67d593..3ef2d2975 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -9,6 +9,9 @@ import os # (Optional/Auto-filled) training is terminated. Filled only if not provided. DONE = "done" +# (Optional) Enum for user controlled checkpoint +SHOULD_CHECKPOINT = "should_checkpoint" + # (Auto-filled) The hostname of the machine hosting the training process. HOSTNAME = "hostname" diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index b1b7b6633..ac3c9a07b 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -2237,6 +2237,29 @@ class TrialRunnerTest(unittest.TestCase): self.assertEquals(count_checkpoints(tmpdir), 2) shutil.rmtree(tmpdir) + def testUserCheckpoint(self): + ray.init(num_cpus=3) + tmpdir = tempfile.mkdtemp() + runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) + runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2})) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) + runner.step() # 0 + self.assertFalse(trials[0].has_checkpoint()) + runner.step() # 1 + self.assertFalse(trials[0].has_checkpoint()) + runner.step() # 2 + self.assertTrue(trials[0].has_checkpoint()) + + runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) + runner2.step() + trials2 = runner2.get_trials() + self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1) + shutil.rmtree(tmpdir) + class SearchAlgorithmTest(unittest.TestCase): def testNestedSuggestion(self): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index de7614859..1c2e6744b 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -46,6 +46,11 @@ class Trainable(object): just a ``my_train(config, reporter)`` function to the config. The function will be automatically converted to this interface (sans checkpoint functionality). + + When using Tune, Tune will convert this class into a Ray actor, which + runs on a separate process. Tune will also change the current working + directory of this process to `self.logdir`. + """ def __init__(self, config=None, logger_creator=None): @@ -70,14 +75,15 @@ class Trainable(object): if logger_creator: self._result_logger = logger_creator(self.config) - self.logdir = self._result_logger.logdir + self._logdir = self._result_logger.logdir else: logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") if not os.path.exists(DEFAULT_RESULTS_DIR): os.makedirs(DEFAULT_RESULTS_DIR) - self.logdir = tempfile.mkdtemp( + self._logdir = tempfile.mkdtemp( prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) - self._result_logger = UnifiedLogger(self.config, self.logdir, None) + self._result_logger = UnifiedLogger(self.config, self._logdir, + None) self._iteration = 0 self._time_total = 0.0 @@ -131,7 +137,8 @@ class Trainable(object): across checkpoint / restore calls. `training_iteration` (int): The index of this - training iteration, e.g. call to train(). + training iteration, e.g. call to train(). This is incremented + after `_train()` is called. `pid` (str): The pid of the training process. @@ -219,8 +226,8 @@ class Trainable(object): def delete_checkpoint(self, checkpoint_dir): """Removes subdirectory within checkpoint_folder - Parameters - ---------- + + Args: checkpoint_dir : path to checkpoint """ if os.path.isfile(checkpoint_dir): @@ -275,8 +282,9 @@ class Trainable(object): return checkpoint_path def save_to_object(self): - """Saves the current model state to a Python object. It also - saves to disk but does not return the checkpoint path. + """Saves the current model state to a Python object. + + It also saves to disk but does not return the checkpoint path. Returns: Object holding checkpoint data. @@ -394,11 +402,45 @@ class Trainable(object): self._result_logger.close() self._stop() + @property + def logdir(self): + """Directory of the results and checkpoints for this Trainable. + + Tune will automatically sync this folder with the driver if execution + is distributed. + + Note that the current working directory will also be changed to this. + + """ + return self._logdir + + @property + def iteration(self): + """Current training iteration. + + This value is automatically incremented every time `train()` is called + and is automatically inserted into the training result dict. + + """ + return self._iteration + + def get_config(self): + """Returns configuration passed in by Tune.""" + return self.config + def _train(self): """Subclasses should override this to implement train(). + The return value will be automatically passed to the loggers. Users + can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT` + to manually trigger termination of this trial or checkpointing of this + trial. Note that manual checkpointing only works when subclassing + Trainables. + Returns: - A dict that describes training progress.""" + A dict that describes training progress. + + """ raise NotImplementedError diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index c512c80d3..f307d0923 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -15,7 +15,8 @@ import traceback import ray.cloudpickle as cloudpickle from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE +from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, + SHOULD_CHECKPOINT) from ray.tune.syncer import get_syncer from ray.tune.trial import Trial, Checkpoint from ray.tune.sample import function @@ -529,7 +530,8 @@ class TrialRunner(object): # the scheduler decision is STOP or PAUSE. Note that # PAUSE only checkpoints to memory and does not update # the global checkpoint state. - self._checkpoint_trial_if_needed(trial) + self._checkpoint_trial_if_needed( + trial, force=result.get(SHOULD_CHECKPOINT, False)) if decision == TrialScheduler.CONTINUE: self.trial_executor.continue_training(trial) @@ -554,9 +556,9 @@ class TrialRunner(object): self.trial_executor.stop_trial( trial, error=True, error_msg=error_msg) - def _checkpoint_trial_if_needed(self, trial): + def _checkpoint_trial_if_needed(self, trial, force=False): """Checkpoints trial based off trial.last_result.""" - if trial.should_checkpoint(): + if trial.should_checkpoint() or force: # Save trial runtime if possible if hasattr(trial, "runner") and trial.runner: self.trial_executor.save(trial, storage=Checkpoint.DISK)