diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 3584fa5a9..a41ae4b38 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -7,6 +7,11 @@ from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL) from ray.tune.utils import flatten_dict +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + try: from tabulate import tabulate except ImportError: @@ -101,7 +106,7 @@ class TuneReporterBase(ProgressReporter): if metric in self._metric_columns: raise ValueError("Column {} already exists.".format(metric)) - if isinstance(self._metric_columns, collections.Mapping): + if isinstance(self._metric_columns, Mapping): representation = representation or metric self._metric_columns[metric] = representation else: @@ -291,7 +296,7 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None): num_trials, ", ".join(num_trials_strs))) # Pre-process trials to figure out what columns to show. - if isinstance(metric_columns, collections.Mapping): + if isinstance(metric_columns, Mapping): keys = list(metric_columns.keys()) else: keys = metric_columns @@ -303,7 +308,7 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None): params = sorted(set().union(*[t.evaluated_params for t in trials])) trial_table = [_get_trial_info(trial, params, keys) for trial in trials] # Format column headings - if isinstance(metric_columns, collections.Mapping): + if isinstance(metric_columns, Mapping): formatted_columns = [metric_columns[k] for k in keys] else: formatted_columns = keys diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 952b06ee7..f37ff6e93 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -1,4 +1,3 @@ -import collections from filelock import FileLock import logging import inspect @@ -17,6 +16,11 @@ from ray.util.sgd import utils logger = logging.getLogger(__name__) amp = None +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable + try: from apex import amp except ImportError: @@ -133,7 +137,7 @@ class TorchRunner: self.schedulers = self.scheduler_creator(self.given_optimizers, self.config) - if not isinstance(self.schedulers, collections.Iterable): + if not isinstance(self.schedulers, Iterable): self.schedulers = [self.schedulers] def _try_setup_apex(self): @@ -153,7 +157,7 @@ class TorchRunner: self._initialize_dataloaders() logger.debug("Creating model") self.models = self.model_creator(self.config) - if not isinstance(self.models, collections.Iterable): + if not isinstance(self.models, Iterable): self.models = [self.models] assert all(isinstance(model, nn.Module) for model in self.models), ( "All models must be PyTorch models: {}.".format(self.models)) @@ -163,7 +167,7 @@ class TorchRunner: logger.debug("Creating optimizer.") self.optimizers = self.optimizer_creator(self.given_models, self.config) - if not isinstance(self.optimizers, collections.Iterable): + if not isinstance(self.optimizers, Iterable): self.optimizers = [self.optimizers] self._create_schedulers_if_available() diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 4cade28c8..7e29289b7 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -1,4 +1,3 @@ -import collections import torch from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, @@ -8,6 +7,11 @@ from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, amp = None +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable + try: from apex import amp except ImportError: @@ -25,7 +29,7 @@ except ImportError: def _is_multiple(component): """Checks if a component (optimizer, model, etc) is not singular.""" - return isinstance(component, collections.Iterable) and len(component) > 1 + return isinstance(component, Iterable) and len(component) > 1 class TrainingOperator: @@ -66,19 +70,24 @@ class TrainingOperator: use_tqdm=False): # You are not expected to override this method. self._models = models # List of models - assert isinstance(models, collections.Iterable), ( - "Components need to be iterable. Got: {}".format(type(models))) + assert isinstance( + models, + Iterable), ("Components need to be iterable. Got: {}".format( + type(models))) self._optimizers = optimizers # List of optimizers - assert isinstance(optimizers, collections.Iterable), ( - "Components need to be iterable. Got: {}".format(type(optimizers))) + assert isinstance( + optimizers, + Iterable), ("Components need to be iterable. Got: {}".format( + type(optimizers))) self._train_loader = train_loader self._validation_loader = validation_loader self._world_rank = world_rank self._criterion = criterion self._schedulers = schedulers if schedulers: - assert isinstance(schedulers, collections.Iterable), ( - "Components need to be iterable. Got: {}".format( + assert isinstance( + schedulers, + Iterable), ("Components need to be iterable. Got: {}".format( type(schedulers))) self._config = config self._use_fp16 = use_fp16