From d35f0e40d07bab06b41ce5493c2f50b6725a1857 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 1 Jul 2020 11:00:00 -0700 Subject: [PATCH] [tune] Use public methods for trainable (#9184) --- doc/source/rllib-dev.rst | 4 +- .../tune/_tutorials/tune-60-seconds.rst | 4 +- .../tune/_tutorials/tune-distributed.rst | 2 +- doc/source/tune/_tutorials/tune-usage.rst | 4 +- doc/source/tune/api_docs/logging.rst | 6 +- doc/source/tune/api_docs/trainable.rst | 26 +-- python/ray/tune/durable_trainable.py | 3 +- .../tune/examples/async_hyperband_example.py | 8 +- python/ray/tune/examples/bohb_example.py | 8 +- .../examples/durable_trainable_example.py | 8 +- python/ray/tune/examples/hyperband_example.py | 8 +- python/ray/tune/examples/logging_example.py | 8 +- .../tune/examples/mnist_pytorch_trainable.py | 8 +- .../ray/tune/examples/pbt_convnet_example.py | 8 +- .../pbt_dcgan_mnist/pbt_dcgan_mnist.py | 8 +- python/ray/tune/examples/pbt_example.py | 8 +- python/ray/tune/examples/pbt_memnn_example.py | 8 +- .../examples/pbt_tune_cifar10_with_keras.py | 10 +- python/ray/tune/examples/tf_mnist_example.py | 4 +- python/ray/tune/function_runner.py | 8 +- python/ray/tune/session.py | 4 + python/ray/tune/tests/test_actor_reuse.py | 8 +- python/ray/tune/tests/test_api.py | 71 +++++-- python/ray/tune/tests/test_cluster.py | 8 +- .../tests/test_experiment_analysis_mem.py | 8 +- .../ray/tune/tests/test_ray_trial_executor.py | 2 +- python/ray/tune/tests/test_run_experiment.py | 12 +- python/ray/tune/tests/test_trial_scheduler.py | 12 +- .../ray/tune/tests/test_tune_save_restore.py | 16 +- python/ray/tune/trainable.py | 185 +++++++++++++----- python/ray/tune/trial_runner.py | 1 - python/ray/util/sgd/tf/tf_trainer.py | 10 +- python/ray/util/sgd/torch/torch_trainer.py | 12 +- rllib/__init__.py | 2 +- rllib/agents/ars/ars.py | 4 +- rllib/agents/es/es.py | 4 +- rllib/agents/mock.py | 12 +- rllib/agents/trainer.py | 14 +- rllib/agents/trainer_template.py | 32 +-- rllib/contrib/random_agent/random_agent.py | 2 +- 40 files changed, 350 insertions(+), 220 deletions(-) diff --git a/doc/source/rllib-dev.rst b/doc/source/rllib-dev.rst index 03e88f353..2c917f9e8 100644 --- a/doc/source/rllib-dev.rst +++ b/doc/source/rllib-dev.rst @@ -31,7 +31,7 @@ Contributing Algorithms These are the guidelines for merging new algorithms into RLlib: * Contributed algorithms (`rllib/contrib `__): - - must subclass Trainer and implement the ``_train()`` method + - must subclass Trainer and implement the ``step()`` method - must include a lightweight test (`example `__) to ensure the algorithm runs - should include tuned hyperparameter examples and documentation - should offer functionality not present in existing algorithms @@ -46,7 +46,7 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a How to add an algorithm to ``contrib`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -It takes just two changes to add an algorithm to `contrib `__. A minimal example can be found `here `__. First, subclass `Trainer `__ and implement the ``_init`` and ``_train`` methods: +It takes just two changes to add an algorithm to `contrib `__. A minimal example can be found `here `__. First, subclass `Trainer `__ and implement the ``_init`` and ``step`` methods: .. literalinclude:: ../../rllib/contrib/random_agent/random_agent.py :language: python diff --git a/doc/source/tune/_tutorials/tune-60-seconds.rst b/doc/source/tune/_tutorials/tune-60-seconds.rst index 1ef6af17c..fe2c46dea 100644 --- a/doc/source/tune/_tutorials/tune-60-seconds.rst +++ b/doc/source/tune/_tutorials/tune-60-seconds.rst @@ -42,13 +42,13 @@ The other is a :ref:`class-based API `. Here's an example of spe from ray import tune class Trainable(tune.Trainable): - def _setup(self, config): + def setup(self, config): # config (dict): A dict of hyperparameters self.x = 0 self.a = config["a"] self.b = config["b"] - def _train(self): # This is called iteratively. + def step(self): # This is called iteratively. score = objective(self.x, self.a, self.b) self.x += 1 return {"score": score} diff --git a/doc/source/tune/_tutorials/tune-distributed.rst b/doc/source/tune/_tutorials/tune-distributed.rst index d2986d8e0..e69490116 100644 --- a/doc/source/tune/_tutorials/tune-distributed.rst +++ b/doc/source/tune/_tutorials/tune-distributed.rst @@ -215,7 +215,7 @@ In GCP, you can use the following configuration modification: Spot instances may be removed suddenly while trials are still running. Often times this may be difficult to deal with when using other distributed hyperparameter optimization frameworks. Tune allows users to mitigate the effects of this by preserving the progress of your model training through checkpointing. -The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement ``_save``, and ``_restore`` abstract methods, as seen in the example below: +The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement ``save_checkpoint``, and ``load_checkpoint`` abstract methods, as seen in the example below: .. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_trainable.py :language: python diff --git a/doc/source/tune/_tutorials/tune-usage.rst b/doc/source/tune/_tutorials/tune-usage.rst index d26c32ad3..cc9feb3d4 100644 --- a/doc/source/tune/_tutorials/tune-usage.rst +++ b/doc/source/tune/_tutorials/tune-usage.rst @@ -122,7 +122,7 @@ You can log arbitrary values and metrics in both training APIs: class Trainable(tune.Trainable): ... - def _train(self): # this is called iteratively + def step(self): # this is called iteratively accuracy = self.model.train() metric_1 = f(self.model) metric_2 = self.model.get_loss() @@ -223,7 +223,7 @@ Stopping Trials You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function. -If a dictionary is passed in, the keys may be any field in the return result of ``tune.report`` in the Function API or ``_train()`` (including the results from ``_train`` and auto-filled metrics). +If a dictionary is passed in, the keys may be any field in the return result of ``tune.report`` in the Function API or ``step()`` (including the results from ``step`` and auto-filled metrics). In the example below, each trial will be stopped either when it completes 10 iterations OR when it reaches a mean accuracy of 0.98. These metrics are assumed to be **increasing**. diff --git a/doc/source/tune/api_docs/logging.rst b/doc/source/tune/api_docs/logging.rst index 4e985c1ce..fa9341b3f 100644 --- a/doc/source/tune/api_docs/logging.rst +++ b/doc/source/tune/api_docs/logging.rst @@ -99,7 +99,7 @@ You can do this in the trainable, as shown below: .. code-block:: python class CustomLogging(tune.Trainable) - def _setup(self, config): + def setup(self, config): trial_id = self.trial_id library.init( name=trial_id, @@ -109,10 +109,10 @@ You can do this in the trainable, as shown below: allow_val_change=True) library.set_log_path(self.logdir) - def _train(self): + def step(self): library.log_model(...) - def _log_result(self, result): + def log_result(self, result): res_dict = { str(k): v for k, v in result.items() diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 89a8e400a..ecd11ce20 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -61,7 +61,7 @@ Many Tune features rely on checkpointing, including the usage of certain Trial S for iter in range(start, 100): time.sleep(1) - # + # checkpoint_dir = tune.make_checkpoint_dir(step=step) path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: @@ -101,13 +101,13 @@ The Trainable **class API** will require users to subclass ``ray.tune.Trainable` from ray import tune class Trainable(tune.Trainable): - def _setup(self, config): + def setup(self, config): # config (dict): A dict of hyperparameters self.x = 0 self.a = config["a"] self.b = config["b"] - def _train(self): # This is called iteratively. + def step(self): # This is called iteratively. score = objective(self.x, self.a, self.b) self.x += 1 return {"score": score} @@ -124,11 +124,11 @@ The Trainable **class API** will require users to subclass ``ray.tune.Trainable` As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on a separate process (using the :ref:`Ray Actor API `). - 1. ``_setup`` function is invoked once training starts. - 2. ``_train`` is invoked **multiple times**. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training. - 3. ``_stop`` is invoked when training is finished. + 1. ``setup`` function is invoked once training starts. + 2. ``step`` is invoked **multiple times**. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training. + 3. ``cleanup`` is invoked when training is finished. -.. tip:: As a rule of thumb, the execution time of ``_train`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes). +.. tip:: As a rule of thumb, the execution time of ``step`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes). .. _tune-trainable-save-restore: @@ -141,12 +141,12 @@ You can also implement checkpoint/restore using the Trainable Class API: .. code-block:: python class MyTrainableClass(Trainable): - def _save(self, tmp_checkpoint_dir): + def save_checkpoint(self, tmp_checkpoint_dir): checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") torch.save(self.model.state_dict(), checkpoint_path) return tmp_checkpoint_dir - def _restore(self, tmp_checkpoint_dir): + def load_checkpoint(self, tmp_checkpoint_dir): checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") self.model.load_state_dict(torch.load(checkpoint_path)) @@ -154,11 +154,11 @@ You can also implement checkpoint/restore using the Trainable Class API: You can checkpoint with three different mechanisms: manually, periodically, and at termination. -**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances: +**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `step`. This can be especially helpful in spot instances: .. code-block:: python - def _train(self): + def step(self): # training code result = {"mean_accuracy": accuracy} if detect_instance_preemption(): @@ -190,7 +190,7 @@ of a trial, you can additionally set the ``checkpoint_at_end=True``: ) -Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution. +Use ``validate_save_restore`` to catch ``save_checkpoint``/``load_checkpoint`` errors before execution. .. code-block:: python @@ -214,7 +214,7 @@ This requires you to implement ``Trainable.reset_config``, which provides a new class PytorchTrainble(tune.Trainable): """Train a Pytorch ConvNet.""" - def _setup(self, config): + def setup(self, config): self.train_loader, self.test_loader = get_data_loaders() self.model = ConvNet() self.optimizer = optim.SGD( diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index fe2e7a5d7..d6a12839c 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -11,8 +11,7 @@ class DurableTrainable(Trainable): """Abstract class for a remote-storage backed fault-tolerant Trainable. Supports checkpointing to and restoring from remote storage. To use this - class, implement the same private methods as ray.tune.Trainable (`_save`, - `_train`, `_restore`, `reset_config`, `_setup`, `_stop`). + class, implement the same private methods as ray.tune.Trainable. .. warning:: This class is currently **experimental** and may be subject to change. diff --git a/python/ray/tune/examples/async_hyperband_example.py b/python/ray/tune/examples/async_hyperband_example.py index ebf30f66b..bd4dc4c59 100644 --- a/python/ray/tune/examples/async_hyperband_example.py +++ b/python/ray/tune/examples/async_hyperband_example.py @@ -19,10 +19,10 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self, config): + def setup(self, config): self.timestep = 0 - def _train(self): + def step(self): self.timestep += 1 v = np.tanh(float(self.timestep) / self.config.get("width", 1)) v *= self.config.get("height", 1) @@ -31,13 +31,13 @@ class MyTrainableClass(Trainable): # objectives such as loss or accuracy. return {"episode_reward_mean": v} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"timestep": self.timestep})) return path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): with open(checkpoint_path) as f: self.timestep = json.loads(f.read())["timestep"] diff --git a/python/ray/tune/examples/bohb_example.py b/python/ray/tune/examples/bohb_example.py index a7478e5f1..908fef3c3 100644 --- a/python/ray/tune/examples/bohb_example.py +++ b/python/ray/tune/examples/bohb_example.py @@ -18,10 +18,10 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self, config): + def setup(self, config): self.timestep = 0 - def _train(self): + def step(self): self.timestep += 1 v = np.tanh(float(self.timestep) / self.config.get("width", 1)) v *= self.config.get("height", 1) @@ -30,13 +30,13 @@ class MyTrainableClass(Trainable): # objectives such as loss or accuracy. return {"episode_reward_mean": v} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"timestep": self.timestep})) return path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): with open(checkpoint_path) as f: self.timestep = json.loads(f.read())["timestep"] diff --git a/python/ray/tune/examples/durable_trainable_example.py b/python/ray/tune/examples/durable_trainable_example.py index ab7bb13d8..62ddc0d1a 100644 --- a/python/ray/tune/examples/durable_trainable_example.py +++ b/python/ray/tune/examples/durable_trainable_example.py @@ -51,7 +51,7 @@ class OptimusFn(object): def get_optimus_trainable(parent_cls): class OptimusTrainable(parent_cls): - def _setup(self, config): + def setup(self, config): self.iter = 0 if config.get("seed"): np.random.seed(config["seed"]) @@ -61,7 +61,7 @@ def get_optimus_trainable(parent_cls): self.initial_samples_per_step = 500 self.mock_data = open("/dev/urandom", "rb").read(1024) - def _train(self): + def step(self): self.iter += 1 new_loss = self.func.eval(self.iter) time.sleep(0.5) @@ -71,7 +71,7 @@ def get_optimus_trainable(parent_cls): "samples": self.initial_samples_per_step } - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): time.sleep(0.5) return { "func": cloudpickle.dumps(self.func), @@ -80,7 +80,7 @@ def get_optimus_trainable(parent_cls): "iter": self.iter } - def _restore(self, checkpoint): + def load_checkpoint(self, checkpoint): self.func = cloudpickle.loads(checkpoint["func"]) self.data = checkpoint["data"] self.iter = checkpoint["iter"] diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index f8ee12162..77ec56040 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -19,10 +19,10 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self, config): + def setup(self, config): self.timestep = 0 - def _train(self): + def step(self): self.timestep += 1 v = np.tanh(float(self.timestep) / self.config.get("width", 1)) v *= self.config.get("height", 1) @@ -31,13 +31,13 @@ class MyTrainableClass(Trainable): # objectives such as loss or accuracy. return {"episode_reward_mean": v} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"timestep": self.timestep})) return path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): with open(checkpoint_path) as f: self.timestep = json.loads(f.read())["timestep"] diff --git a/python/ray/tune/examples/logging_example.py b/python/ray/tune/examples/logging_example.py index 9643034e7..d74497113 100755 --- a/python/ray/tune/examples/logging_example.py +++ b/python/ray/tune/examples/logging_example.py @@ -27,10 +27,10 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self, config): + def setup(self, config): self.timestep = 0 - def _train(self): + def step(self): self.timestep += 1 v = np.tanh(float(self.timestep) / self.config.get("width", 1)) v *= self.config.get("height", 1) @@ -39,13 +39,13 @@ class MyTrainableClass(Trainable): # objectives such as loss or accuracy. return {"episode_reward_mean": v} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"timestep": self.timestep})) return path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): with open(checkpoint_path) as f: self.timestep = json.loads(f.read())["timestep"] diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index 43f51cc7b..956a9c024 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -34,7 +34,7 @@ parser.add_argument( # yapf: disable # __trainable_example_begin__ class TrainMNIST(tune.Trainable): - def _setup(self, config): + def setup(self, config): use_cuda = config.get("use_gpu") and torch.cuda.is_available() self.device = torch.device("cuda" if use_cuda else "cpu") self.train_loader, self.test_loader = get_data_loaders() @@ -44,18 +44,18 @@ class TrainMNIST(tune.Trainable): lr=config.get("lr", 0.01), momentum=config.get("momentum", 0.9)) - def _train(self): + def step(self): train( self.model, self.optimizer, self.train_loader, device=self.device) acc = test(self.model, self.test_loader, self.device) return {"mean_accuracy": acc} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "model.pth") torch.save(self.model.state_dict(), checkpoint_path) return checkpoint_path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): self.model.load_state_dict(torch.load(checkpoint_path)) diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index 387fcb4ac..8aa6b1d9d 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -27,7 +27,7 @@ class PytorchTrainble(tune.Trainable): changing the original training code. """ - def _setup(self, config): + def setup(self, config): self.train_loader, self.test_loader = get_data_loaders() self.model = ConvNet() self.optimizer = optim.SGD( @@ -35,17 +35,17 @@ class PytorchTrainble(tune.Trainable): lr=config.get("lr", 0.01), momentum=config.get("momentum", 0.9)) - def _train(self): + def step(self): train(self.model, self.optimizer, self.train_loader) acc = test(self.model, self.test_loader) return {"mean_accuracy": acc} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "model.pth") torch.save(self.model.state_dict(), checkpoint_path) return checkpoint_path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): self.model.load_state_dict(torch.load(checkpoint_path)) def _export_model(self, export_formats, export_dir): diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py index acbef0276..263d27dc4 100644 --- a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py +++ b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py @@ -231,7 +231,7 @@ def train(netD, netG, optimG, optimD, criterion, dataloader, iteration, # __Trainable_begin__ class PytorchTrainable(tune.Trainable): - def _setup(self, config): + def setup(self, config): use_cuda = config.get("use_gpu") and torch.cuda.is_available() self.device = torch.device("cuda" if use_cuda else "cpu") self.netD = Discriminator().to(self.device) @@ -250,13 +250,13 @@ class PytorchTrainable(tune.Trainable): with FileLock(os.path.expanduser("~/.data.lock")): self.dataloader = get_data_loader() - def _train(self): + def step(self): lossG, lossD, is_score = train( self.netD, self.netG, self.optimizerG, self.optimizerD, self.criterion, self.dataloader, self._iteration, self.device) return {"lossg": lossG, "lossd": lossD, "is_score": is_score} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") torch.save({ "netDmodel": self.netD.state_dict(), @@ -267,7 +267,7 @@ class PytorchTrainable(tune.Trainable): return checkpoint_dir - def _restore(self, checkpoint_dir): + def load_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") checkpoint = torch.load(path) self.netD.load_state_dict(checkpoint["netDmodel"]) diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index e0b9900a1..90901e74a 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -31,11 +31,11 @@ class PBTBenchmarkExample(Trainable): faster convergence. Training will not converge without PBT. """ - def _setup(self, config): + def setup(self, config): self.lr = config["lr"] self.accuracy = 0.0 # end = 1000 - def _train(self): + def step(self): midpoint = 100 # lr starts decreasing after acc > midpoint q_tolerance = 3 # penalize exceeding lr by more than this multiple noise_level = 2 # add gaussian noise to the acc increase @@ -66,13 +66,13 @@ class PBTBenchmarkExample(Trainable): "done": self.accuracy > midpoint * 2, } - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): return { "accuracy": self.accuracy, "lr": self.lr, } - def _restore(self, checkpoint): + def load_checkpoint(self, checkpoint): self.accuracy = checkpoint["accuracy"] def reset_config(self, new_config): diff --git a/python/ray/tune/examples/pbt_memnn_example.py b/python/ray/tune/examples/pbt_memnn_example.py index 0570ab675..bd2e246e2 100644 --- a/python/ray/tune/examples/pbt_memnn_example.py +++ b/python/ray/tune/examples/pbt_memnn_example.py @@ -214,7 +214,7 @@ class MemNNModel(Trainable): model = Model([input_sequence, question], answer) return model - def _setup(self, config): + def setup(self, config): with FileLock(os.path.expanduser("~/.tune.lock")): self.train_stories, self.test_stories = read_data() model = self.build_model() @@ -226,7 +226,7 @@ class MemNNModel(Trainable): metrics=["accuracy"]) self.model = model - def _train(self): + def step(self): # train self.model.fit( [self.inputs_train, self.queries_train], @@ -242,12 +242,12 @@ class MemNNModel(Trainable): verbose=0) return {"mean_accuracy": accuracy} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): file_path = checkpoint_dir + "/model" self.model.save(file_path) return file_path - def _restore(self, path): + def load_checkpoint(self, path): # See https://stackoverflow.com/a/42763323 del self.model self.model = load_model(path) diff --git a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py index 7933b06d0..e0403895d 100755 --- a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py +++ b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py @@ -106,7 +106,7 @@ class Cifar10Model(Trainable): model = Model(inputs=x, outputs=y, name="model1") return model - def _setup(self, config): + def setup(self, config): self.train_data, self.test_data = self._read_data() x_train = self.train_data[0] model = self._build_model(x_train.shape[1:]) @@ -120,7 +120,7 @@ class Cifar10Model(Trainable): metrics=["accuracy"]) self.model = model - def _train(self): + def step(self): x_train, y_train = self.train_data x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES] x_test, y_test = self.test_data @@ -161,17 +161,17 @@ class Cifar10Model(Trainable): _, accuracy = self.model.evaluate(x_test, y_test, verbose=0) return {"mean_accuracy": accuracy} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): file_path = checkpoint_dir + "/model" self.model.save(file_path) return file_path - def _restore(self, path): + def load_checkpoint(self, path): # See https://stackoverflow.com/a/42763323 del self.model self.model = load_model(path) - def _stop(self): + def cleanup(self): # If need, save your model when exit. # saved_path = self.model.save(self.logdir) # print("save model at: ", saved_path) diff --git a/python/ray/tune/examples/tf_mnist_example.py b/python/ray/tune/examples/tf_mnist_example.py index 66749ec2d..054df5ceb 100644 --- a/python/ray/tune/examples/tf_mnist_example.py +++ b/python/ray/tune/examples/tf_mnist_example.py @@ -40,7 +40,7 @@ class MyModel(Model): class MNISTTrainable(tune.Trainable): - def _setup(self, config): + def setup(self, config): # IMPORTANT: See the above note. import tensorflow as tf (x_train, y_train), (x_test, y_test) = load_data() @@ -90,7 +90,7 @@ class MNISTTrainable(tune.Trainable): self.tf_train_step = train_step self.tf_test_step = test_step - def _train(self): + def step(self): self.train_loss.reset_states() self.train_accuracy.reset_states() self.test_loss.reset_states() diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index f82fc83f5..9849d9434 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -165,7 +165,7 @@ class FunctionRunner(Trainable): _name = "func" - def _setup(self, config): + def setup(self, config): # Semaphore for notifying the reporter to continue with the computation # and to generate the next result. self._continue_semaphore = threading.Semaphore(0) @@ -212,7 +212,7 @@ class FunctionRunner(Trainable): # now done or has raised an exception. pass - def _train(self): + def step(self): """Implements train() for a Function API. If the RunnerThread finishes without reporting "done", @@ -313,7 +313,7 @@ class FunctionRunner(Trainable): out.write(data_dict) return out.getvalue() - def _restore(self, checkpoint): + def load_checkpoint(self, checkpoint): # This should be removed once Trainables are refactored. if "tune_checkpoint_path" in checkpoint: del checkpoint["tune_checkpoint_path"] @@ -330,7 +330,7 @@ class FunctionRunner(Trainable): checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir) self.restore(checkpoint_path) - def _stop(self): + def cleanup(self): # If everything stayed in synch properly, this should never happen. if not self._results_queue.empty(): logger.warning( diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index 4155e9702..349f8fda8 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -96,6 +96,8 @@ def make_checkpoint_dir(step=None): tune.report(hello="world", ray="tune") + .. warning:: Do not call this function within the Trainable Class API. + Args: step (int): Current training iteration - used for setting an index to uniquely identify the checkpoint. @@ -138,6 +140,8 @@ def save_checkpoint(checkpoint): analysis = tune.run(run_me) + .. warning:: Do not call this function within the Trainable Class API. + Args: **kwargs: Any key value pair to be logged by Tune. Any of these metrics can be used for early stopping or optimization. diff --git a/python/ray/tune/tests/test_actor_reuse.py b/python/ray/tune/tests/test_actor_reuse.py index e1d8df0d3..b2038a10f 100644 --- a/python/ray/tune/tests/test_actor_reuse.py +++ b/python/ray/tune/tests/test_actor_reuse.py @@ -13,19 +13,19 @@ class FrequentPausesScheduler(FIFOScheduler): def create_resettable_class(): class MyResettableClass(Trainable): - def _setup(self, config): + def setup(self, config): self.config = config self.num_resets = 0 self.iter = 0 - def _train(self): + def step(self): self.iter += 1 return {"num_resets": self.num_resets, "done": self.iter > 1} - def _save(self, chkpt_dir): + def save_checkpoint(self, chkpt_dir): return {"iter": self.iter} - def _restore(self, item): + def load_checkpoint(self, item): self.iter = item["iter"] def reset_config(self, new_config): diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 5b9dabf42..51b0c5351 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -63,11 +63,11 @@ class TrainableFunctionApiTest(unittest.TestCase): function_output.append(result) class _WrappedTrainable(Trainable): - def _setup(self, config): + def setup(self, config): del config self._result_iter = copy.deepcopy(class_results) - def _train(self): + def step(self): if sleep_per_iter: time.sleep(sleep_per_iter) res = self._result_iter.pop(0) # This should not fail @@ -233,7 +233,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def default_resource_request(cls, config): return Resources(cpu=config["cpu"], gpu=config["gpu"]) - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} register_trainable("B", B) @@ -628,7 +628,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def testTrialInfoAccess(self): class TestTrainable(Trainable): - def _train(self): + def step(self): result = {"name": self.trial_name, "trial_id": self.trial_id} print(result) return result @@ -659,11 +659,11 @@ class TrainableFunctionApiTest(unittest.TestCase): @patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3) def testLotsOfStops(self): class TestTrainable(Trainable): - def _train(self): + def step(self): result = {"name": self.trial_name, "trial_id": self.trial_id} return result - def _stop(self): + def cleanup(self): time.sleep(2) open(os.path.join(self.logdir, "marker"), "a").close() return 1 @@ -825,17 +825,17 @@ class TrainableFunctionApiTest(unittest.TestCase): def testDurableTrainable(self): class TestTrain(DurableTrainable): - def _setup(self, config): + def setup(self, config): self.state = {"hi": 1, "iter": 0} - def _train(self): + def step(self): self.state["iter"] += 1 return {"timesteps_this_iter": 1, "done": True} - def _save(self, path): + def save_checkpoint(self, path): return self.state - def _restore(self, state): + def load_checkpoint(self, state): self.state = state sync_client = mock_storage_client() @@ -853,16 +853,16 @@ class TrainableFunctionApiTest(unittest.TestCase): def testCheckpointDict(self): class TestTrain(Trainable): - def _setup(self, config): + def setup(self, config): self.state = {"hi": 1} - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} - def _save(self, path): + def save_checkpoint(self, path): return self.state - def _restore(self, state): + def load_checkpoint(self, state): self.state = state test_trainable = TestTrain() @@ -883,17 +883,17 @@ class TrainableFunctionApiTest(unittest.TestCase): def testMultipleCheckpoints(self): class TestTrain(Trainable): - def _setup(self, config): + def setup(self, config): self.state = {"hi": 1, "iter": 0} - def _train(self): + def step(self): self.state["iter"] += 1 return {"timesteps_this_iter": 1, "done": True} - def _save(self, path): + def save_checkpoint(self, path): return self.state - def _restore(self, state): + def load_checkpoint(self, state): self.state = state test_trainable = TestTrain() @@ -938,6 +938,41 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertEqual(trial.last_result[TRAINING_ITERATION], 100) self.assertEqual(trial.last_result["itr"], 99) + def testBackwardsCompat(self): + class TestTrain(Trainable): + def _setup(self, config): + self.state = {"hi": 1, "iter": 0} + + def _train(self): + self.state["iter"] += 1 + return {"timesteps_this_iter": 1, "done": True} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + test_trainable = TestTrain() + checkpoint_1 = test_trainable.save() + test_trainable.train() + checkpoint_2 = test_trainable.save() + self.assertNotEqual(checkpoint_1, checkpoint_2) + test_trainable.restore(checkpoint_2) + self.assertEqual(test_trainable.state["iter"], 1) + test_trainable.restore(checkpoint_1) + self.assertEqual(test_trainable.state["iter"], 0) + + trials = run_experiments({ + "foo": { + "run": TestTrain, + "checkpoint_at_end": True + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 47e4f821b..e4126d180 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -623,18 +623,18 @@ def test_cluster_interrupt(start_connected_cluster, tmpdir): class _Mock(tune.Trainable): """Finishes on the 4th iteration.""" - def _setup(self, config): + def setup(self, config): self.state = {"hi": 0} - def _train(self): + def step(self): self.state["hi"] += 1 time.sleep(0.5) return {"done": self.state["hi"] >= 4} - def _save(self, path): + def save_checkpoint(self, path): return self.state - def _restore(self, state): + def load_checkpoint(self, state): self.state = state # Removes indent from class. diff --git a/python/ray/tune/tests/test_experiment_analysis_mem.py b/python/ray/tune/tests/test_experiment_analysis_mem.py index 4667c87aa..a0241d226 100644 --- a/python/ray/tune/tests/test_experiment_analysis_mem.py +++ b/python/ray/tune/tests/test_experiment_analysis_mem.py @@ -21,19 +21,19 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase): 4: [7, 5, 5, 5, 5, 5, 5, 5, 3] } - def _setup(self, config): + def setup(self, config): self.id = config["id"] self.idx = 0 - def _train(self): + def step(self): val = self.scores_dict[self.id][self.idx] self.idx += 1 return {"score": val} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): pass - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): pass self.MockTrainable = MockTrainable diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index cd5c9c856..583f44594 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -145,7 +145,7 @@ class RayTrialExecutorTest(unittest.TestCase): """Tests that reset works as expected.""" class B(Trainable): - def _train(self): + def step(self): return dict(timesteps_this_iter=1, done=True) def reset_config(self, config): diff --git a/python/ray/tune/tests/test_run_experiment.py b/python/ray/tune/tests/test_run_experiment.py index 53ef55b15..bf2f4cbcb 100644 --- a/python/ray/tune/tests/test_run_experiment.py +++ b/python/ray/tune/tests/test_run_experiment.py @@ -74,7 +74,7 @@ class RunExperimentTest(unittest.TestCase): reporter(timesteps_total=i) class B(Trainable): - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} register_trainable("f1", train) @@ -91,10 +91,10 @@ class RunExperimentTest(unittest.TestCase): def testCheckpointAtEnd(self): class train(Trainable): - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} - def _save(self, path): + def save_checkpoint(self, path): checkpoint = os.path.join(path, "checkpoint") with open(checkpoint, "w") as f: f.write("OK") @@ -112,7 +112,7 @@ class RunExperimentTest(unittest.TestCase): def testExportFormats(self): class train(Trainable): - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} def _export_model(self, export_formats, export_dir): @@ -134,7 +134,7 @@ class RunExperimentTest(unittest.TestCase): def testInvalidExportFormats(self): class train(Trainable): - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} def _export_model(self, export_formats, export_dir): @@ -156,7 +156,7 @@ class RunExperimentTest(unittest.TestCase): ray.init(resources={"hi": 3}) class train(Trainable): - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} trials = run_experiments({ diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 20bb90632..cf44ed736 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -1141,10 +1141,10 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase): pbt = self.basicSetup(perturbation_interval=2) class train(tune.Trainable): - def _train(self): + def step(self): return {"mean_accuracy": self.training_iteration} - def _save(self, path): + def save_checkpoint(self, path): checkpoint = os.path.join(path, "checkpoint") with open(checkpoint, "w") as f: f.write("OK") @@ -1173,16 +1173,16 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase): pbt = self.basicSetup(perturbation_interval=2) class train_dict(tune.Trainable): - def _setup(self, config): + def setup(self, config): self.state = {"hi": 1} - def _train(self): + def step(self): return {"mean_accuracy": self.training_iteration} - def _save(self, path): + def save_checkpoint(self, path): return self.state - def _restore(self, state): + def load_checkpoint(self, state): self.state = state trial_hyperparams = { diff --git a/python/ray/tune/tests/test_tune_save_restore.py b/python/ray/tune/tests/test_tune_save_restore.py index dd4d8fe11..63122e859 100644 --- a/python/ray/tune/tests/test_tune_save_restore.py +++ b/python/ray/tune/tests/test_tune_save_restore.py @@ -19,20 +19,20 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase): class MockTrainable(Trainable): _name = "MockTrainable" - def _setup(self, config): + def setup(self, config): self.state = {"hi": 1} - def _train(self): + def step(self): return {"timesteps_this_iter": 1, "done": True} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): checkpoint_path = os.path.join( checkpoint_dir, "checkpoint-{}".format(self._iteration)) with open(checkpoint_path, "wb") as f: pickle.dump(self.state, f) return checkpoint_path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): with open(checkpoint_path, "rb") as f: extra_data = pickle.load(f) self.state.update(extra_data) @@ -154,18 +154,18 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase): """Tests that passing the checkpoint_dir right back works.""" class MockTrainable(Trainable): - def _setup(self, config): + def setup(self, config): pass - def _train(self): + def step(self): return {"score": 1} - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): with open(os.path.join(checkpoint_dir, "test.txt"), "wb") as f: pickle.dump("test", f) return checkpoint_dir - def _restore(self, checkpoint_dir): + def load_checkpoint(self, checkpoint_dir): with open(os.path.join(checkpoint_dir, "test.txt"), "rb") as f: x = pickle.load(f) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 3296c3115..545f88385 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -15,6 +15,7 @@ import time import uuid import ray +from ray.util.debug import log_once from ray.tune.logger import UnifiedLogger from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, @@ -174,11 +175,11 @@ class Trainable: Calling ``save()`` should save the training state of a trainable to disk, and ``restore(path)`` should restore a trainable to the given state. - Generally you only need to implement ``_setup``, ``_train``, - ``_save``, and ``_restore`` when subclassing Trainable. + Generally you only need to implement ``build``, ``step``, + ``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable. Other implementation methods that may be helpful to override are - ``_log_result``, ``reset_config``, ``_stop``, and ``_export_model``. + ``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``. When using Tune, Tune will convert this class into a Ray actor, which runs on a separate process. Tune will also change the current working @@ -192,7 +193,7 @@ class Trainable: Sets up logging and points ``self.logdir`` to a directory in which training outputs should be placed. - Subclasses should prefer defining ``_setup()`` instead of overriding + Subclasses should prefer defining ``build()`` instead of overriding ``__init__()`` directly. Args: @@ -228,11 +229,11 @@ class Trainable: self._trial_info = trial_info start_time = time.time() - self._setup(copy.deepcopy(self.config)) + self.setup(copy.deepcopy(self.config)) setup_time = time.time() - start_time if setup_time > SETUP_TIME_THRESHOLD: - logger.info("_setup took {:.3f} seconds. If your trainable is " - "slow to initialize, consider setting " + logger.info("Trainable.setup took {:.3f} seconds. If your " + "trainable is slow to initialize, consider setting " "reuse_actors=True to reduce actor creation " "overheads.".format(setup_time)) self._local_ip = self.get_current_ip() @@ -277,8 +278,9 @@ class Trainable: def train(self): """Runs one logical iteration of training. - Subclasses should override ``_train()`` instead to return results. - This class automatically fills the following fields in the result: + Calls ``step()`` internally. Subclasses should override ``step()`` + instead to return results. + This method automatically fills the following fields in the result: `done` (bool): training is terminated. Filled only if not provided. @@ -295,7 +297,7 @@ class Trainable: `training_iteration` (int): The index of this training iteration, e.g. call to train(). This is incremented - after `_train()` is called. + after `step()` is called. `pid` (str): The pid of the training process. @@ -314,8 +316,8 @@ class Trainable: A dict that describes training progress. """ start = time.time() - result = self._train() - assert isinstance(result, dict), "_train() needs to return a dict." + result = self.step() + assert isinstance(result, dict), "step() needs to return a dict." # We do not modify internal state nor update this result if duplicate. if RESULT_DUPLICATE in result: @@ -376,7 +378,7 @@ class Trainable: if monitor_data: result.update(monitor_data) - self._log_result(result) + self.log_result(result) return result @@ -404,7 +406,7 @@ class Trainable: """ checkpoint_dir = TrainableUtil.make_checkpoint_dir( checkpoint_dir or self.logdir, index=self.iteration) - checkpoint = self._save(checkpoint_dir) + checkpoint = self.save_checkpoint(checkpoint_dir) trainable_state = self.get_state() checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, @@ -451,9 +453,9 @@ class Trainable: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) - self._restore(checkpoint_dict) + self.load_checkpoint(checkpoint_dict) else: - self._restore(checkpoint_path) + self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 @@ -523,7 +525,7 @@ class Trainable: be updated to reflect the latest parameter information in Ray logs. Args: - new_config (dir): Updated hyperparameter configuration + new_config (dict): Updated hyperparameter configuration for the trainable. Returns: @@ -532,10 +534,14 @@ class Trainable: return False def stop(self): - """Releases all resources used by this trainable.""" + """Releases all resources used by this trainable. + + Calls ``Trainable.cleanup`` internally. Subclasses should override + ``Trainable.cleanup`` for custom cleanup procedures. + """ self._result_logger.flush() self._result_logger.close() - self._stop() + self.cleanup() @property def logdir(self): @@ -603,7 +609,7 @@ class Trainable: """Returns configuration passed in by Tune.""" return self.config - def _train(self): + def step(self): """Subclasses should override this to implement train(). The return value will be automatically passed to the loggers. Users @@ -612,27 +618,43 @@ class Trainable: trial. Note that manual checkpointing only works when subclassing Trainables. + .. versionadded:: 0.8.7 + Returns: A dict that describes training progress. """ + result = self._train() + if self._is_overriden("_train") and log_once("_train"): + logger.warning( + "Trainable._train is deprecated and will be removed in " + "a future version of Ray. Override Trainable.step instead.") + return result + + def _train(self): + """This method is deprecated. Override 'Trainable.step' instead. + + .. versionchanged:: 0.8.7 + """ raise NotImplementedError - def _save(self, tmp_checkpoint_dir): + def save_checkpoint(self, tmp_checkpoint_dir): """Subclasses should override this to implement ``save()``. Warning: - Do not rely on absolute paths in the implementation of ``_save`` - and ``_restore``. + Do not rely on absolute paths in the implementation of + ``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``. - Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors - before execution. + Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/ + ``Trainable.load_checkpoint`` errors before execution. >>> from ray.tune.utils import validate_save_restore >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore(MyTrainableClass, use_object_store=True) + .. versionadded:: 0.8.7 + Args: tmp_checkpoint_dir (str): The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, @@ -641,44 +663,60 @@ class Trainable: Returns: A dict or string. If string, the return value is expected to be prefixed by `tmp_checkpoint_dir`. If dict, the return value will - be automatically serialized by Tune and passed to `_restore()`. + be automatically serialized by Tune and + passed to ``Trainable.load_checkpoint()``. Examples: - >>> print(trainable1._save("/tmp/checkpoint_1")) + >>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" - >>> print(trainable2._save("/tmp/checkpoint_2")) + >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) {"some": "data"} - >>> trainable._save("/tmp/bad_example") + >>> trainable.save_checkpoint("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error. """ + checkpoint = self._save(tmp_checkpoint_dir) + if self._is_overriden("_save") and log_once("_save"): + logger.warning( + "Trainable._save is deprecated and will be removed in a " + "future version of Ray. Override " + "Trainable.save_checkpoint instead.") + return checkpoint + + def _save(self, tmp_checkpoint_dir): + """This method is deprecated. Override 'save_checkpoint' instead. + + .. versionchanged:: 0.8.7 + """ raise NotImplementedError - def _restore(self, checkpoint): + def load_checkpoint(self, checkpoint): """Subclasses should override this to implement restore(). Warning: In this method, do not rely on absolute paths. The absolute - path of the checkpoint_dir used in ``_save`` may be changed. + path of the checkpoint_dir used in ``Trainable.save_checkpoint`` + may be changed. - If ``_save`` returned a prefixed string, the prefix of the checkpoint - string returned by ``_save`` may be changed. This is because trial - pausing depends on temporary directories. + If ``Trainable.save_checkpoint`` returned a prefixed string, the + prefix of the checkpoint string returned by + ``Trainable.save_checkpoint`` may be changed. + This is because trial pausing depends on temporary directories. - The directory structure under the checkpoint_dir provided to ``_save`` - is preserved. + The directory structure under the checkpoint_dir provided to + ``Trainable.save_checkpoint`` is preserved. See the example below. .. code-block:: python class Example(Trainable): - def _save(self, checkpoint_path): + def save_checkpoint(self, checkpoint_path): print(checkpoint_path) return os.path.join(checkpoint_path, "my/check/point") - def _restore(self, checkpoint): + def load_checkpoint(self, checkpoint): print(checkpoint) >>> trainer = Example() @@ -687,39 +725,78 @@ class Trainable: >>> trainer.restore_from_object(obj) # Note the different prefix. /tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point + .. versionadded:: 0.8.7 Args: checkpoint (str|dict): If dict, the return value is as - returned by `_save`. If a string, then it is a checkpoint path - that may have a different prefix than that returned by `_save`. - The directory structure underneath the `checkpoint_dir` - `_save` is preserved. + returned by `save_checkpoint`. If a string, then it is + a checkpoint path that may have a different prefix than that + returned by `save_checkpoint`. The directory structure + underneath the `checkpoint_dir` `save_checkpoint` is preserved. """ + self._restore(checkpoint) + if self._is_overriden("_restore") and log_once("_restore"): + logger.warning( + "Trainable._restore is deprecated and will be removed in a " + "future version of Ray. Override Trainable.load_checkpoint " + "instead.") + def _restore(self, checkpoint): + """This method is deprecated. Override 'load_checkpoint' instead. + + .. versionchanged:: 0.8.7 + """ raise NotImplementedError - def _setup(self, config): + def setup(self, config): """Subclasses should override this for custom initialization. + .. versionadded:: 0.8.7 + Args: config (dict): Hyperparameters and other configs given. Copy of `self.config`. """ + self._setup(config) + if self._is_overriden("_setup") and log_once("_setup"): + logger.warning( + "Trainable._setup is deprecated and will be removed in " + "a future version of Ray. Override Trainable.setup instead.") + + def _setup(self, config): + """This method is deprecated. Override 'setup' instead. + + .. versionchanged:: 0.8.7 + """ pass - def _log_result(self, result): + def log_result(self, result): """Subclasses can optionally override this to customize logging. The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the ``loggers`` parameter in ``tune.run`` when overriding this function. + .. versionadded:: 0.8.7 + Args: - result (dict): Training result returned by _train(). + result (dict): Training result returned by step(). + """ + self._log_result(result) + if self._is_overriden("_log_result") and log_once("_log_result"): + logger.warning( + "Trainable._log_result is deprecated and will be removed in " + "a future version of Ray. Override " + "Trainable.log_result instead.") + + def _log_result(self, result): + """This method is deprecated. Override 'log_result' instead. + + .. versionchanged:: 0.8.7 """ self._result_logger.on_result(result) - def _stop(self): + def cleanup(self): """Subclasses should override this for any cleanup on stop. If any Ray actors are launched in the Trainable (i.e., with a RLlib @@ -727,6 +804,19 @@ class Trainable: You can kill a Ray actor by calling `actor.__ray_terminate__.remote()` on the actor. + + .. versionadded:: 0.8.7 + """ + self._stop() + if self._is_overriden("_stop") and log_once("trainable.cleanup"): + logger.warning( + "Trainable._stop is deprecated and will be removed in " + "a future version of Ray. Override Trainable.cleanup instead.") + + def _stop(self): + """This method is deprecated. Override 'cleanup' instead. + + .. versionchanged:: 0.8.7 """ pass @@ -741,3 +831,6 @@ class Trainable: A dict that maps ExportFormats to successfully exported models. """ return {} + + def _is_overriden(self, key): + return getattr(self, key).__code__ != getattr(Trainable, key).__code__ diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index f2b66e7c1..bc690e478 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -293,7 +293,6 @@ class TrialRunner: with open(newest_ckpt_path, "r") as f: runner_state = json.load(f, cls=_TuneFunctionDecoder) self.checkpoint_file = newest_ckpt_path - logger.warning("".join([ "Attempting to resume experiment from {}. ".format( self._local_checkpoint_dir), "This feature is experimental, " diff --git a/python/ray/util/sgd/tf/tf_trainer.py b/python/ray/util/sgd/tf/tf_trainer.py index b5ab0ddee..2ee7af164 100644 --- a/python/ray/util/sgd/tf/tf_trainer.py +++ b/python/ray/util/sgd/tf/tf_trainer.py @@ -171,7 +171,7 @@ class TFTrainable(Trainable): extra_cpu=config["num_replicas"], extra_gpu=int(config["use_gpu"]) * config["num_replicas"]) - def _setup(self, config): + def setup(self, config): self._trainer = TFTrainer( model_creator=config["model_creator"], data_creator=config["data_creator"], @@ -180,7 +180,7 @@ class TFTrainable(Trainable): use_gpu=config["use_gpu"], num_cpus_per_worker=config.get("num_cpus_per_worker", 1)) - def _train(self): + def step(self): train_stats = self._trainer.train() validation_stats = self._trainer.validate() @@ -189,11 +189,11 @@ class TFTrainable(Trainable): return train_stats - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): return self._trainer.save(os.path.join(checkpoint_dir, "model")) - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): return self._trainer.restore(checkpoint_path) - def _stop(self): + def cleanup(self): self._trainer.shutdown() diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 232289501..b7d7a2bd7 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -785,7 +785,7 @@ class BaseTorchTrainable(Trainable): # TorchTrainable is subclass of BaseTorchTrainable. class CustomTrainable(TorchTrainable): - def _train(self): + def step(self): for i in range(5): train_stats = self.trainer.train() validation_stats = self.trainer.validate() @@ -799,11 +799,11 @@ class BaseTorchTrainable(Trainable): """ - def _setup(self, config): + def setup(self, config): """Constructs a TorchTrainer object as `self.trainer`.""" self._trainer = self._create_trainer(config) - def _train(self): + def step(self): """Calls `self.trainer.train()` and `self.trainer.validate()` once. You may want to override this if using a custom LR scheduler. @@ -813,20 +813,20 @@ class BaseTorchTrainable(Trainable): stats = merge_dicts(train_stats, validation_stats) return stats - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): """Returns a path containing the trainer state.""" checkpoint_path = os.path.join(checkpoint_dir, "trainer.checkpoint") self.trainer.save(checkpoint_path) return checkpoint_path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): """Restores the trainer state. Override this if you have state external to the Trainer object. """ return self.trainer.load(checkpoint_path) - def _stop(self): + def cleanup(self): """Shuts down the trainer.""" self.trainer.shutdown() diff --git a/rllib/__init__.py b/rllib/__init__.py index 21207b600..9eff863f3 100644 --- a/rllib/__init__.py +++ b/rllib/__init__.py @@ -41,7 +41,7 @@ def _register_all(): _name = "SeeContrib" _default_config = with_common_config({}) - def _setup(self, config): + def setup(self, config): raise NameError( "Please run `contrib/{}` instead.".format(name)) diff --git a/rllib/agents/ars/ars.py b/rllib/agents/ars/ars.py index 87c4b7cfc..e4b2cc99c 100644 --- a/rllib/agents/ars/ars.py +++ b/rllib/agents/ars/ars.py @@ -216,7 +216,7 @@ class ARSTrainer(Trainer): return self.policy @override(Trainer) - def _train(self): + def step(self): config = self.config theta = self.policy.get_flat_weights() @@ -313,7 +313,7 @@ class ARSTrainer(Trainer): return result @override(Trainer) - def _stop(self): + def cleanup(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: w.__ray_terminate__.remote() diff --git a/rllib/agents/es/es.py b/rllib/agents/es/es.py index 7ab320656..f94fc469a 100644 --- a/rllib/agents/es/es.py +++ b/rllib/agents/es/es.py @@ -217,7 +217,7 @@ class ESTrainer(Trainer): return self.policy @override(Trainer) - def _train(self): + def step(self): config = self.config theta = self.policy.get_flat_weights() @@ -313,7 +313,7 @@ class ESTrainer(Trainer): return action @override(Trainer) - def _stop(self): + def cleanup(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self._workers: w.__ray_terminate__.remote() diff --git a/rllib/agents/mock.py b/rllib/agents/mock.py index cd0a8bdd8..90bfffe83 100644 --- a/rllib/agents/mock.py +++ b/rllib/agents/mock.py @@ -29,7 +29,7 @@ class _MockTrainer(Trainer): self.info = None self.restored = False - def _train(self): + def step(self): if self.config["mock_error"] and self.iteration == 1 \ and (self.config["persistent_error"] or not self.restored): raise Exception("mock error") @@ -43,13 +43,13 @@ class _MockTrainer(Trainer): result.update({tune_result.SHOULD_CHECKPOINT: True}) return result - def _save(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "mock_agent.pkl") with open(path, "wb") as f: pickle.dump(self.info, f) return path - def _restore(self, checkpoint_path): + def load_checkpoint(self, checkpoint_path): with open(checkpoint_path, "rb") as f: info = pickle.load(f) self.info = info @@ -83,7 +83,7 @@ class _SigmoidFakeData(_MockTrainer): "object_store_memory": 0, }) - def _train(self): + def step(self): i = max(0, self.iteration - self.config["offset"]) v = np.tanh(float(i) / self.config["width"]) v *= self.config["height"] @@ -109,7 +109,7 @@ class _ParameterTuningTrainer(_MockTrainer): "object_store_memory": 0, }) - def _train(self): + def step(self): return dict( episode_reward_mean=self.config["reward_amt"] * self.iteration, episode_len_mean=self.config["reward_amt"], @@ -125,7 +125,7 @@ def _agent_import_failed(trace): _name = "AgentImportFailed" _default_config = with_common_config({}) - def _setup(self, config): + def setup(self, config): raise ImportError(trace) return _AgentImportFailed diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 74f272f87..35804ecc6 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -439,7 +439,7 @@ class Trainer(Trainable): # User provided config (this is w/o the default Trainer's # `COMMON_CONFIG` (see above)). Will get merged with COMMON_CONFIG - # in self._setup(). + # in self.setup(). config = config or {} # Vars to synchronize to workers on each train call @@ -550,14 +550,14 @@ class Trainer(Trainable): workers.local_worker().filters)) @override(Trainable) - def _log_result(self, result: ResultDict): + def log_result(self, result: ResultDict): self.callbacks.on_train_result(trainer=self, result=result) # log after the callback is invoked, so that the user has a chance # to mutate the result - Trainable._log_result(self, result) + Trainable.log_result(self, result) @override(Trainable) - def _setup(self, config: PartialTrainerConfigDict): + def setup(self, config: PartialTrainerConfigDict): env = self._env_id if env: config["env"] = env @@ -665,14 +665,14 @@ class Trainer(Trainable): self.evaluation_metrics = {} @override(Trainable) - def _stop(self): + def cleanup(self): if hasattr(self, "workers"): self.workers.stop() if hasattr(self, "optimizer") and self.optimizer: self.optimizer.stop() @override(Trainable) - def _save(self, checkpoint_dir: str) -> str: + def save_checkpoint(self, checkpoint_dir: str) -> str: checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) @@ -680,7 +680,7 @@ class Trainer(Trainable): return checkpoint_path @override(Trainable) - def _restore(self, checkpoint_path: str): + def load_checkpoint(self, checkpoint_path: str): extra_data = pickle.load(open(checkpoint_path, "rb")) self.__setstate__(extra_data) diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 4cb8f917a..570cc32f0 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -34,21 +34,21 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict): @DeveloperAPI def build_trainer( - name: str, - default_policy: Optional[Policy], - *, - default_config: TrainerConfigDict = None, - validate_config: Callable[[TrainerConfigDict], None] = None, - get_initial_state=None, # DEPRECATED - get_policy_class: Callable[[TrainerConfigDict], Policy] = None, - before_init: Callable[[Trainer], None] = None, - make_workers=None, # DEPRECATED - make_policy_optimizer=None, # DEPRECATED - after_init: Callable[[Trainer], None] = None, - before_train_step=None, # DEPRECATED - after_optimizer_step=None, # DEPRECATED - after_train_result=None, # DEPRECATED - collect_metrics_fn=None, # DEPRECATED + name: str, + default_policy: Optional[Policy], + *, + default_config: TrainerConfigDict = None, + validate_config: Callable[[TrainerConfigDict], None] = None, + get_initial_state=None, # DEPRECATED + get_policy_class: Callable[[TrainerConfigDict], Policy] = None, + before_init: Callable[[Trainer], None] = None, + make_workers=None, # DEPRECATED + make_policy_optimizer=None, # DEPRECATED + after_init: Callable[[Trainer], None] = None, + before_train_step=None, # DEPRECATED + after_optimizer_step=None, # DEPRECATED + after_train_result=None, # DEPRECATED + collect_metrics_fn=None, # DEPRECATED before_evaluate_fn: Callable[[Trainer], None] = None, mixins: List[type] = None, execution_plan: Callable[[WorkerSet, TrainerConfigDict], Iterable[ @@ -134,7 +134,7 @@ def build_trainer( after_init(self) @override(Trainer) - def _train(self): + def step(self): if self.train_exec_impl: return self._train_exec_impl() diff --git a/rllib/contrib/random_agent/random_agent.py b/rllib/contrib/random_agent/random_agent.py index 4e4abe7e1..2570eee95 100644 --- a/rllib/contrib/random_agent/random_agent.py +++ b/rllib/contrib/random_agent/random_agent.py @@ -20,7 +20,7 @@ class RandomAgent(Trainer): self.env = env_creator(config["env_config"]) @override(Trainer) - def _train(self): + def step(self): rewards = [] steps = 0 for _ in range(self.config["rollouts_per_iteration"]):