diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index ca0f34fa3..cf57d326e 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -370,6 +370,82 @@ def test_dataset(ray_start_4_cpus, use_local): trainer.shutdown() +@pytest.mark.parametrize("use_local", [True, False]) +def test_num_steps(ray_start_2_cpus, use_local): + """Tests if num_steps continues training from the subsampled dataset.""" + + def data_creator(config): + train_dataset = [0] * 5 + [1] * 5 + val_dataset = [0] * 5 + [1] * 5 + return DataLoader(train_dataset, batch_size=config["batch_size"]), \ + DataLoader(val_dataset, batch_size=config["batch_size"]) + + batch_size = 1 + Operator = TrainingOperator.from_creators(model_creator, optimizer_creator, + data_creator) + + def train_func(self, iterator, info=None): + total_sum = 0 + num_items = 0 + for e in iterator: + total_sum += e + num_items += 1 + return {"average": total_sum.item() / num_items} + + TestOperator = get_test_operator(Operator) + trainer = TorchTrainer( + training_operator_cls=TestOperator, + num_workers=2, + use_local=use_local, + add_dist_sampler=False, + config={ + "batch_size": batch_size, + "custom_func": train_func + }) + + # If num_steps not passed, should do one full epoch. + result = trainer.train() + # Average of 5 0s and 5 1s + assert result["average"] == 0.5 + assert result["epoch"] == 1 + val_result = trainer.validate() + assert val_result["average"] == 0.5 + + # Train again with num_steps. + result = trainer.train(num_steps=5) + # 5 zeros + assert result["average"] == 0 + assert result["epoch"] == 2 + val_result = trainer.validate(num_steps=5) + assert val_result["average"] == 0 + + # Should continue where last train run left off. + result = trainer.train(num_steps=3) + # 3 ones. + assert result["average"] == 1 + assert result["epoch"] == 2 + val_result = trainer.validate(num_steps=3) + assert val_result["average"] == 1 + + # Should continue from last train run, and cycle to beginning. + result = trainer.train(num_steps=5) + # 2 ones and 3 zeros. + assert result["average"] == 0.4 + assert result["epoch"] == 3 + val_result = trainer.validate(num_steps=5) + assert val_result["average"] == 0.4 + + # Should continue, and since num_steps not passed in, just finishes epoch. + result = trainer.train() + # 2 zeros and 5 ones. + assert result["average"] == 5 / 7 + assert result["epoch"] == 3 + val_result = trainer.validate() + assert val_result["average"] == 5 / 7 + + trainer.shutdown() + + @pytest.mark.parametrize("use_local", [True, False]) def test_split_batch(ray_start_2_cpus, use_local): if not dist.is_available(): @@ -467,7 +543,7 @@ def test_metrics(ray_start_2_cpus, num_workers, use_local): "val_size": val_size }) - stats = trainer.train(num_steps=num_train_steps) + stats = trainer.train() # Test that we output mean and last of custom metrics in an epoch assert "score" in stats assert stats["last_score"] == 0 diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 34d737b1d..f7b3872c3 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -47,6 +47,13 @@ class TorchRunner: "https://www.github.com/nvidia/apex to use fp16 training.") self.scheduler_step_freq = scheduler_step_freq + # Training and Validation iterators + self.train_iterator = None + self._should_reset_train_loader = True + + self.val_iterator = None + self._should_reset_val_loader = True + def setup_operator(self): """Create the training operator.""" self.training_operator = self.training_operator_cls( @@ -60,6 +67,49 @@ class TorchRunner: apex_args=self.apex_args, scheduler_step_freq=self.scheduler_step_freq) + def get_iterator(self, training=True): + if training: + # In training. + if self._should_reset_train_loader: + self.epochs += 1 + self.train_iterator = iter(self.train_loader) + self._should_reset_train_loader = False + return self.train_iterator + else: + # In validation. + if self._should_reset_val_loader: + self.val_iterator = iter(self.validation_loader) + self._should_reset_val_loader = False + return self.val_iterator + + def make_iterator(self, training=True, num_steps=None): + steps = 0 + # Needed to make sure we don't loop forever if iterator is empty + has_at_least_one = False + while True: + iterator = self.get_iterator(training=training) + if num_steps is not None and steps >= num_steps: + # Stop iterating after reaching num_steps. + break + try: + item = next(iterator) + steps += 1 + if not has_at_least_one: + has_at_least_one = True + yield item + except StopIteration: + # Set should reset iterator on next cycle to True. + if training: + self._should_reset_train_loader = True + else: + self._should_reset_val_loader = True + if num_steps is None or not has_at_least_one: + # End after current epoch or if iterator has no elements. + break + else: + # Else, start cycling through the iterator again. + pass + def train_epoch(self, num_steps=None, profile=False, @@ -76,9 +126,7 @@ class TorchRunner: "epoch_idx": self.epochs, }) with self.timers.record("train_epoch"): - if iterator is None: - iterator = iter(self.train_loader) - else: + if iterator is not None: # Dataset will provide us with a list of tuples but we # need two lists. def format_batch(batch): @@ -86,11 +134,14 @@ class TorchRunner: return torch.cat(features), torch.cat(targets) iterator = map(format_batch, iterator) - if num_steps: - iterator = itertools.islice(iterator, num_steps) + if num_steps: + iterator = itertools.islice(iterator, num_steps) + self.epochs += 1 + else: + iterator = self.make_iterator( + training=True, num_steps=num_steps) train_stats = self.training_operator.train_epoch(iterator, info) - self.epochs += 1 # This is so that `epochs` is first in ordering. stats = dict(epoch=self.epochs, **train_stats) if profile: @@ -101,12 +152,9 @@ class TorchRunner: """Evaluates the model on the validation data set.""" info = info or {} self._toggle_profiling(profile=profile) - validation_loader = self.validation_loader with self.timers.record("validation"): - iterator = validation_loader - if num_steps: - iterator = itertools.islice(iterator, num_steps) + iterator = self.make_iterator(training=False, num_steps=num_steps) validation_stats = self.training_operator.validate( iterator, info=info) if profile: @@ -197,6 +245,8 @@ class TorchRunner: def shutdown(self): """Attempts to shut down the worker.""" + del self.train_iterator + del self.val_iterator del self.training_operator if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 6efa476a5..93c1a91fb 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -375,9 +375,9 @@ class TorchTrainer: instance preemption. Args: - num_steps (int): Number of batches to compute update steps on. - This corresponds also to the number of times - ``TrainingOperator.train_batch`` is called. + num_steps (int): Number of batches to compute update steps on + per worker. This corresponds also to the number of times + ``TrainingOperator.train_batch`` is called per worker. profile (bool): Returns time stats for the training procedure. reduce_results (bool): Whether to average all metrics across all workers into one dict. If a metric is a non-numerical @@ -437,7 +437,6 @@ class TorchTrainer: NUM_SAMPLES: sum( stats.pop(NUM_SAMPLES, np.nan) for stats in worker_stats) } - for stat_key in worker_stats[0]: if isinstance(worker_stats[0][stat_key], numbers.Number): stats[stat_key] = np.nanmean( @@ -479,9 +478,9 @@ class TorchTrainer: """Evaluates the model on the validation data set. Args: - num_steps (int): Number of batches to compute update steps on. - This corresponds also to the number of times - ``TrainingOperator.validate_batch`` is called. + num_steps (int): Number of batches to compute update steps on + per worker. This corresponds also to the number of times + ``TrainingOperator.validate_batch`` is called per worker. profile (bool): Returns time stats for the evaluation procedure. reduce_results (bool): Whether to average all metrics across all workers into one dict. If a metric is a non-numerical diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 4b8b00cc2..8e40d37f7 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -1032,6 +1032,9 @@ def get_test_operator(operator_cls): return func(self, iterator, info) return {"done": 1} + def validate(self, iterator, info): + return self.train_epoch(iterator, info) + return _TestingOperator