[tune/sgd] Import ABC from collections.abc instead of collectio… (#7982)

* Import ABC from collections.abc instead of collections for Python 3 compatibility.

* Fix linter errors.
This commit is contained in:
Karthikeyan Singaravelan
2020-04-17 03:56:49 +05:30
committed by GitHub
parent 42f88ecf9d
commit f95e18dfeb
3 changed files with 33 additions and 15 deletions
+8 -3
View File
@@ -7,6 +7,11 @@ from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
from ray.tune.utils import flatten_dict
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
try:
from tabulate import tabulate
except ImportError:
@@ -101,7 +106,7 @@ class TuneReporterBase(ProgressReporter):
if metric in self._metric_columns:
raise ValueError("Column {} already exists.".format(metric))
if isinstance(self._metric_columns, collections.Mapping):
if isinstance(self._metric_columns, Mapping):
representation = representation or metric
self._metric_columns[metric] = representation
else:
@@ -291,7 +296,7 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
num_trials, ", ".join(num_trials_strs)))
# Pre-process trials to figure out what columns to show.
if isinstance(metric_columns, collections.Mapping):
if isinstance(metric_columns, Mapping):
keys = list(metric_columns.keys())
else:
keys = metric_columns
@@ -303,7 +308,7 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
params = sorted(set().union(*[t.evaluated_params for t in trials]))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
# Format column headings
if isinstance(metric_columns, collections.Mapping):
if isinstance(metric_columns, Mapping):
formatted_columns = [metric_columns[k] for k in keys]
else:
formatted_columns = keys
+8 -4
View File
@@ -1,4 +1,3 @@
import collections
from filelock import FileLock
import logging
import inspect
@@ -17,6 +16,11 @@ from ray.util.sgd import utils
logger = logging.getLogger(__name__)
amp = None
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
try:
from apex import amp
except ImportError:
@@ -133,7 +137,7 @@ class TorchRunner:
self.schedulers = self.scheduler_creator(self.given_optimizers,
self.config)
if not isinstance(self.schedulers, collections.Iterable):
if not isinstance(self.schedulers, Iterable):
self.schedulers = [self.schedulers]
def _try_setup_apex(self):
@@ -153,7 +157,7 @@ class TorchRunner:
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
if not isinstance(self.models, Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
@@ -163,7 +167,7 @@ class TorchRunner:
logger.debug("Creating optimizer.")
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, collections.Iterable):
if not isinstance(self.optimizers, Iterable):
self.optimizers = [self.optimizers]
self._create_schedulers_if_available()
+17 -8
View File
@@ -1,4 +1,3 @@
import collections
import torch
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
@@ -8,6 +7,11 @@ from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
amp = None
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
try:
from apex import amp
except ImportError:
@@ -25,7 +29,7 @@ except ImportError:
def _is_multiple(component):
"""Checks if a component (optimizer, model, etc) is not singular."""
return isinstance(component, collections.Iterable) and len(component) > 1
return isinstance(component, Iterable) and len(component) > 1
class TrainingOperator:
@@ -66,19 +70,24 @@ class TrainingOperator:
use_tqdm=False):
# You are not expected to override this method.
self._models = models # List of models
assert isinstance(models, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(models)))
assert isinstance(
models,
Iterable), ("Components need to be iterable. Got: {}".format(
type(models)))
self._optimizers = optimizers # List of optimizers
assert isinstance(optimizers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(optimizers)))
assert isinstance(
optimizers,
Iterable), ("Components need to be iterable. Got: {}".format(
type(optimizers)))
self._train_loader = train_loader
self._validation_loader = validation_loader
self._world_rank = world_rank
self._criterion = criterion
self._schedulers = schedulers
if schedulers:
assert isinstance(schedulers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(
assert isinstance(
schedulers,
Iterable), ("Components need to be iterable. Got: {}".format(
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16