mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 06:27:02 +08:00
[raysgd] Cleanup User API (#7384)
* Init fp16 * fp16 and schedulers * scheduler linking and fp16 * to fp16 * loss scaling and documentation * more documentation * add tests, refactor config * moredocs * more docs * fix logo, add test mode, add fp16 flag * fix tests * fix scheduler * fix apex * improve safety * fix tests * fix tests * remove pin memory default * rm * fix * Update doc/examples/doc_code/raysgd_torch_signatures.py * fix * migrate changes from other PR * ok thanks * pass * signatures * lint' * Update python/ray/experimental/sgd/pytorch/utils.py * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * should address most comments * comments * fix this ci * first_pass * add overrides * override * fixing up operators * format * sgd * constants * rm * revert * save * failures * fixes * trainer * run test * operator * code * op * ok done * operator * sgd test fixes * ok * trainer * format * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * Update doc/source/raysgd/raysgd_pytorch.rst * docstring * dcgan * doc * commits * nit * testing * revert * Start renaming pytorch to torch * Rename PyTorchTrainer to TorchTrainer * Rename PyTorch runners to Torch runners * Finish renaming API * Rename to torch in tests * Finish renaming docs + tests * Run format + fix DeprecationWarning * fix * move tests up * benchmarks * rename * remove some args * better metrics output * fix up the benchmark * benchmark-yaml * horovod-benchmark * benchmarks * Remove benchmark code for cleanups * makedatacreator * relax * metrics * autosetsampler * profile * movements * OK * smoothen * fix * nitdocs * loss * comments * fix * fix * runner_tests * codes * example * fix_test * fix * tests Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com> Co-authored-by: Maksim Smolin <maximsmol@gmail.com>
This commit is contained in:
@@ -1,121 +0,0 @@
|
||||
# flake8: noqa
|
||||
"""
|
||||
This file holds code for the torch Trainer creator signatures.
|
||||
|
||||
It ignores yapf because yapf doesn't allow comments right after code blocks,
|
||||
but we put comments right after code blocks to prevent large white spaces
|
||||
in the documentation.
|
||||
"""
|
||||
# yapf: disable
|
||||
|
||||
# __torch_model_start__
|
||||
import torch.nn as nn
|
||||
|
||||
def model_creator(config):
|
||||
"""Constructor function for the model(s) to be optimized.
|
||||
|
||||
You will also need to provide a custom training
|
||||
function to specify the optimization procedure for multiple models.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary passed into ``TorchTrainer``.
|
||||
|
||||
Returns:
|
||||
One or more torch.nn.Module objects.
|
||||
"""
|
||||
return nn.Linear(1, 1)
|
||||
# __torch_model_end__
|
||||
|
||||
|
||||
# __torch_optimizer_start__
|
||||
import torch
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Constructor of one or more Torch optimizers.
|
||||
|
||||
Args:
|
||||
models: The return values from ``model_creator``. This can be one
|
||||
or more torch nn modules.
|
||||
config (dict): Configuration dictionary passed into ``TorchTrainer``.
|
||||
|
||||
Returns:
|
||||
One or more Torch optimizer objects.
|
||||
"""
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
# __torch_optimizer_end__
|
||||
|
||||
|
||||
# __torch_data_start__
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
def data_creator(config):
|
||||
"""Constructs torch.utils.data.Dataset objects.
|
||||
|
||||
Note that even though two Dataset objects can be returned,
|
||||
only one dataset will be used for training.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary passed into ``TorchTrainer``
|
||||
|
||||
Returns:
|
||||
One or Two Dataset objects. If only one Dataset object is provided,
|
||||
``trainer.validate()`` will throw a ValueError.
|
||||
"""
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
# __torch_data_end__
|
||||
|
||||
# __torch_loss_start__
|
||||
import torch
|
||||
|
||||
def loss_creator(config):
|
||||
"""Constructs the Torch Loss object.
|
||||
|
||||
Note that optionally, you can pass in a Torch Loss constructor directly
|
||||
into the TorchTrainer (i.e., ``TorchTrainer(loss_creator=nn.BCELoss, ...)``).
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary passed into ``TorchTrainer``
|
||||
|
||||
Returns:
|
||||
Torch Loss object.
|
||||
"""
|
||||
return torch.nn.BCELoss()
|
||||
# __torch_loss_end__
|
||||
|
||||
# __torch_scheduler_start__
|
||||
import torch
|
||||
|
||||
def scheduler_creator(optimizer, config):
|
||||
"""Constructor of one or more Torch optimizer schedulers.
|
||||
|
||||
Args:
|
||||
optimizers: The return values from ``optimizer_creator``.
|
||||
This can be one or more torch optimizer objects.
|
||||
config: Configuration dictionary passed into ``TorchTrainer``
|
||||
|
||||
Returns:
|
||||
One or more Torch scheduler objects.
|
||||
"""
|
||||
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
|
||||
|
||||
# __torch_scheduler_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
config={"lr": 0.001})
|
||||
|
||||
# __torch_trainer_end__
|
||||
Reference in New Issue
Block a user