mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 06:34:08 +08:00
[tune] Disallow setting resources_per_trial when it is already configured (#4880)
* disallow it * import fix * fix example * fix test * fix tests * Update mock.py * fix * make less convoluted * fix tests
This commit is contained in:
@@ -20,6 +20,10 @@ class _MockTrainer(Trainer):
|
||||
"num_workers": 0,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
return None
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
@@ -40,13 +40,9 @@ def my_train_fn(config, reporter):
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
tune.run(
|
||||
my_train_fn,
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
},
|
||||
config={
|
||||
"lr": 0.01,
|
||||
"num_workers": 0,
|
||||
},
|
||||
)
|
||||
config = {
|
||||
"lr": 0.01,
|
||||
"num_workers": 0,
|
||||
}
|
||||
resources = PPOTrainer.default_resource_request(config).to_json()
|
||||
tune.run(my_train_fn, resources_per_trial=resources, config=config)
|
||||
|
||||
@@ -21,7 +21,6 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
|
||||
EPISODES_THIS_ITER, EPISODES_TOTAL,
|
||||
TRAINING_ITERATION, RESULT_DUPLICATE)
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,7 +95,7 @@ class Trainable(object):
|
||||
allocation, so the user does not need to.
|
||||
"""
|
||||
|
||||
return Resources(cpu=1, gpu=0)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
|
||||
@@ -139,6 +139,9 @@ class Resources(
|
||||
return Resources(cpu, gpu, extra_cpu, extra_gpu, new_custom_res,
|
||||
extra_custom_res)
|
||||
|
||||
def to_json(self):
|
||||
return resources_to_json(self)
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
if data is None or data == "null":
|
||||
@@ -275,9 +278,20 @@ class Trial(object):
|
||||
self.config = config or {}
|
||||
self.local_dir = local_dir # This remains unexpanded for syncing.
|
||||
self.experiment_tag = experiment_tag
|
||||
self.resources = (
|
||||
resources
|
||||
or self._get_trainable_cls().default_resource_request(self.config))
|
||||
trainable_cls = self._get_trainable_cls()
|
||||
if trainable_cls and hasattr(trainable_cls,
|
||||
"default_resource_request"):
|
||||
default_resources = trainable_cls.default_resource_request(
|
||||
self.config)
|
||||
if default_resources:
|
||||
if resources:
|
||||
raise ValueError(
|
||||
"Resources for {} have been automatically set to {} "
|
||||
"by its `default_resource_request()` method. Please "
|
||||
"clear the `resources_per_trial` option.".format(
|
||||
trainable_cls, default_resources))
|
||||
resources = default_resources
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.upload_dir = upload_dir
|
||||
self.loggers = loggers
|
||||
|
||||
Reference in New Issue
Block a user