mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 09:12:56 +08:00
fix example (#10964)
This commit is contained in:
@@ -8,7 +8,12 @@ in the documentation.
|
||||
# yapf: disable
|
||||
|
||||
# __torch_operator_start__
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
@@ -44,10 +49,29 @@ class MyTrainingOperator(TrainingOperator):
|
||||
self.model, self.optimizer, self.criterion, self.scheduler = \
|
||||
self.register(models=model, optimizers=optimizer,
|
||||
criterion=criterion,
|
||||
scheduler=scheduler)
|
||||
schedulers=scheduler)
|
||||
self.register_data(train_loader=train_loader, validation_loader=val_loader)
|
||||
# __torch_operator_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
scheduler_step_freq="epoch", # if scheduler is used
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
# __torch_model_start__
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -144,13 +168,6 @@ def scheduler_creator(optimizer, config):
|
||||
|
||||
# __torch_scheduler_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __backwards_compat_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
@@ -167,15 +184,3 @@ trainer = TorchTrainer(
|
||||
# __backwards_compat_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
scheduler_step_freq="epoch", # if scheduler is used
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user