mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:31:15 +08:00
[sgd] show training result for examples (#6552)
This commit is contained in:
@@ -26,7 +26,6 @@ class LinearDataset(torch.utils.data.Dataset):
|
||||
"""y = a * x + b"""
|
||||
|
||||
def __init__(self, a, b, size=1000):
|
||||
x = np.random.random(size).astype(np.float32) * 10
|
||||
x = np.arange(0, 10, 10 / size, dtype=np.float32)
|
||||
self.x = torch.from_numpy(x)
|
||||
self.y = torch.from_numpy(a * x + b)
|
||||
@@ -44,7 +43,7 @@ def model_creator(config):
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns optimizer."""
|
||||
return torch.optim.SGD(model.parameters(), lr=1e-4)
|
||||
return torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||
|
||||
|
||||
def data_creator(batch_size, config):
|
||||
@@ -81,9 +80,16 @@ def train_example(num_replicas=1, use_gpu=False):
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=512,
|
||||
batch_size=num_replicas * 4,
|
||||
backend="gloo")
|
||||
trainer1.train()
|
||||
for i in range(5):
|
||||
stats = trainer1.train()
|
||||
print(stats)
|
||||
|
||||
print(trainer1.validate())
|
||||
m = trainer1.get_model()
|
||||
print("trained weight: % .2f, bias: % .2f" % (
|
||||
m.weight.item(), m.bias.item()))
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user