mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[tune] Fix github readme (#9365)
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
+20
-15
@@ -87,34 +87,39 @@ To run this example, you will need to install the following:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install ray[tune] torch torchvision filelock
|
||||
$ pip install ray[tune]
|
||||
|
||||
|
||||
This example runs a parallel grid search to train a Convolutional Neural Network using PyTorch.
|
||||
This example runs a parallel grid search to optimize an example objective function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
import torch.optim as optim
|
||||
from ray import tune
|
||||
from ray.tune.examples.mnist_pytorch import (
|
||||
get_data_loaders, ConvNet, train, test)
|
||||
|
||||
|
||||
def train_mnist(config):
|
||||
train_loader, test_loader = get_data_loaders()
|
||||
model = ConvNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=config["lr"])
|
||||
for i in range(10):
|
||||
train(model, optimizer, train_loader)
|
||||
acc = test(model, test_loader)
|
||||
tune.track.log(mean_accuracy=acc)
|
||||
def objective(step, alpha, beta):
|
||||
return (0.1 + alpha * step / 100)**(-1) + beta * 0.1
|
||||
|
||||
|
||||
def training_function(config):
|
||||
# Hyperparameters
|
||||
alpha, beta = config["alpha"], config["beta"]
|
||||
for step in range(10):
|
||||
# Iterative training function - can be any arbitrary training procedure.
|
||||
intermediate_score = objective(step, alpha, beta)
|
||||
# Feed the score back back to Tune.
|
||||
tune.report(mean_loss=intermediate_score)
|
||||
|
||||
|
||||
analysis = tune.run(
|
||||
train_mnist, config={"lr": tune.grid_search([0.001, 0.01, 0.1])})
|
||||
training_function,
|
||||
config={
|
||||
"alpha": tune.grid_search([0.001, 0.01, 0.1]),
|
||||
"beta": tune.choice([1, 2, 3])
|
||||
})
|
||||
|
||||
print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))
|
||||
print("Best config: ", analysis.get_best_config(metric="mean_loss"))
|
||||
|
||||
# Get a dataframe for analyzing trial results.
|
||||
df = analysis.dataframe()
|
||||
|
||||
Reference in New Issue
Block a user