diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index 0eec95a02..c7a02df24 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -725,6 +725,55 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 trainer1.shutdown() +def test_multi_input_model(ray_start_2_cpus): + def model_creator(config): + class MultiInputModel(nn.Module): + def __init__(self): + super(MultiInputModel, self).__init__() + self._fc1 = torch.nn.Linear(1, 1) + self._fc2 = torch.nn.Linear(1, 1) + + def forward(self, x, y): + return self._fc1(x) + self._fc2(y) + + return MultiInputModel() + + def data_creator(config): + class LinearDataset(torch.utils.data.Dataset): + def __init__(self, a, b, size=1000): + x = np.random.randn(size) + y = np.random.randn(size) + self.x = torch.tensor(x, dtype=torch.float32) + self.y = torch.tensor(y, dtype=torch.float32) + self.z = torch.tensor(a * (x + y) + 2 * b, dtype=torch.float32) + + def __getitem__(self, index): + return (self.x[index, None], self.y[index, None], + self.z[index, None]) + + def __len__(self): + return len(self.x) + + train_dataset = LinearDataset(3, 4) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=config.get("batch_size", 32), + ) + return train_loader, None + + trainer = TorchTrainer( + model_creator=model_creator, + data_creator=data_creator, + optimizer_creator=optimizer_creator, + loss_creator=lambda config: nn.MSELoss(), + num_workers=1) + + metrics = trainer.train(num_steps=1) + assert metrics[BATCH_COUNT] == 1 + + trainer.shutdown() + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 7e29289b7..493d261e4 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -212,8 +212,9 @@ class TrainingOperator: updating the model. By default, this method implementation assumes that batches - are in (features, labels) format. If using amp/fp16 - training, it will also scale the loss automatically. + are in (\*features, labels) format. So we also support multiple inputs + model. If using amp/fp16 training, it will also scale the loss + automatically. You can provide custom loss metrics and training operations if you override this method. If overriding this method, you can access model, @@ -237,15 +238,18 @@ class TrainingOperator: calculate averages. """ - features, target = batch + # unpack features into list to support multiple inputs model + *features, target = batch # Create non_blocking tensors for distributed training if self.use_gpu: - features = features.cuda(non_blocking=True) + features = [ + feature.cuda(non_blocking=True) for feature in features + ] target = target.cuda(non_blocking=True) # Compute output. with self.timers.record("fwd"): - output = self.model(features) + output = self.model(*features) loss = self.criterion(output, target) # Compute gradients in a backward pass. @@ -261,7 +265,7 @@ class TrainingOperator: with self.timers.record("apply"): self.optimizer.step() - return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)} + return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)} def validate(self, val_iterator, info): """Runs one standard validation pass over the val_iterator. @@ -304,6 +308,10 @@ class TrainingOperator: You can override this method to provide arbitrary metrics. + Same as ``train_batch``, this method implementation assumes that + batches are in (\*features, labels) format by default. So we also + support multiple inputs model. + Args: batch: One item of the validation iterator. batch_info (dict): Contains information per batch from @@ -317,15 +325,18 @@ class TrainingOperator: by default, ``validate`` uses "num_samples" to calculate averages. """ - features, target = batch + # unpack features into list to support multiple inputs model + *features, target = batch if self.use_gpu: - features = features.cuda(non_blocking=True) + features = [ + feature.cuda(non_blocking=True) for feature in features + ] target = target.cuda(non_blocking=True) # compute output with self.timers.record("eval_fwd"): - output = self.model(features) + output = self.model(*features) loss = self.criterion(output, target) _, predicted = torch.max(output.data, 1)