mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[tune] add type hints to tune.run(), fix abstract methods of ProgressReporter (#13684)
This commit is contained in:
@@ -57,6 +57,13 @@ class ProgressReporter:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]):
|
||||
return True
|
||||
|
||||
def set_total_samples(self, total_samples: int):
|
||||
pass
|
||||
|
||||
|
||||
class TuneReporterBase(ProgressReporter):
|
||||
"""Abstract base class for the default Tune reporters.
|
||||
|
||||
+79
-66
@@ -1,25 +1,35 @@
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, \
|
||||
Union
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis
|
||||
from ray.tune.suggest import BasicVariantGenerator, SearchGenerator
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter, \
|
||||
ProgressReporter
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.stopper import Stopper
|
||||
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm, \
|
||||
SearchGenerator
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
from ray.tune.suggest.variant_generator import has_unresolved_values
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.syncer import SyncConfig, set_sync_periods, wait_for_sync
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.syncer import wait_for_sync, set_sync_periods, \
|
||||
SyncConfig
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||
from ray.tune.schedulers import FIFOScheduler
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
|
||||
|
||||
# Must come last to avoid circular imports
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
@@ -55,50 +65,51 @@ def _report_progress(runner, reporter, done=False):
|
||||
|
||||
|
||||
def run(
|
||||
run_or_experiment,
|
||||
name=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
stop=None,
|
||||
time_budget_s=None,
|
||||
config=None,
|
||||
resources_per_trial=None,
|
||||
num_samples=1,
|
||||
local_dir=None,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
verbose=Verbosity.V3_TRIAL_DETAILS,
|
||||
progress_reporter=None,
|
||||
log_to_file=False,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
sync_config=None,
|
||||
export_formats=None,
|
||||
max_failures=0,
|
||||
fail_fast=False,
|
||||
restore=None,
|
||||
server_port=None,
|
||||
resume=False,
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True,
|
||||
callbacks=None,
|
||||
run_or_experiment: Union[str, Callable, Type],
|
||||
name: Optional[str] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
stop: Union[None, Mapping, Stopper, Callable[[str, Mapping],
|
||||
bool]] = None,
|
||||
time_budget_s: Union[None, int, float, datetime.timedelta] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
resources_per_trial: Optional[Mapping[str, Union[float, int]]] = None,
|
||||
num_samples: int = 1,
|
||||
local_dir: Optional[str] = None,
|
||||
search_alg: Optional[Union[Searcher, SearchAlgorithm]] = None,
|
||||
scheduler: Optional[TrialScheduler] = None,
|
||||
keep_checkpoints_num: Optional[int] = None,
|
||||
checkpoint_score_attr: Optional[str] = None,
|
||||
checkpoint_freq: int = 0,
|
||||
checkpoint_at_end: bool = False,
|
||||
verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
|
||||
progress_reporter: Optional[ProgressReporter] = None,
|
||||
log_to_file: bool = False,
|
||||
trial_name_creator: Optional[Callable[[Trial], str]] = None,
|
||||
trial_dirname_creator: Optional[Callable[[Trial], str]] = None,
|
||||
sync_config: Optional[SyncConfig] = None,
|
||||
export_formats: Optional[Sequence] = None,
|
||||
max_failures: int = 0,
|
||||
fail_fast: bool = False,
|
||||
restore: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
resume: bool = False,
|
||||
queue_trials: bool = False,
|
||||
reuse_actors: bool = False,
|
||||
trial_executor: Optional[RayTrialExecutor] = None,
|
||||
raise_on_failed_trial: bool = True,
|
||||
callbacks: Optional[Sequence[Callback]] = None,
|
||||
# Deprecated args
|
||||
loggers=None,
|
||||
ray_auto_init=None,
|
||||
run_errored_only=None,
|
||||
global_checkpoint_period=None,
|
||||
with_server=None,
|
||||
upload_dir=None,
|
||||
sync_to_cloud=None,
|
||||
sync_to_driver=None,
|
||||
sync_on_checkpoint=None,
|
||||
):
|
||||
loggers: Optional[Sequence[Type[Logger]]] = None,
|
||||
ray_auto_init: Optional = None,
|
||||
run_errored_only: Optional = None,
|
||||
global_checkpoint_period: Optional = None,
|
||||
with_server: Optional = None,
|
||||
upload_dir: Optional = None,
|
||||
sync_to_cloud: Optional = None,
|
||||
sync_to_driver: Optional = None,
|
||||
sync_on_checkpoint: Optional = None,
|
||||
) -> ExperimentAnalysis:
|
||||
"""Executes training.
|
||||
|
||||
Examples:
|
||||
@@ -458,18 +469,20 @@ def run(
|
||||
default_mode=mode)
|
||||
|
||||
|
||||
def run_experiments(experiments,
|
||||
scheduler=None,
|
||||
server_port=None,
|
||||
verbose=Verbosity.V3_TRIAL_DETAILS,
|
||||
progress_reporter=None,
|
||||
resume=False,
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True,
|
||||
concurrent=True,
|
||||
callbacks=None):
|
||||
def run_experiments(
|
||||
experiments: Union[Experiment, Mapping, Sequence[Union[Experiment,
|
||||
Mapping]]],
|
||||
scheduler: Optional[TrialScheduler] = None,
|
||||
server_port: Optional[int] = None,
|
||||
verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
|
||||
progress_reporter: Optional[ProgressReporter] = None,
|
||||
resume: bool = False,
|
||||
queue_trials: bool = False,
|
||||
reuse_actors: bool = False,
|
||||
trial_executor: Optional[RayTrialExecutor] = None,
|
||||
raise_on_failed_trial: bool = True,
|
||||
concurrent: bool = True,
|
||||
callbacks: Optional[Sequence[Callback]] = None):
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Examples:
|
||||
|
||||
Reference in New Issue
Block a user