From be23b3ac415e75db3ac100b630f31d562904e20c Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 25 Dec 2019 17:15:43 -0800 Subject: [PATCH] [sgd] show training result for examples (#6552) --- .../ray/experimental/sgd/examples/train_example.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/ray/experimental/sgd/examples/train_example.py b/python/ray/experimental/sgd/examples/train_example.py index de64905da..d44dc374e 100644 --- a/python/ray/experimental/sgd/examples/train_example.py +++ b/python/ray/experimental/sgd/examples/train_example.py @@ -26,7 +26,6 @@ class LinearDataset(torch.utils.data.Dataset): """y = a * x + b""" def __init__(self, a, b, size=1000): - x = np.random.random(size).astype(np.float32) * 10 x = np.arange(0, 10, 10 / size, dtype=np.float32) self.x = torch.from_numpy(x) self.y = torch.from_numpy(a * x + b) @@ -44,7 +43,7 @@ def model_creator(config): def optimizer_creator(model, config): """Returns optimizer.""" - return torch.optim.SGD(model.parameters(), lr=1e-4) + return torch.optim.SGD(model.parameters(), lr=1e-2) def data_creator(batch_size, config): @@ -81,9 +80,16 @@ def train_example(num_replicas=1, use_gpu=False): loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas, use_gpu=use_gpu, - batch_size=512, + batch_size=num_replicas * 4, backend="gloo") - trainer1.train() + for i in range(5): + stats = trainer1.train() + print(stats) + + print(trainer1.validate()) + m = trainer1.get_model() + print("trained weight: % .2f, bias: % .2f" % ( + m.weight.item(), m.bias.item())) trainer1.shutdown() print("success!")