mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 07:53:50 +08:00
[tune][minor] formatting examples, fix travis (#5869)
* formatting * formatting
This commit is contained in:
@@ -70,6 +70,7 @@ def get_data_loaders(batch_size):
|
||||
shuffle=True)
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Setup: Defining the Neural Network
|
||||
# ----------------------------------
|
||||
@@ -130,6 +131,7 @@ def test(model, test_loader, device=torch.device("cpu")):
|
||||
|
||||
return correct / total
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Evaluating the Hyperparameters
|
||||
# -------------------------------
|
||||
@@ -141,6 +143,7 @@ def test(model, test_loader, device=torch.device("cpu")):
|
||||
#
|
||||
# The ``@ray.remote`` decorator defines a remote process.
|
||||
|
||||
|
||||
@ray.remote
|
||||
def evaluate_hyperparameters(config):
|
||||
model = ConvNet()
|
||||
@@ -152,6 +155,7 @@ def evaluate_hyperparameters(config):
|
||||
train(model, optimizer, train_loader)
|
||||
return test(model, test_loader)
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Synchronous Evaluation of Randomly Generated Hyperparameters
|
||||
# ------------------------------------------------------------
|
||||
@@ -159,7 +163,6 @@ def evaluate_hyperparameters(config):
|
||||
# We will create multiple sets of random hyperparameters for our neural
|
||||
# network that will be evaluated in parallel.
|
||||
|
||||
|
||||
# Keep track of the best hyperparameters and the best accuracy.
|
||||
best_hyperparameters = None
|
||||
best_accuracy = 0
|
||||
|
||||
Reference in New Issue
Block a user