mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 01:59:23 +08:00
[tune] update pt tutorial docs (#10925)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -72,6 +72,8 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
||||
|
||||
# The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
|
||||
# should be restored.
|
||||
if checkpoint_dir:
|
||||
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
||||
model_state, optimizer_state = torch.load(checkpoint)
|
||||
@@ -139,6 +141,9 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
|
||||
val_loss += loss.cpu().numpy()
|
||||
val_steps += 1
|
||||
|
||||
# Here we save a checkpoint. It is automatically registered with
|
||||
# Ray Tune and will potentially be passed as the `checkpoint_dir`
|
||||
# parameter in future iterations.
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save(
|
||||
@@ -174,7 +179,7 @@ def test_accuracy(net, device="cpu"):
|
||||
# __main_begin__
|
||||
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
||||
data_dir = os.path.abspath("./data")
|
||||
load_data(data_dir)
|
||||
load_data(data_dir) # Download data for all trials before starting the run
|
||||
config = {
|
||||
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
||||
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
||||
|
||||
Reference in New Issue
Block a user