[tune] update pt tutorial docs (#10925)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke
2020-09-21 21:33:37 +01:00
committed by GitHub
parent d1d4743702
commit 50d63b8077
2 changed files with 17 additions and 10 deletions
+6 -1
View File
@@ -72,6 +72,8 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
# The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
# should be restored.
if checkpoint_dir:
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint)
@@ -139,6 +141,9 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
val_loss += loss.cpu().numpy()
val_steps += 1
# Here we save a checkpoint. It is automatically registered with
# Ray Tune and will potentially be passed as the `checkpoint_dir`
# parameter in future iterations.
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save(
@@ -174,7 +179,7 @@ def test_accuracy(net, device="cpu"):
# __main_begin__
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
data_dir = os.path.abspath("./data")
load_data(data_dir)
load_data(data_dir) # Download data for all trials before starting the run
config = {
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),