diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 35810a8ed..41602008f 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -141,6 +141,14 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} python /ray/python/ray/tune/examples/pbt_convnet_example.py \ --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ + python /ray/python/ray/tune/examples/hyperband_function_example.py \ + --smoke-test + +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ + python /ray/python/ray/tune/examples/pbt_function.py \ + --smoke-test + $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \ python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \ --smoke-test diff --git a/doc/source/tune/_tutorials/tune-60-seconds.rst b/doc/source/tune/_tutorials/tune-60-seconds.rst index 08b1887a7..1ef6af17c 100644 --- a/doc/source/tune/_tutorials/tune-60-seconds.rst +++ b/doc/source/tune/_tutorials/tune-60-seconds.rst @@ -35,7 +35,7 @@ Here's an example of specifying the objective function using :ref:`the function- Now, there's two Trainable APIs - one being the :ref:`function-based API ` that we demonstrated above. -The other is a :ref:`class-based API ` that enables :ref:`checkpointing and pausing `. Here's an example of specifying the objective function using the :ref:`class-based API `: +The other is a :ref:`class-based API `. Here's an example of specifying the objective function using the :ref:`class-based API `: .. code-block:: python diff --git a/doc/source/tune/_tutorials/tune-usage.rst b/doc/source/tune/_tutorials/tune-usage.rst index 726252fa8..d26c32ad3 100644 --- a/doc/source/tune/_tutorials/tune-usage.rst +++ b/doc/source/tune/_tutorials/tune-usage.rst @@ -147,46 +147,39 @@ When running a hyperparameter search, Tune can automatically and periodically sa * fault-tolerance when using pre-emptible machines. * Pausing trials when using Trial Schedulers such as HyperBand and PBT. -To enable checkpointing, you must implement a :ref:`Trainable class ` (the function-based API are not checkpointable, since they never return control back to their caller). +Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on. -Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on. You can checkpoint with three different mechanisms: manually, periodically, and at termination. - -**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances: +To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``: .. code-block:: python - def _train(self): - # training code - result = {"mean_accuracy": accuracy} - if detect_instance_preemption(): - result.update(should_checkpoint=True) - return result + import time + from ray import tune + def train_func(config, checkpoint=None): + start = 0 + if checkpoint: + with open(checkpoint) as f: + state = json.loads(f.read()) + start = state["step"] + 1 -**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq=`` and ``max_failures=`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.: + for iter in range(start, 100): + time.sleep(1) -.. code-block:: python + # Obtain a checkpoint directory + checkpoint_dir = tune.make_checkpoint_dir(step=step) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) + tune.save_checkpoint(path) - tune.run( - my_trainable, - checkpoint_freq=10, - max_failures=5, - ) + tune.report(hello="world", ray="tune") -**Checkpointing at Termination**: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end -of a trial, you can additionally set the ``checkpoint_at_end=True``: + tune.run(train_func) -.. code-block:: python - :emphasize-lines: 5 +In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_``. - tune.run( - my_trainable, - checkpoint_freq=10, - checkpoint_at_end=True, - max_failures=5, - ) - -The checkpoint will be saved at a path that looks like ``local_dir/exp_name/trial_name/checkpoint_x/``, where the x is the number of iterations so far when the checkpoint is saved. To restore the checkpoint, you can use the ``restore`` argument and specify a checkpoint file. By doing this, you can change whatever experiments' configuration such as the experiment's name, the training iteration or so: +You can restore a single trial checkpoint by using ``tune.run(restore=)`` By doing this, you can change whatever experiments' configuration such as the experiment's name: .. code-block:: python diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 929e61972..89a8e400a 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -3,9 +3,7 @@ Training (tune.Trainable, tune.report) ====================================== -Training can be done with either a **Class API** (``tune.Trainable``) or **function-based API** (``tune.report``). - -You can use the **function-based API** for fast prototyping. On the other hand, the ``tune.Trainable`` interface supports checkpoint/restore functionality and provides more control for advanced algorithms. +Training can be done with either a **Class API** (``tune.Trainable``) or **function API** (``tune.report``). For the sake of example, let's maximize this objective function: @@ -16,8 +14,10 @@ For the sake of example, let's maximize this objective function: .. _tune-function-api: -Function-based API ------------------- +Function API +------------ + +Here is a simple example of using the function API. You can report intermediate metrics by simply calling ``tune.report`` within the provided function. .. code-block:: python @@ -25,31 +25,74 @@ Function-based API # config (dict): A dict of hyperparameters. for x in range(20): - score = objective(x, config["a"], config["b"]) + intermediate_score = objective(x, config["a"], config["b"]) - tune.report(score=score) # This sends the score to Tune. + tune.report(value=intermediate_score) # This sends the score to Tune. analysis = tune.run( trainable, - config={ - "a": 2, - "b": 4 - }) + config={"a": 2, "b": 4} + ) print("best config: ", analysis.get_best_config(metric="score", mode="max")) -.. tip:: Do not use ``tune.track.log`` within a ``Trainable`` class. +.. tip:: Do not use ``tune.report`` within a ``Trainable`` class. -Tune will run this function on a separate thread in a Ray actor process. Note that this API is not checkpointable, since the thread will never return control back to its caller. +Tune will run this function on a separate thread in a Ray actor process. -.. note:: If you want to pass in a Python lambda, you will need to first register the function: ``tune.register_trainable("lambda_id", lambda x: ...)``. You can then use ``lambda_id`` in place of ``my_trainable``. + +Function API Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``: + +.. code-block:: python + + import time + from ray import tune + + def train_func(config, checkpoint=None): + start = 0 + if checkpoint: + with open(checkpoint) as f: + state = json.loads(f.read()) + start = state["step"] + 1 + + for iter in range(start, 100): + time.sleep(1) + + # + checkpoint_dir = tune.make_checkpoint_dir(step=step) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) + tune.save_checkpoint(path) + + tune.report(hello="world", ray="tune") + + tune.run(train_func) + +In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_``. You can restore a single trial checkpoint by using ``tune.run(restore=)``: + +.. code-block:: python + + analysis = tune.run( + train, + config={ + "max_iter": 5 + }, + ).trials + last_ckpt = trial.checkpoint.value + analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt) + +Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``. .. _tune-class-api: Trainable Class API ------------------- -.. caution:: Do not use ``tune.track.log`` within a ``Trainable`` class. +.. caution:: Do not use ``tune.report`` within a ``Trainable`` class. The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API: @@ -87,14 +130,13 @@ As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on .. tip:: As a rule of thumb, the execution time of ``_train`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes). -In this example, we only implemented the ``_setup`` and ``_train`` methods for simplification. Next, we'll implement ``_save`` and ``_restore`` for checkpoint and fault tolerance. .. _tune-trainable-save-restore: -Save and Restore -~~~~~~~~~~~~~~~~ +Class API Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~ -Many Tune features rely on ``_save``, and ``_restore``, including the usage of certain Trial Schedulers, fault tolerance, and checkpointing. +You can also implement checkpoint/restore using the Trainable Class API: .. code-block:: python @@ -108,9 +150,45 @@ Many Tune features rely on ``_save``, and ``_restore``, including the usage of c checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") self.model.load_state_dict(torch.load(checkpoint_path)) -Checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_``. You can restore a single trial checkpoint by using ``tune.run(restore=)``. + tune.run(MyTrainableClass, checkpoint_freq=2) + +You can checkpoint with three different mechanisms: manually, periodically, and at termination. + +**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances: + +.. code-block:: python + + def _train(self): + # training code + result = {"mean_accuracy": accuracy} + if detect_instance_preemption(): + result.update(should_checkpoint=True) + return result + + +**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting ``checkpoint_freq=`` and ``max_failures=`` to checkpoint trials every *N* iterations and recover from up to *M* crashes per trial, e.g.: + +.. code-block:: python + + tune.run( + my_trainable, + checkpoint_freq=10, + max_failures=5, + ) + +**Checkpointing at Termination**: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end +of a trial, you can additionally set the ``checkpoint_at_end=True``: + +.. code-block:: python + :emphasize-lines: 5 + + tune.run( + my_trainable, + checkpoint_freq=10, + checkpoint_at_end=True, + max_failures=5, + ) -Tune also generates temporary checkpoints for pausing and switching between trials. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``. Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution. @@ -122,31 +200,11 @@ Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before exec validate_save_restore(MyTrainableClass) validate_save_restore(MyTrainableClass, use_object_store=True) - -Advanced Resource Allocation ----------------------------- - -Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set ``extra_cpu`` or ``extra_gpu`` inside ``tune.run`` to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set ``"gpu": 1, "extra_gpu": 4``. - -.. code-block:: python - :emphasize-lines: 4-8 - - tune.run( - my_trainable, - name="my_trainable", - resources_per_trial={ - "cpu": 1, - "gpu": 1, - "extra_gpu": 4 - } - ) - -The ``Trainable`` also provides the ``default_resource_requests`` interface to automatically declare the ``resources_per_trial`` based on the given configuration. - - Advanced: Reusing Actors ~~~~~~~~~~~~~~~~~~~~~~~~ +.. note:: This feature is only for the Trainable Class API. + Your Trainable can often take a long time to start. To avoid this, you can do ``tune.run(reuse_actors=True)`` to reuse the same Trainable Python process and object for multiple hyperparameters. This requires you to implement ``Trainable.reset_config``, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable. @@ -176,8 +234,47 @@ This requires you to implement ``Trainable.reset_config``, which provides a new return True -tune.Trainable --------------- +Advanced Resource Allocation +---------------------------- + +Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set ``extra_cpu`` or ``extra_gpu`` inside ``tune.run`` to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set ``"gpu": 1, "extra_gpu": 4``. + +.. code-block:: python + :emphasize-lines: 4-8 + + tune.run( + my_trainable, + name="my_trainable", + resources_per_trial={ + "cpu": 1, + "gpu": 1, + "extra_gpu": 4 + } + ) + +The ``Trainable`` also provides the ``default_resource_requests`` interface to automatically declare the ``resources_per_trial`` based on the given configuration. + + + +.. _track-docstring: + +tune.report / tune.checkpoint (Function API) +-------------------------------------------- + +.. autofunction:: ray.tune.report + +.. autofunction:: ray.tune.make_checkpoint_dir + +.. autofunction:: ray.tune.save_checkpoint + +.. autofunction:: ray.tune.get_trial_dir + +.. autofunction:: ray.tune.get_trial_name + +.. autofunction:: ray.tune.get_trial_id + +tune.Trainable (Class API) +-------------------------- .. autoclass:: ray.tune.Trainable @@ -190,21 +287,6 @@ tune.DurableTrainable .. autoclass:: ray.tune.DurableTrainable -.. _track-docstring: - -tune.track ----------- - -.. automodule:: ray.tune.track - :members: - :exclude-members: init, - -KerasCallback -------------- - -.. automodule:: ray.tune.integration.keras - :members: - StatusReporter -------------- diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 3e6dd4388..60805c250 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -149,6 +149,14 @@ py_test( tags = ["exclusive"], ) +py_test( + name = "test_function_api", + size = "medium", + srcs = ["tests/test_function_api.py"], + deps = [":tune_lib"], + tags = ["exclusive"], +) + py_test( name = "test_sync", size = "medium", diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 58753f98f..7fd38c142 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -8,7 +8,8 @@ from ray.tune.trainable import Trainable from ray.tune.durable_trainable import DurableTrainable from ray.tune.suggest import grid_search from ray.tune.session import (report, get_trial_dir, get_trial_name, - get_trial_id) + get_trial_id, make_checkpoint_dir, + save_checkpoint) from ray.tune.progress_reporter import (ProgressReporter, CLIReporter, JupyterNotebookReporter) from ray.tune.sample import (function, sample_from, uniform, choice, randint, @@ -21,5 +22,5 @@ __all__ = [ "uniform", "choice", "randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir", "get_trial_name", - "get_trial_id" + "get_trial_id", "make_checkpoint_dir", "save_checkpoint" ] diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index ab25cfa02..fe2e7a5d7 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -57,7 +57,7 @@ class DurableTrainable(Trainable): Checkpoint path or prefix that may be passed to restore(). """ if checkpoint_dir: - if checkpoint_dir.starts_with(os.path.abspath(self.logdir)): + if checkpoint_dir.startswith(os.path.abspath(self.logdir)): raise ValueError("`checkpoint_dir` must be `self.logdir`, or " "a sub-directory.") checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir) diff --git a/python/ray/tune/examples/hyperband_function_example.py b/python/ray/tune/examples/hyperband_function_example.py new file mode 100644 index 000000000..6d4d3dc4e --- /dev/null +++ b/python/ray/tune/examples/hyperband_function_example.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python + +import argparse +import json +import os + +import numpy as np + +import ray +from ray import tune +from ray.tune.schedulers import HyperBandScheduler + + +def train(config, checkpoint=None): + step = 0 + if checkpoint: + with open(checkpoint) as f: + step = json.loads(f.read())["timestep"] + + for timestep in range(step, 100): + v = np.tanh(float(timestep) / config.get("width", 1)) + v *= config.get("height", 1) + + if timestep % 3 == 0: + checkpoint_dir = tune.make_checkpoint_dir(step=timestep) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"timestep": timestep})) + tune.save_checkpoint(path) + + # Here we use `episode_reward_mean`, but you can also report other + # objectives such as loss or accuracy. + tune.report(episode_reward_mean=v) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init(num_cpus=4 if args.smoke_test else None) + + # Hyperband early stopping, configured with `episode_reward_mean` as the + # objective and `training_iteration` as the time unit, + # which is automatically filled by Tune. + hyperband = HyperBandScheduler( + time_attr="training_iteration", + metric="episode_reward_mean", + mode="max", + max_t=200) + + tune.run( + train, + name="hyperband_test", + num_samples=20, + stop={"training_iteration": 10 if args.smoke_test else 99999}, + config={"height": tune.uniform(0, 100)}, + scheduler=hyperband, + fail_fast=True) diff --git a/python/ray/tune/examples/pbt_function.py b/python/ray/tune/examples/pbt_function.py new file mode 100644 index 000000000..8e75240ce --- /dev/null +++ b/python/ray/tune/examples/pbt_function.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +import numpy as np +import argparse +import json +import os +import random + +import ray +from ray import tune +from ray.tune.schedulers import PopulationBasedTraining + + +def pbt_function(config, checkpoint=None): + """Toy PBT problem for benchmarking adaptive learning rate. + + The goal is to optimize this trainable's accuracy. The accuracy increases + fastest at the optimal lr, which is a function of the current accuracy. + + The optimal lr schedule for this problem is the triangle wave as follows. + Note that many lr schedules for real models also follow this shape: + + best lr + ^ + | /\ + | / \ + | / \ + | / \ + ------------> accuracy + + In this problem, using PBT with a population of 2-4 is sufficient to + roughly approximate this lr schedule. Higher population sizes will yield + faster convergence. Training will not converge without PBT. + """ + lr = config["lr"] + accuracy = 0.0 # end = 1000 + start = 0 + if checkpoint: + with open(checkpoint) as f: + state = json.loads(f.read()) + accuracy = state["acc"] + start = state["step"] + + midpoint = 100 # lr starts decreasing after acc > midpoint + q_tolerance = 3 # penalize exceeding lr by more than this multiple + noise_level = 2 # add gaussian noise to the acc increase + # triangle wave: + # - start at 0.001 @ t=0, + # - peak at 0.01 @ t=midpoint, + # - end at 0.001 @ t=midpoint * 2, + for step in range(start, 100): + if accuracy < midpoint: + optimal_lr = 0.01 * accuracy / midpoint + else: + optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint + optimal_lr = min(0.01, max(0.001, optimal_lr)) + + # compute accuracy increase + q_err = max(lr, optimal_lr) / min(lr, optimal_lr) + if q_err < q_tolerance: + accuracy += (1.0 / q_err) * random.random() + elif lr > optimal_lr: + accuracy -= (q_err - q_tolerance) * random.random() + accuracy += noise_level * np.random.normal() + accuracy = max(0, accuracy) + + if step % 3 == 0: + checkpoint_dir = tune.make_checkpoint_dir(step=step) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"acc": accuracy, "step": start})) + tune.save_checkpoint(path) + + tune.report( + mean_accuracy=accuracy, + cur_lr=lr, + optimal_lr=optimal_lr, # for debugging + q_err=q_err, # for debugging + done=accuracy > midpoint * 2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + if args.smoke_test: + ray.init(num_cpus=2) # force pausing to happen for test + else: + ray.init() + + pbt = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=4, + hyperparam_mutations={ + # distribution for resampling + "lr": lambda: random.uniform(0.0001, 0.02), + # allow perturbations within this set of categorical values + "some_other_factor": [1, 2], + }) + + tune.run( + pbt_function, + name="pbt_test", + scheduler=pbt, + verbose=False, + stop={ + "training_iteration": 30, + }, + num_samples=8, + fail_fast=True, + config={ + "lr": 0.0001, + # note: this parameter is perturbed but has no effect on + # the model training in this example + "some_other_factor": 1, + }) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 78c85ce1b..c401a6ddb 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -3,6 +3,7 @@ import logging import os from ray.tune.error import TuneError +from ray.tune.function_runner import detect_checkpoint_function from ray.tune.registry import register_trainable, get_trainable_cls from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.sample import sample_from @@ -92,6 +93,18 @@ class Experiment: restore=None): config = config or {} + + if callable(run) and detect_checkpoint_function(run): + if checkpoint_at_end: + raise ValueError( + "'checkpoint_at_end' cannot be used with a " + "checkpointable function. You can specify and register " + "checkpoints within your trainable function.") + if checkpoint_freq: + raise ValueError( + "'checkpoint_freq' cannot be used with a " + "checkpointable function. You can specify checkpoints " + "within your trainable function.") self._run_identifier = Experiment.register_if_needed(run) self.name = name or self._run_identifier if upload_dir: diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 3002adf52..f82fc83f5 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -1,13 +1,18 @@ import logging +import os +import io import time import inspect +import shutil import threading import traceback + from six.moves import queue from ray.tune import TuneError, session -from ray.tune.trainable import Trainable -from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE +from ray.tune.trainable import Trainable, TrainableUtil +from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, + SHOULD_CHECKPOINT) logger = logging.getLogger(__name__) @@ -40,6 +45,8 @@ class StatusReporter: self._trial_name = trial_name self._trial_id = trial_id self._logdir = logdir + self._last_checkpoint = {} + self._fresh_checkpoint = False def __call__(self, **kwargs): """Report updated training status. @@ -77,6 +84,29 @@ class StatusReporter: # resume training. self._continue_semaphore.acquire() + def make_checkpoint_dir(self, step=None): + checkpoint_dir = TrainableUtil.make_checkpoint_dir( + self.logdir, index=step) + return checkpoint_dir + + def save_checkpoint(self, checkpoint): + if isinstance(checkpoint, str): + try: + TrainableUtil.find_checkpoint_dir(checkpoint) + except FileNotFoundError: + logger.error("Checkpoint must be created with path given from " + "make_checkpoint_dir.") + raise + self._last_checkpoint = checkpoint + self._fresh_checkpoint = True + + def has_new_checkpoint(self): + return self._fresh_checkpoint + + def get_checkpoint(self): + self._fresh_checkpoint = False + return self._last_checkpoint + def _start(self): self._last_report_time = time.time() @@ -155,21 +185,33 @@ class FunctionRunner(Trainable): trial_id=self.trial_id, logdir=self.logdir) self._last_result = {} - config = config.copy() session.init(self._status_reporter) - - def entrypoint(): - return self._trainable_func(config, self._status_reporter) - - # the runner thread is not started until the first call to _train - self._runner = _RunnerThread(entrypoint, self._error_queue) + self._runner = None + self._restore_tmpdir = None + self.default_checkpoint_dir = None def _trainable_func(self): """Subclasses can override this to set the trainable func.""" raise NotImplementedError + def _start(self): + def entrypoint(): + return self._trainable_func(self.config, self._status_reporter, + self._status_reporter.get_checkpoint()) + + # the runner thread is not started until the first call to _train + self._runner = _RunnerThread(entrypoint, self._error_queue) + # if not alive, try to start + self._status_reporter._start() + try: + self._runner.start() + except RuntimeError: + # If this is reached, it means the thread was started and is + # now done or has raised an exception. + pass + def _train(self): """Implements train() for a Function API. @@ -178,19 +220,12 @@ class FunctionRunner(Trainable): along with a result with "done=True". The TrialRunner will handle the result accordingly (see tune/trial_runner.py). """ - if self._runner.is_alive(): + if self._runner and self._runner.is_alive(): # if started and alive, inform the reporter to continue and # generate the next result self._continue_semaphore.release() else: - # if not alive, try to start - self._status_reporter._start() - try: - self._runner.start() - except RuntimeError: - # If this is reached, it means the thread was started and is - # now done or has raised an exception. - pass + self._start() result = None while result is None and self._runner.is_alive(): @@ -240,8 +275,61 @@ class FunctionRunner(Trainable): result = new_result self._last_result = result + if self._status_reporter.has_new_checkpoint(): + result[SHOULD_CHECKPOINT] = True return result + def create_default_checkpoint_dir(self): + self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir( + self.logdir, index="default") + return self.default_checkpoint_dir + + def save(self, checkpoint_path=None): + if checkpoint_path: + raise ValueError( + "Checkpoint path should not be used with function API.") + + checkpoint = self._status_reporter.get_checkpoint() + state = self.get_state() + + if not checkpoint: + state.update(iteration=0, timesteps_total=0, episodes_total=0) + parent_dir = self.create_default_checkpoint_dir() + elif isinstance(checkpoint, dict): + parent_dir = TrainableUtil.make_checkpoint_dir( + self.logdir, index=self.training_iteration) + else: + parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) + checkpoint_path = TrainableUtil.process_checkpoint( + checkpoint, parent_dir, state) + return checkpoint_path + + def save_to_object(self): + checkpoint_path = self.save() + data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) + out = io.BytesIO() + if len(data_dict) > 10e6: # getting pretty large + logger.info("Checkpoint size is {} bytes".format(len(data_dict))) + out.write(data_dict) + return out.getvalue() + + def _restore(self, checkpoint): + # This should be removed once Trainables are refactored. + if "tune_checkpoint_path" in checkpoint: + del checkpoint["tune_checkpoint_path"] + self._status_reporter.save_checkpoint(checkpoint) + + def restore_from_object(self, obj): + if self.default_checkpoint_dir is not None and os.exists( + self.default_checkpoint_dir): + shutil.rmtree(self.default_checkpoint_dir) + logger.debug("Clearing default checkpoint: %s", + self.default_checkpoint_dir) + + checkpoint_dir = self.create_default_checkpoint_dir() + checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir) + self.restore(checkpoint_path) + def _stop(self): # If everything stayed in synch properly, this should never happen. if not self._results_queue.empty(): @@ -251,7 +339,6 @@ class FunctionRunner(Trainable): # Check for any errors that might have been missed. self._report_thread_runner_error() - session.shutdown() def _report_thread_runner_error(self, block=False): @@ -264,13 +351,35 @@ class FunctionRunner(Trainable): pass +def detect_checkpoint_function(train_func): + func_args = inspect.getfullargspec(train_func).args + use_checkpoint = "checkpoint" in func_args + return use_checkpoint + + def wrap_function(train_func): class ImplicitFunc(FunctionRunner): - def _trainable_func(self, config, reporter): + def _trainable_func(self, config, reporter, checkpoint): func_args = inspect.getfullargspec(train_func).args - use_track = ("reporter" not in func_args and len(func_args) == 1) - if use_track: + if len(func_args) > 1: # more arguments than just the config + if "reporter" not in func_args and ( + "checkpoint" not in func_args): + raise ValueError( + "Unknown argument found in the Trainable function. " + "Arguments other than the 'config' arg must be one " + "of ['reporter', 'checkpoint']. Found: {}".format( + func_args)) + use_reporter = "reporter" in func_args + use_checkpoint = "checkpoint" in func_args + if not use_checkpoint and not use_reporter: + logger.warning( + "Function checkpointing is disabled. This may result in " + "unexpected behavior when using checkpointing features or " + "certain schedulers. To enable, set the train function " + "arguments to be `func(config, checkpoint)`.") output = train_func(config) + elif use_checkpoint: + output = train_func(config, checkpoint=checkpoint) else: output = train_func(config, reporter) diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index c6f8adc1b..b2019ff03 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -669,8 +669,7 @@ class RayTrialExecutor(TrialExecutor): elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where a DurableTrainable is not provided. - logger.warning("Trial %s: Reading checkpoint into memory.", - trial) + logger.debug("Trial %s: Reading checkpoint into memory", trial) data_dict = TrainableUtil.pickle_checkpoint(value) with self._change_working_directory(trial): remote = trial.runner.restore_from_object.remote(data_dict) diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 654cde12d..94897b71d 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -381,10 +381,14 @@ class Bracket: assert trial in self._live_trials assert self._get_result_time(result) >= 0 + observed_time = self._get_result_time(result) + last_observed = self._get_result_time(self._live_trials[trial]) - delta = self._get_result_time(result) - \ - self._get_result_time(self._live_trials[trial]) - assert delta >= 0, (result, self._live_trials[trial]) + delta = last_observed - observed_time + if delta >= 0: + logger.info("Restoring from a previous point in time. " + "Previous={}; Now={}".format(last_observed, + observed_time)) self._completed_progress += delta self._live_trials[trial] = result @@ -424,7 +428,7 @@ class Bracket: def _calculate_total_work(self, n, r, s): work = 0 cumulative_r = r - for i in range(s + 1): + for _ in range(s + 1): work += int(n) * int(r) n /= self._eta n = int(np.ceil(n)) diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index 9b228b68e..4155e9702 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -5,29 +5,6 @@ logger = logging.getLogger(__name__) _session = None -class _ReporterSession: - def __init__(self, tune_reporter): - self.tune_reporter = tune_reporter - - def report(self, **metrics): - return self.tune_reporter(**metrics) - - @property - def logdir(self): - """Trial logdir (subdir of given experiment directory)""" - return self.tune_reporter.logdir - - @property - def trial_name(self): - """Trial name for the corresponding trial of this Trainable""" - return self.tune_reporter.trial_name - - @property - def trial_id(self): - """Trial id for the corresponding trial of this Trainable""" - return self.tune_reporter.trial_id - - def get_session(): global _session if _session is None: @@ -56,7 +33,11 @@ def init(reporter, ignore_reinit_error=True): else: raise ValueError(reinit_msg) - _session = _ReporterSession(reporter) + if reporter is None: + logger.warning("You are using a Tune session outside of Tune. " + "Most session commands will have no effect.") + + _session = reporter def shutdown(): @@ -86,34 +67,109 @@ def report(**kwargs): metrics can be used for early stopping or optimization. """ _session = get_session() - return _session.report(**kwargs) + return _session(**kwargs) + + +def make_checkpoint_dir(step=None): + """Gets the next checkpoint dir. + + .. code-block:: python + + import time + from ray import tune + + def func(config, checkpoint=None): + start = 0 + if checkpoint: + with open(checkpoint) as f: + state = json.loads(f.read()) + start = state["step"] + 1 + + for iter in range(start, 100): + time.sleep(1) + + checkpoint_dir = tune.make_checkpoint_dir(step=step) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) + tune.save_checkpoint(path) + + tune.report(hello="world", ray="tune") + + Args: + step (int): Current training iteration - used for setting + an index to uniquely identify the checkpoint. + + .. versionadded:: 0.8.6 + + """ + _session = get_session() + return _session.make_checkpoint_dir(step=step) + + +def save_checkpoint(checkpoint): + """Register the given checkpoint. + + .. code-block:: python + + import os + import json + import time + from ray import tune + + def func(config, checkpoint=None): + start = 0 + if checkpoint: + with open(checkpoint) as f: + state = json.loads(f.read()) + accuracy = state["acc"] + start = state["step"] + 1 + + for iter in range(start, 10): + time.sleep(1) + + checkpoint_dir = tune.make_checkpoint_dir(step=iter) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) + tune.save_checkpoint(path) + + tune.report(hello="world", ray="tune") + + analysis = tune.run(run_me) + + Args: + **kwargs: Any key value pair to be logged by Tune. Any of these + metrics can be used for early stopping or optimization. + + .. versionadded:: 0.8.6 + """ + _session = get_session() + return _session.save_checkpoint(checkpoint) def get_trial_dir(): """Returns the directory where trial results are saved. - For function API use only. Do not call this method in the Class API. Use - `self.logdir` instead. + For function API use only. """ _session = get_session() return _session.logdir def get_trial_name(): - """Trial name for the corresponding trial of this Trainable. + """Trial name for the corresponding trial. - For function API use only. Do not call this method in the Class API. Use - `self.trial_name` instead. + For function API use only. """ _session = get_session() return _session.trial_name def get_trial_id(): - """Trial id for the corresponding trial of this Trainable. + """Trial id for the corresponding trial. - For function API use only. Do not call this method in the Class API. Use - `self.trial_id` instead. + For function API use only. """ _session = get_session() return _session.trial_id diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py index 6fa26703a..9e58c4cb1 100644 --- a/python/ray/tune/suggest/ax.py +++ b/python/ray/tune/suggest/ax.py @@ -58,7 +58,7 @@ class AxSearch(Searcher): def easy_objective(config): for i in range(100): intermediate_result = config["x1"] + config["x2"] * i - tune.track.log(score=intermediate_result) + tune.report(score=intermediate_result) client = AxClient(enforce_sequential_optimization=False) client.create_experiment(parameters=parameters, objective_name="score") diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 6c1e8fc56..47e4f821b 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -394,10 +394,28 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, } # The following patches only affect __fake_remote. - find_checkpoint_dir = TrainableUtil.find_checkpoint_dir - with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer: - trainable_util = "ray.tune.ray_trial_executor.TrainableUtil" - with patch(trainable_util + ".find_checkpoint_dir") as mock_find_dir: + def hide_remote_path(path_function): + def hidden_path_func(checkpoint_path): + """Converts back to local path first.""" + if MOCK_REMOTE_DIR in checkpoint_path: + checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):] + checkpoint_path = os.path.join("/", checkpoint_path) + return path_function(checkpoint_path) + + return hidden_path_func + + trainable_util = "ray.tune.ray_trial_executor.TrainableUtil" + _find_ckpt = trainable_util + ".find_checkpoint_dir" + find_func = TrainableUtil.find_checkpoint_dir + _pickle_ckpt = trainable_util + ".pickle_checkpoint" + pickle_func = TrainableUtil.pickle_checkpoint + + with patch(_find_ckpt) as mock_find, patch(_pickle_ckpt) as mock_pkl_ckpt: + # __fake_remote trainables save to a separate "remote" directory. + # TrainableUtil will not check this path unless we mock it. + mock_find.side_effect = hide_remote_path(find_func) + mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func) + with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer: def mock_get_syncer_fn(local_dir, remote_dir, sync_function): client = mock_storage_client() @@ -405,16 +423,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, mock_get_node_syncer.side_effect = mock_get_syncer_fn - def mock_find_dir_fn(checkpoint_path): - """Converts back to local path first.""" - checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):] - checkpoint_path = os.path.join("/", checkpoint_path) - return find_checkpoint_dir(checkpoint_path) - - # __fake_remote trainables save to a separate "remote" directory. - # TrainableUtil will not check this path unless we mock it. - mock_find_dir.side_effect = mock_find_dir_fn - # Test recovery of trial that has been checkpointed t1 = Trial(trainable_id, **kwargs) runner.add_trial(t1) @@ -428,7 +436,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, cluster.remove_node(node) cluster.wait_for_nodes() shutil.rmtree(os.path.dirname(t1.checkpoint.value)) - runner.step() # Collect result 3, kick off + fail result 4 runner.step() # Dispatch restore runner.step() # Process restore + step 4 diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py new file mode 100644 index 000000000..8c390799f --- /dev/null +++ b/python/ray/tune/tests/test_function_api.py @@ -0,0 +1,169 @@ +import json +import os +import unittest + +import ray +from ray.rllib import _register_all + +from ray import tune +from ray.tune.function_runner import wrap_function +from ray.tune.result import TRAINING_ITERATION + + +class FunctionApiTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024) + + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + + def testFunctionNoCheckpointing(self): + def train(config, checkpoint=None): + for i in range(10): + tune.report(test=i) + + wrapped = wrap_function(train) + + new_trainable = wrapped() + result = new_trainable.train() + checkpoint = new_trainable.save() + new_trainable.stop() + + new_trainable2 = wrapped() + new_trainable2.restore(checkpoint) + result = new_trainable2.train() + self.assertEquals(result[TRAINING_ITERATION], 1) + checkpoint = new_trainable2.save() + new_trainable2.stop() + + def testFunctionRecurringSave(self): + """This tests that save and restore are commutative.""" + + def train(config, checkpoint=None): + for step in range(10): + if step % 3 == 0: + checkpoint_dir = tune.make_checkpoint_dir(step=step) + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": step})) + tune.save_checkpoint(path) + tune.report(test=step) + + wrapped = wrap_function(train) + + new_trainable = wrapped() + new_trainable.train() + checkpoint_obj = new_trainable.save_to_object() + new_trainable.restore_from_object(checkpoint_obj) + checkpoint = new_trainable.save() + new_trainable.stop() + + new_trainable2 = wrapped() + new_trainable2.restore(checkpoint) + new_trainable2.train() + new_trainable2.stop() + + def testCheckpointFunctionAtEnd(self): + def train(config, checkpoint=False): + for i in range(10): + tune.report(test=i) + checkpoint_dir = tune.make_checkpoint_dir(step=10) + checkpoint_path = os.path.join(checkpoint_dir, "hello") + with open(checkpoint_path, "w") as f: + f.write("hello") + tune.save_checkpoint(checkpoint_path) + + [trial] = tune.run(train).trials + assert "hello" in trial.checkpoint.value + + def testVariousCheckpointFunctionAtEnd(self): + def train(config, checkpoint=False): + for i in range(10): + checkpoint_dir = tune.make_checkpoint_dir() + checkpoint_path = os.path.join(checkpoint_dir, "hello") + with open(checkpoint_path, "w") as f: + f.write("hello") + tune.save_checkpoint(checkpoint_path) + tune.report(test=i) + checkpoint_dir = tune.make_checkpoint_dir() + checkpoint_path = os.path.join(checkpoint_dir, "goodbye") + with open(checkpoint_path, "w") as f: + f.write("goodbye") + tune.save_checkpoint(checkpoint_path) + + [trial] = tune.run(train, keep_checkpoints_num=3).trials + assert "goodbye" in trial.checkpoint.value + + def testReuseCheckpoint(self): + def train(config, checkpoint=False): + itr = 0 + if checkpoint: + with open(checkpoint, "r") as f: + itr = int(f.read()) + 1 + + for i in range(itr, config["max_iter"]): + checkpoint_dir = tune.make_checkpoint_dir(step=i) + checkpoint_path = os.path.join(checkpoint_dir, "goodbye") + with open(checkpoint_path, "w") as f: + f.write(str(i)) + tune.save_checkpoint(checkpoint_path) + tune.report(test=i, training_iteration=i) + + [trial] = tune.run( + train, + config={ + "max_iter": 5 + }, + ).trials + last_ckpt = trial.checkpoint.value + assert "goodbye" in last_ckpt + analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt) + trial_dfs = list(analysis.trial_dataframes.values()) + assert len(trial_dfs[0]["training_iteration"]) == 5 + + def testRetry(self): + def train(config, checkpoint=None): + restored = bool(checkpoint) + itr = 0 + if checkpoint: + with open(checkpoint, "r") as f: + itr = int(f.read()) + 1 + + for i in range(itr, 10): + if i == 5 and not restored: + raise Exception("try to fail me") + checkpoint_dir = tune.make_checkpoint_dir(step=i) + checkpoint_path = os.path.join(checkpoint_dir, "goodbye") + with open(checkpoint_path, "w") as f: + f.write(str(i)) + tune.save_checkpoint(checkpoint_path) + tune.report(test=i, training_iteration=i) + + analysis = tune.run(train, max_failures=3) + last_ckpt = analysis.trials[0].checkpoint.value + assert "goodbye" in last_ckpt + trial_dfs = list(analysis.trial_dataframes.values()) + assert len(trial_dfs[0]["training_iteration"]) == 10 + + def testBlankCheckpoint(self): + def train(config, checkpoint=None): + restored = bool(checkpoint) + itr = 0 + if checkpoint: + with open(checkpoint, "r") as f: + itr = int(f.read()) + 1 + + for i in range(itr, 10): + if i == 5 and not restored: + raise Exception("try to fail me") + checkpoint_dir = tune.make_checkpoint_dir() + checkpoint_path = os.path.join(checkpoint_dir, "goodbye") + with open(checkpoint_path, "w") as f: + f.write(str(i)) + tune.save_checkpoint(checkpoint_path) + tune.report(test=i, training_iteration=i) + + analysis = tune.run(train, max_failures=3) + trial_dfs = list(analysis.trial_dataframes.values()) + assert len(trial_dfs[0]["training_iteration"]) == 10 diff --git a/python/ray/tune/tests/test_track.py b/python/ray/tune/tests/test_track.py index 5ad39c67f..7abb8df4e 100644 --- a/python/ray/tune/tests/test_track.py +++ b/python/ray/tune/tests/test_track.py @@ -17,18 +17,6 @@ class TrackApiTest(unittest.TestCase): session.shutdown() ray.shutdown() - def testSessionInitShutdown(self): - self.assertTrue(session._session is None) - - # Checks that the singleton _session is created/destroyed - # by session.init() and session.shutdown() - for _ in range(2): - # do it twice to see that we can reopen the session - session.init(reporter=None) - self.assertTrue(session._session is not None) - session.shutdown() - self.assertTrue(session._session is None) - def testSoftDeprecation(self): """Checks that tune.track.log code does not break.""" from ray.tune import track diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 330c5adae..3296c3115 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -28,6 +28,34 @@ SETUP_TIME_THRESHOLD = 10 class TrainableUtil: + @staticmethod + def process_checkpoint(checkpoint, parent_dir, trainable_state): + saved_as_dict = False + if isinstance(checkpoint, string_types): + if not checkpoint.startswith(parent_dir): + raise ValueError( + "The returned checkpoint path must be within the " + "given checkpoint dir {}: {}".format( + parent_dir, checkpoint)) + checkpoint_path = checkpoint + if os.path.isdir(checkpoint_path): + # Add trailing slash to prevent tune metadata from + # being written outside the directory. + checkpoint_path = os.path.join(checkpoint_path, "") + elif isinstance(checkpoint, dict): + saved_as_dict = True + checkpoint_path = os.path.join(parent_dir, "checkpoint") + with open(checkpoint_path, "wb") as f: + pickle.dump(checkpoint, f) + else: + raise ValueError("Returned unexpected type {}. " + "Expected str or dict.".format(type(checkpoint))) + + with open(checkpoint_path + ".tune_metadata", "wb") as f: + trainable_state["saved_as_dict"] = saved_as_dict + pickle.dump(trainable_state, f) + return checkpoint_path + @staticmethod def pickle_checkpoint(checkpoint_path): """Pickles checkpoint data.""" @@ -39,7 +67,8 @@ class TrainableUtil: with open(path, "rb") as f: data[os.path.relpath(path, checkpoint_dir)] = f.read() # Use normpath so that a directory path isn't mapped to empty string. - name = os.path.basename(os.path.normpath(checkpoint_path)) + name = os.path.relpath( + os.path.normpath(checkpoint_path), checkpoint_dir) name += os.path.sep if os.path.isdir(checkpoint_path) else "" data_dict = pickle.dumps({ "checkpoint_name": name, @@ -70,11 +99,38 @@ class TrainableUtil: return checkpoint_dir @staticmethod - def make_checkpoint_dir(checkpoint_dir): - """Creates a checkpoint directory at the provided path.""" + def make_checkpoint_dir(checkpoint_dir, index): + """Creates a checkpoint directory within the provided path. + + Args: + checkpoint_dir (str): Path to checkpoint directory. + index (str): A subdirectory will be created + at the checkpoint directory named 'checkpoint_{index}'. + """ + suffix = "checkpoint" + if index is not None: + suffix += "_{}".format(index) + checkpoint_dir = os.path.join(checkpoint_dir, suffix) + os.makedirs(checkpoint_dir, exist_ok=True) # Drop marker in directory to identify it as a checkpoint dir. open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close() + return checkpoint_dir + + @staticmethod + def create_from_pickle(obj, tmpdir): + info = pickle.loads(obj) + data = info["data"] + checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"]) + + for relpath_name, file_contents in data.items(): + path = os.path.join(tmpdir, relpath_name) + + # This may be a subdirectory, hence not just using tmpdir + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "wb") as f: + f.write(file_contents) + return checkpoint_path @staticmethod def get_checkpoints_paths(logdir): @@ -324,6 +380,16 @@ class Trainable: return result + def get_state(self): + return { + "experiment_id": self._experiment_id, + "iteration": self._iteration, + "timesteps_total": self._timesteps_total, + "time_total": self._time_total, + "episodes_total": self._episodes_total, + "ray_version": ray.__version__, + } + def save(self, checkpoint_dir=None): """Saves the current model state to a checkpoint. @@ -336,41 +402,14 @@ class Trainable: Returns: str: Checkpoint path or prefix that may be passed to restore(). """ - checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, - "checkpoint_{}".format(self._iteration)) - TrainableUtil.make_checkpoint_dir(checkpoint_dir) + checkpoint_dir = TrainableUtil.make_checkpoint_dir( + checkpoint_dir or self.logdir, index=self.iteration) checkpoint = self._save(checkpoint_dir) - saved_as_dict = False - if isinstance(checkpoint, string_types): - if not checkpoint.startswith(checkpoint_dir): - raise ValueError( - "The returned checkpoint path must be within the " - "given checkpoint dir {}: {}".format( - checkpoint_dir, checkpoint)) - checkpoint_path = checkpoint - if os.path.isdir(checkpoint_path): - # Add trailing slash to prevent tune metadata from - # being written outside the directory. - checkpoint_path = os.path.join(checkpoint_path, "") - elif isinstance(checkpoint, dict): - saved_as_dict = True - checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") - with open(checkpoint_path, "wb") as f: - pickle.dump(checkpoint, f) - else: - raise ValueError("Returned unexpected type {}. " - "Expected str or dict.".format(type(checkpoint))) - - with open(checkpoint_path + ".tune_metadata", "wb") as f: - pickle.dump({ - "experiment_id": self._experiment_id, - "iteration": self._iteration, - "timesteps_total": self._timesteps_total, - "time_total": self._time_total, - "episodes_total": self._episodes_total, - "saved_as_dict": saved_as_dict, - "ray_version": ray.__version__, - }, f) + trainable_state = self.get_state() + checkpoint_path = TrainableUtil.process_checkpoint( + checkpoint, + parent_dir=checkpoint_dir, + trainable_state=trainable_state) return checkpoint_path def save_to_object(self): @@ -434,19 +473,8 @@ class Trainable: These checkpoints are returned from calls to save_to_object(). """ - info = pickle.loads(obj) - data = info["data"] tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) - checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"]) - - for relpath_name, file_contents in data.items(): - path = os.path.join(tmpdir, relpath_name) - - # This may be a subdirectory, hence not just using tmpdir - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "wb") as f: - f.write(file_contents) - + checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir) self.restore(checkpoint_path) shutil.rmtree(tmpdir) @@ -531,7 +559,10 @@ class Trainable: name = self.trial_name """ - return self._trial_info.trial_name + if self._trial_info: + return self._trial_info.trial_name + else: + return "default" @property def trial_id(self): @@ -543,7 +574,10 @@ class Trainable: trial_id = self.trial_id """ - return self._trial_info.trial_id + if self._trial_info: + return self._trial_info.trial_id + else: + return "default" @property def iteration(self): diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index a85476afc..e27678d32 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -385,10 +385,14 @@ class TrialRunner: self.trial_executor.try_checkpoint_metadata(trial) def debug_string(self, delim="\n"): + result_keys = [ + list(t.last_result) for t in self.get_trials() if t.last_result + ] + metrics = set().union(*result_keys) messages = [ self._scheduler_alg.debug_string(), self.trial_executor.debug_string(), - trial_progress_str(self.get_trials()), + trial_progress_str(self.get_trials(), metrics), ] return delim.join(messages) @@ -468,6 +472,7 @@ class TrialRunner: result = self.trial_executor.fetch_result(trial) is_duplicate = RESULT_DUPLICATE in result + force_checkpoint = result.get(SHOULD_CHECKPOINT, False) # TrialScheduler and SearchAlgorithm still receive a # notification because there may be special handling for # the `on_trial_complete` hook. @@ -506,8 +511,7 @@ class TrialRunner: # the scheduler decision is STOP or PAUSE. Note that # PAUSE only checkpoints to memory and does not update # the global checkpoint state. - self._checkpoint_trial_if_needed( - trial, force=result.get(SHOULD_CHECKPOINT, False)) + self._checkpoint_trial_if_needed(trial, force=force_checkpoint) if trial.is_saving: # Cache decision to execute on after the save is processed. diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index bf5716d91..4c2229d95 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -110,7 +110,10 @@ def run(run_or_experiment, function or class, or the string identifier of a trainable function or class registered in the tune registry. If Experiment, then Tune will execute training based on - Experiment.spec. + Experiment.spec. If you want to pass in a Python lambda, you + will need to first register the function: + ``tune.register_trainable("lambda_id", lambda x: ...)``. You can + then use ``tune.run("lambda_id")``. name (str): Name of experiment. stop (dict | callable | :class:`Stopper`): Stopping criteria. If dict, the keys may be any field in the return result of 'train()', @@ -154,8 +157,10 @@ def run(run_or_experiment, syncing to driver is disabled. checkpoint_freq (int): How many training iterations between checkpoints. A value of 0 (default) disables checkpointing. + This has no effect when using the Functional Training API. checkpoint_at_end (bool): Whether to checkpoint at the end of the experiment regardless of the checkpoint_freq. Default is False. + This has no effect when using the Functional Training API. sync_on_checkpoint (bool): Force sync-down of trial checkpoint to driver. If set to False, checkpoint syncing from worker to driver is asynchronous and best-effort. This does not affect persistent @@ -214,6 +219,8 @@ def run(run_or_experiment, if using a RayTrialExecutor (which is the default) and if Ray is not initialized. Defaults to True. + + Returns: ExperimentAnalysis: Object for experiment analysis.