diff --git a/doc/source/tune/_tutorials/tune-pytorch-cifar.rst b/doc/source/tune/_tutorials/tune-pytorch-cifar.rst index e8b2c5240..54a46d545 100644 --- a/doc/source/tune/_tutorials/tune-pytorch-cifar.rst +++ b/doc/source/tune/_tutorials/tune-pytorch-cifar.rst @@ -78,21 +78,19 @@ documentation `. .. code-block:: python net = Net(config["l1"], config["l2"]) + optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) if checkpoint_dir: checkpoint = os.path.join(checkpoint_dir, "checkpoint") - net.load_state_dict(torch.load(checkpoint)) - -The learning rate of the optimizer is made configurable, too: - -.. code-block:: python - - optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) + model_state, optimizer_state = torch.load(checkpoint) + net.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) We also split the training data into a training and validation subset. We thus train on 80% of the data and calculate the validation loss on the remaining 20%. The batch sizes @@ -129,6 +127,8 @@ also supports :doc:`fractional GPUs ` so we can share GPUs among trials, as long as the model still fits on the GPU memory. We'll come back to that later. +.. _communicating-with-ray-tune: + Communicating with Ray Tune ~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -150,6 +150,8 @@ resources on those trials. The :ref:`checkpoint saving ` is optional. However, it is necessary if we wanted to use advanced schedulers like `Population Based Training `_. +In this cases, the created checkpoint directory will be passed as the ``checkpoint_dir`` parameter +to the training function. After training, we can also restore the checkpointed models and validate them on a test set. @@ -162,7 +164,7 @@ The full code example looks like this: :language: python :start-after: __train_begin__ :end-before: __train_end__ - :emphasize-lines: 2,4-9,12,14-18,28,33,43,70,81-84,86 + :emphasize-lines: 2,4-9,12,14-20,30,35,45,72,83-89,91 As you can see, most of the code is adapted directly from the example. diff --git a/python/ray/tune/examples/cifar10_pytorch.py b/python/ray/tune/examples/cifar10_pytorch.py index 2e1b1cbe2..4af8c4937 100644 --- a/python/ray/tune/examples/cifar10_pytorch.py +++ b/python/ray/tune/examples/cifar10_pytorch.py @@ -72,6 +72,8 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) + # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint + # should be restored. if checkpoint_dir: checkpoint = os.path.join(checkpoint_dir, "checkpoint") model_state, optimizer_state = torch.load(checkpoint) @@ -139,6 +141,9 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None): val_loss += loss.cpu().numpy() val_steps += 1 + # Here we save a checkpoint. It is automatically registered with + # Ray Tune and will potentially be passed as the `checkpoint_dir` + # parameter in future iterations. with tune.checkpoint_dir(step=epoch) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") torch.save( @@ -174,7 +179,7 @@ def test_accuracy(net, device="cpu"): # __main_begin__ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2): data_dir = os.path.abspath("./data") - load_data(data_dir) + load_data(data_dir) # Download data for all trials before starting the run config = { "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),