[tune/sgd] Import ABC from collections.abc instead of collectio… (#7982)

* Import ABC from collections.abc instead of collections for Python 3 compatibility.

* Fix linter errors.
This commit is contained in:
Karthikeyan Singaravelan
2020-04-17 03:56:49 +05:30
committed by GitHub
parent 42f88ecf9d
commit f95e18dfeb
3 changed files with 33 additions and 15 deletions
+8 -4
View File
@@ -1,4 +1,3 @@
import collections
from filelock import FileLock
import logging
import inspect
@@ -17,6 +16,11 @@ from ray.util.sgd import utils
logger = logging.getLogger(__name__)
amp = None
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
try:
from apex import amp
except ImportError:
@@ -133,7 +137,7 @@ class TorchRunner:
self.schedulers = self.scheduler_creator(self.given_optimizers,
self.config)
if not isinstance(self.schedulers, collections.Iterable):
if not isinstance(self.schedulers, Iterable):
self.schedulers = [self.schedulers]
def _try_setup_apex(self):
@@ -153,7 +157,7 @@ class TorchRunner:
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
if not isinstance(self.models, Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
@@ -163,7 +167,7 @@ class TorchRunner:
logger.debug("Creating optimizer.")
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, collections.Iterable):
if not isinstance(self.optimizers, Iterable):
self.optimizers = [self.optimizers]
self._create_schedulers_if_available()
+17 -8
View File
@@ -1,4 +1,3 @@
import collections
import torch
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
@@ -8,6 +7,11 @@ from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
amp = None
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
try:
from apex import amp
except ImportError:
@@ -25,7 +29,7 @@ except ImportError:
def _is_multiple(component):
"""Checks if a component (optimizer, model, etc) is not singular."""
return isinstance(component, collections.Iterable) and len(component) > 1
return isinstance(component, Iterable) and len(component) > 1
class TrainingOperator:
@@ -66,19 +70,24 @@ class TrainingOperator:
use_tqdm=False):
# You are not expected to override this method.
self._models = models # List of models
assert isinstance(models, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(models)))
assert isinstance(
models,
Iterable), ("Components need to be iterable. Got: {}".format(
type(models)))
self._optimizers = optimizers # List of optimizers
assert isinstance(optimizers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(optimizers)))
assert isinstance(
optimizers,
Iterable), ("Components need to be iterable. Got: {}".format(
type(optimizers)))
self._train_loader = train_loader
self._validation_loader = validation_loader
self._world_rank = world_rank
self._criterion = criterion
self._schedulers = schedulers
if schedulers:
assert isinstance(schedulers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(
assert isinstance(
schedulers,
Iterable), ("Components need to be iterable. Got: {}".format(
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16