[tune] add type hints to tune.run(), fix abstract methods of ProgressReporter (#13684)

This commit is contained in:
Kai Fricke
2021-01-27 16:43:50 +01:00
committed by GitHub
parent 2664a2a8f6
commit c5b645e3da
2 changed files with 86 additions and 66 deletions
+7
View File
@@ -57,6 +57,13 @@ class ProgressReporter:
"""
raise NotImplementedError
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]):
return True
def set_total_samples(self, total_samples: int):
pass
class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters.
+79 -66
View File
@@ -1,25 +1,35 @@
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, \
Union
import datetime
import logging
import sys
import time
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.suggest import BasicVariantGenerator, SearchGenerator
from ray.tune.callback import Callback
from ray.tune.error import TuneError
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.logger import Logger
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter, \
ProgressReporter
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls
from ray.tune.stopper import Stopper
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm, \
SearchGenerator
from ray.tune.suggest.suggestion import Searcher
from ray.tune.suggest.variant_generator import has_unresolved_values
from ray.tune.trial import Trial
from ray.tune.syncer import SyncConfig, set_sync_periods, wait_for_sync
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.registry import get_trainable_cls
from ray.tune.syncer import wait_for_sync, set_sync_periods, \
SyncConfig
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import FIFOScheduler
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
# Must come last to avoid circular imports
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
logger = logging.getLogger(__name__)
try:
@@ -55,50 +65,51 @@ def _report_progress(runner, reporter, done=False):
def run(
run_or_experiment,
name=None,
metric=None,
mode=None,
stop=None,
time_budget_s=None,
config=None,
resources_per_trial=None,
num_samples=1,
local_dir=None,
search_alg=None,
scheduler=None,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
checkpoint_freq=0,
checkpoint_at_end=False,
verbose=Verbosity.V3_TRIAL_DETAILS,
progress_reporter=None,
log_to_file=False,
trial_name_creator=None,
trial_dirname_creator=None,
sync_config=None,
export_formats=None,
max_failures=0,
fail_fast=False,
restore=None,
server_port=None,
resume=False,
queue_trials=False,
reuse_actors=False,
trial_executor=None,
raise_on_failed_trial=True,
callbacks=None,
run_or_experiment: Union[str, Callable, Type],
name: Optional[str] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
stop: Union[None, Mapping, Stopper, Callable[[str, Mapping],
bool]] = None,
time_budget_s: Union[None, int, float, datetime.timedelta] = None,
config: Optional[Dict[str, Any]] = None,
resources_per_trial: Optional[Mapping[str, Union[float, int]]] = None,
num_samples: int = 1,
local_dir: Optional[str] = None,
search_alg: Optional[Union[Searcher, SearchAlgorithm]] = None,
scheduler: Optional[TrialScheduler] = None,
keep_checkpoints_num: Optional[int] = None,
checkpoint_score_attr: Optional[str] = None,
checkpoint_freq: int = 0,
checkpoint_at_end: bool = False,
verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
progress_reporter: Optional[ProgressReporter] = None,
log_to_file: bool = False,
trial_name_creator: Optional[Callable[[Trial], str]] = None,
trial_dirname_creator: Optional[Callable[[Trial], str]] = None,
sync_config: Optional[SyncConfig] = None,
export_formats: Optional[Sequence] = None,
max_failures: int = 0,
fail_fast: bool = False,
restore: Optional[str] = None,
server_port: Optional[int] = None,
resume: bool = False,
queue_trials: bool = False,
reuse_actors: bool = False,
trial_executor: Optional[RayTrialExecutor] = None,
raise_on_failed_trial: bool = True,
callbacks: Optional[Sequence[Callback]] = None,
# Deprecated args
loggers=None,
ray_auto_init=None,
run_errored_only=None,
global_checkpoint_period=None,
with_server=None,
upload_dir=None,
sync_to_cloud=None,
sync_to_driver=None,
sync_on_checkpoint=None,
):
loggers: Optional[Sequence[Type[Logger]]] = None,
ray_auto_init: Optional = None,
run_errored_only: Optional = None,
global_checkpoint_period: Optional = None,
with_server: Optional = None,
upload_dir: Optional = None,
sync_to_cloud: Optional = None,
sync_to_driver: Optional = None,
sync_on_checkpoint: Optional = None,
) -> ExperimentAnalysis:
"""Executes training.
Examples:
@@ -458,18 +469,20 @@ def run(
default_mode=mode)
def run_experiments(experiments,
scheduler=None,
server_port=None,
verbose=Verbosity.V3_TRIAL_DETAILS,
progress_reporter=None,
resume=False,
queue_trials=False,
reuse_actors=False,
trial_executor=None,
raise_on_failed_trial=True,
concurrent=True,
callbacks=None):
def run_experiments(
experiments: Union[Experiment, Mapping, Sequence[Union[Experiment,
Mapping]]],
scheduler: Optional[TrialScheduler] = None,
server_port: Optional[int] = None,
verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
progress_reporter: Optional[ProgressReporter] = None,
resume: bool = False,
queue_trials: bool = False,
reuse_actors: bool = False,
trial_executor: Optional[RayTrialExecutor] = None,
raise_on_failed_trial: bool = True,
concurrent: bool = True,
callbacks: Optional[Sequence[Callback]] = None):
"""Runs and blocks until all trials finish.
Examples: