mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:00:58 +08:00
[SGD] Support multiple input model (#8246)
This commit is contained in:
@@ -725,6 +725,55 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
def test_multi_input_model(ray_start_2_cpus):
|
||||
def model_creator(config):
|
||||
class MultiInputModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiInputModel, self).__init__()
|
||||
self._fc1 = torch.nn.Linear(1, 1)
|
||||
self._fc2 = torch.nn.Linear(1, 1)
|
||||
|
||||
def forward(self, x, y):
|
||||
return self._fc1(x) + self._fc2(y)
|
||||
|
||||
return MultiInputModel()
|
||||
|
||||
def data_creator(config):
|
||||
class LinearDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, a, b, size=1000):
|
||||
x = np.random.randn(size)
|
||||
y = np.random.randn(size)
|
||||
self.x = torch.tensor(x, dtype=torch.float32)
|
||||
self.y = torch.tensor(y, dtype=torch.float32)
|
||||
self.z = torch.tensor(a * (x + y) + 2 * b, dtype=torch.float32)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.x[index, None], self.y[index, None],
|
||||
self.z[index, None])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
train_dataset = LinearDataset(3, 4)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.get("batch_size", 32),
|
||||
)
|
||||
return train_loader, None
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=1)
|
||||
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
@@ -212,8 +212,9 @@ class TrainingOperator:
|
||||
updating the model.
|
||||
|
||||
By default, this method implementation assumes that batches
|
||||
are in (features, labels) format. If using amp/fp16
|
||||
training, it will also scale the loss automatically.
|
||||
are in (\*features, labels) format. So we also support multiple inputs
|
||||
model. If using amp/fp16 training, it will also scale the loss
|
||||
automatically.
|
||||
|
||||
You can provide custom loss metrics and training operations if you
|
||||
override this method. If overriding this method, you can access model,
|
||||
@@ -237,15 +238,18 @@ class TrainingOperator:
|
||||
calculate averages.
|
||||
|
||||
"""
|
||||
features, target = batch
|
||||
# unpack features into list to support multiple inputs model
|
||||
*features, target = batch
|
||||
# Create non_blocking tensors for distributed training
|
||||
if self.use_gpu:
|
||||
features = features.cuda(non_blocking=True)
|
||||
features = [
|
||||
feature.cuda(non_blocking=True) for feature in features
|
||||
]
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# Compute output.
|
||||
with self.timers.record("fwd"):
|
||||
output = self.model(features)
|
||||
output = self.model(*features)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# Compute gradients in a backward pass.
|
||||
@@ -261,7 +265,7 @@ class TrainingOperator:
|
||||
with self.timers.record("apply"):
|
||||
self.optimizer.step()
|
||||
|
||||
return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)}
|
||||
return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)}
|
||||
|
||||
def validate(self, val_iterator, info):
|
||||
"""Runs one standard validation pass over the val_iterator.
|
||||
@@ -304,6 +308,10 @@ class TrainingOperator:
|
||||
|
||||
You can override this method to provide arbitrary metrics.
|
||||
|
||||
Same as ``train_batch``, this method implementation assumes that
|
||||
batches are in (\*features, labels) format by default. So we also
|
||||
support multiple inputs model.
|
||||
|
||||
Args:
|
||||
batch: One item of the validation iterator.
|
||||
batch_info (dict): Contains information per batch from
|
||||
@@ -317,15 +325,18 @@ class TrainingOperator:
|
||||
by default, ``validate`` uses "num_samples" to
|
||||
calculate averages.
|
||||
"""
|
||||
features, target = batch
|
||||
# unpack features into list to support multiple inputs model
|
||||
*features, target = batch
|
||||
if self.use_gpu:
|
||||
features = features.cuda(non_blocking=True)
|
||||
features = [
|
||||
feature.cuda(non_blocking=True) for feature in features
|
||||
]
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
|
||||
with self.timers.record("eval_fwd"):
|
||||
output = self.model(features)
|
||||
output = self.model(*features)
|
||||
loss = self.criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user