mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:10:40 +08:00
[tune/sgd] Import ABC from collections.abc instead of collectio… (#7982)
* Import ABC from collections.abc instead of collections for Python 3 compatibility. * Fix linter errors.
This commit is contained in:
committed by
GitHub
parent
42f88ecf9d
commit
f95e18dfeb
@@ -7,6 +7,11 @@ from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
||||
try:
|
||||
from collections.abc import Mapping
|
||||
except ImportError:
|
||||
from collections import Mapping
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
@@ -101,7 +106,7 @@ class TuneReporterBase(ProgressReporter):
|
||||
if metric in self._metric_columns:
|
||||
raise ValueError("Column {} already exists.".format(metric))
|
||||
|
||||
if isinstance(self._metric_columns, collections.Mapping):
|
||||
if isinstance(self._metric_columns, Mapping):
|
||||
representation = representation or metric
|
||||
self._metric_columns[metric] = representation
|
||||
else:
|
||||
@@ -291,7 +296,7 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
||||
num_trials, ", ".join(num_trials_strs)))
|
||||
|
||||
# Pre-process trials to figure out what columns to show.
|
||||
if isinstance(metric_columns, collections.Mapping):
|
||||
if isinstance(metric_columns, Mapping):
|
||||
keys = list(metric_columns.keys())
|
||||
else:
|
||||
keys = metric_columns
|
||||
@@ -303,7 +308,7 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
||||
params = sorted(set().union(*[t.evaluated_params for t in trials]))
|
||||
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
|
||||
# Format column headings
|
||||
if isinstance(metric_columns, collections.Mapping):
|
||||
if isinstance(metric_columns, Mapping):
|
||||
formatted_columns = [metric_columns[k] for k in keys]
|
||||
else:
|
||||
formatted_columns = keys
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import inspect
|
||||
@@ -17,6 +16,11 @@ from ray.util.sgd import utils
|
||||
logger = logging.getLogger(__name__)
|
||||
amp = None
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
from collections import Iterable
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
@@ -133,7 +137,7 @@ class TorchRunner:
|
||||
self.schedulers = self.scheduler_creator(self.given_optimizers,
|
||||
self.config)
|
||||
|
||||
if not isinstance(self.schedulers, collections.Iterable):
|
||||
if not isinstance(self.schedulers, Iterable):
|
||||
self.schedulers = [self.schedulers]
|
||||
|
||||
def _try_setup_apex(self):
|
||||
@@ -153,7 +157,7 @@ class TorchRunner:
|
||||
self._initialize_dataloaders()
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
if not isinstance(self.models, Iterable):
|
||||
self.models = [self.models]
|
||||
assert all(isinstance(model, nn.Module) for model in self.models), (
|
||||
"All models must be PyTorch models: {}.".format(self.models))
|
||||
@@ -163,7 +167,7 @@ class TorchRunner:
|
||||
logger.debug("Creating optimizer.")
|
||||
self.optimizers = self.optimizer_creator(self.given_models,
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
if not isinstance(self.optimizers, Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
|
||||
self._create_schedulers_if_available()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import collections
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
|
||||
@@ -8,6 +7,11 @@ from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
|
||||
|
||||
amp = None
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
from collections import Iterable
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
@@ -25,7 +29,7 @@ except ImportError:
|
||||
|
||||
def _is_multiple(component):
|
||||
"""Checks if a component (optimizer, model, etc) is not singular."""
|
||||
return isinstance(component, collections.Iterable) and len(component) > 1
|
||||
return isinstance(component, Iterable) and len(component) > 1
|
||||
|
||||
|
||||
class TrainingOperator:
|
||||
@@ -66,19 +70,24 @@ class TrainingOperator:
|
||||
use_tqdm=False):
|
||||
# You are not expected to override this method.
|
||||
self._models = models # List of models
|
||||
assert isinstance(models, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(type(models)))
|
||||
assert isinstance(
|
||||
models,
|
||||
Iterable), ("Components need to be iterable. Got: {}".format(
|
||||
type(models)))
|
||||
self._optimizers = optimizers # List of optimizers
|
||||
assert isinstance(optimizers, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(type(optimizers)))
|
||||
assert isinstance(
|
||||
optimizers,
|
||||
Iterable), ("Components need to be iterable. Got: {}".format(
|
||||
type(optimizers)))
|
||||
self._train_loader = train_loader
|
||||
self._validation_loader = validation_loader
|
||||
self._world_rank = world_rank
|
||||
self._criterion = criterion
|
||||
self._schedulers = schedulers
|
||||
if schedulers:
|
||||
assert isinstance(schedulers, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(
|
||||
assert isinstance(
|
||||
schedulers,
|
||||
Iterable), ("Components need to be iterable. Got: {}".format(
|
||||
type(schedulers)))
|
||||
self._config = config
|
||||
self._use_fp16 = use_fp16
|
||||
|
||||
Reference in New Issue
Block a user