diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index c557aaebb..916fa1cbf 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -93,7 +93,7 @@ class RayTrialExecutor(TrialExecutor): memory=trial.resources.memory, object_store_memory=trial.resources.object_store_memory, resources=trial.resources.custom_resources)( - trial._get_trainable_cls()) + trial.get_trainable_cls()) trial.init_logger() # We checkpoint metadata here to try mitigating logdir duplication @@ -622,6 +622,11 @@ class RayTrialExecutor(TrialExecutor): trial.runner.export_model.remote(trial.export_formats)) return {} + def has_gpus(self): + if self._resources_initialized: + self._update_avail_resources() + return self._avail_resources.gpu > 0 + def _to_gb(n_bytes): return round(n_bytes / (1024**3), 2) diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 4c3c09b90..80d78e70f 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -7,6 +7,7 @@ from types import FunctionType import ray import ray.cloudpickle as pickle + from ray.experimental.internal_kv import _internal_kv_initialized, \ _internal_kv_get, _internal_kv_put @@ -23,6 +24,24 @@ KNOWN_CATEGORIES = [ logger = logging.getLogger(__name__) +def has_trainable(trainable_name): + return _global_registry.contains(TRAINABLE_CLASS, trainable_name) + + +def get_trainable_cls(trainable_name): + validate_trainable(trainable_name) + return _global_registry.get(TRAINABLE_CLASS, trainable_name) + + +def validate_trainable(trainable_name): + if not has_trainable(trainable_name): + # Make sure rllib agents are registered + from ray import rllib # noqa: F401 + from ray.tune.error import TuneError + if not has_trainable(trainable_name): + raise TuneError("Unknown trainable: " + trainable_name) + + def register_trainable(name, trainable): """Register a trainable function or class. diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index d31a5dfb0..840533c7e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -10,13 +10,12 @@ import uuid import time import tempfile import os -import ray from ray.tune import TuneError from ray.tune.logger import pretty_print, UnifiedLogger # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not # have been defined yet. See https://github.com/ray-project/ray/issues/1716. -import ray.tune.registry +from ray.tune.registry import get_trainable_cls, validate_trainable from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION from ray.utils import binary_to_hex, hex_to_binary from ray.tune.resources import Resources, json_to_resources, resources_to_json @@ -30,11 +29,6 @@ def date_str(): return datetime.today().strftime("%Y-%m-%d_%H-%M-%S") -def has_trainable(trainable_name): - return ray.tune.registry._global_registry.contains( - ray.tune.registry.TRAINABLE_CLASS, trainable_name) - - class Checkpoint(object): """Describes a checkpoint of trial state. @@ -126,7 +120,7 @@ class Trial(object): in ray.tune.config_parser. """ - Trial._registration_check(trainable_name) + validate_trainable(trainable_name) # Trial config self.trainable_name = trainable_name self.trial_id = Trial.generate_id() if trial_id is None else trial_id @@ -136,7 +130,7 @@ class Trial(object): #: Parameters that Tune varies across searches. self.evaluated_params = evaluated_params or {} self.experiment_tag = experiment_tag - trainable_cls = self._get_trainable_cls() + trainable_cls = self.get_trainable_cls() if trainable_cls and hasattr(trainable_cls, "default_resource_request"): default_resources = trainable_cls.default_resource_request( @@ -202,14 +196,6 @@ class Trial(object): if trial_name_creator: self.custom_trial_name = trial_name_creator(self) - @classmethod - def _registration_check(cls, trainable_name): - if not has_trainable(trainable_name): - # Make sure rllib agents are registered - from ray import rllib # noqa: F401 - if not has_trainable(trainable_name): - raise TuneError("Unknown trainable: " + trainable_name) - @classmethod def generate_id(cls): return str(uuid.uuid1().hex)[:8] @@ -363,9 +349,8 @@ class Trial(object): return True return False - def _get_trainable_cls(self): - return ray.tune.registry._global_registry.get( - ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) + def get_trainable_cls(self): + return get_trainable_cls(self.trainable_name) def set_verbose(self, verbose): self.verbose = verbose @@ -430,6 +415,6 @@ class Trial(object): state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) - Trial._registration_check(self.trainable_name) + validate_trainable(self.trainable_name) if logger_started: self.init_logger() diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 9e53e7108..27de24f5a 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -226,3 +226,7 @@ class TrialExecutor(object): """ raise NotImplementedError("Subclasses of TrialExecutor must provide " "export_trial_if_needed() method") + + def has_gpus(self): + """Returns True if GPUs are detected on the cluster.""" + return None diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 31c2c87bd..6814888e5 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -346,7 +346,7 @@ class TrialRunner(object): "up. {}").format( trial.resources.summary_string(), self.trial_executor.resource_string(), - trial._get_trainable_cls().resource_help( + trial.get_trainable_cls().resource_help( trial.config))) elif trial.status == Trial.PAUSED: raise TuneError( diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 90f33deb0..893d22201 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -10,7 +10,9 @@ from ray.tune.experiment import convert_to_experiment_list, Experiment from ray.tune.analysis import ExperimentAnalysis from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL +from ray.tune.trainable import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor +from ray.tune.registry import get_trainable_cls from ray.tune.syncer import wait_for_sync from ray.tune.trial_runner import TrialRunner from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter @@ -42,6 +44,13 @@ def _make_scheduler(args): args.scheduler, _SCHEDULERS.keys())) +def _check_default_resources_override(run_identifier): + trainable_cls = get_trainable_cls(run_identifier) + return hasattr(trainable_cls, "default_resource_request") and ( + trainable_cls.default_resource_request.__code__ != + Trainable.default_resource_request.__code__) + + def run(run_or_experiment, name=None, stop=None, @@ -250,6 +259,24 @@ def run(run_or_experiment, else: reporter = CLIReporter() + # User Warning for GPUs + if trial_executor.has_gpus(): + if isinstance(resources_per_trial, + dict) and "gpu" in resources_per_trial: + # "gpu" is manually set. + pass + elif _check_default_resources_override(run_identifier): + # "default_resources" is manually overriden. + pass + else: + logger.warning("Tune detects GPUs, but no trials are using GPUs. " + "To enable trials to use GPUs, set " + "tune.run(resources_per_trial={'gpu': 1}...) " + "which allows Tune to expose 1 GPU to each trial. " + "You can also override " + "`Trainable.default_resource_request` if using the " + "Trainable API.") + last_debug = 0 while not runner.is_finished(): runner.step()