mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 12:31:42 +08:00
[tune] resources_per_trial from trial_resources (#3580)
Renaming variable due to user errors.
This commit is contained in:
@@ -100,9 +100,9 @@ def run(args, parser):
|
||||
"run": args.run,
|
||||
"checkpoint_freq": args.checkpoint_freq,
|
||||
"local_dir": args.local_dir,
|
||||
"trial_resources": (
|
||||
args.trial_resources and
|
||||
resources_to_json(args.trial_resources)),
|
||||
"resources_per_trial": (
|
||||
args.resources_per_trial and
|
||||
resources_to_json(args.resources_per_trial)),
|
||||
"stop": args.stop,
|
||||
"config": dict(args.config, env=args.env),
|
||||
"restore": args.restore,
|
||||
|
||||
@@ -83,7 +83,7 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams), "
|
||||
"specified in JSON.")
|
||||
parser.add_argument(
|
||||
"--trial-resources",
|
||||
"--resources-per-trial",
|
||||
default=None,
|
||||
type=json_to_resources,
|
||||
help="Override the machine resources to allocate per trial, e.g. "
|
||||
@@ -197,8 +197,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
args = parser.parse_args(to_argv(spec))
|
||||
except SystemExit:
|
||||
raise TuneError("Error parsing args, see above message", spec)
|
||||
if "trial_resources" in spec:
|
||||
trial_kwargs["resources"] = json_to_resources(spec["trial_resources"])
|
||||
if "resources_per_trial" in spec:
|
||||
trial_kwargs["resources"] = json_to_resources(
|
||||
spec["resources_per_trial"])
|
||||
return Trial(
|
||||
# Submitting trial via server in py2.7 creates Unicode, which does not
|
||||
# convert to string in a straightforward manner.
|
||||
|
||||
@@ -71,7 +71,7 @@ if __name__ == "__main__":
|
||||
"training_iteration": 1 if args.smoke_test else 99999
|
||||
},
|
||||
"num_samples": 20,
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
|
||||
@@ -175,7 +175,7 @@ if __name__ == '__main__':
|
||||
"mean_accuracy": 0.98,
|
||||
"training_iteration": 1 if args.smoke_test else 20
|
||||
},
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
"cpu": 3
|
||||
},
|
||||
"run": "train_mnist",
|
||||
|
||||
@@ -187,7 +187,7 @@ if __name__ == '__main__':
|
||||
"mean_accuracy": 0.95,
|
||||
"training_iteration": 1 if args.smoke_test else 20,
|
||||
},
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
"cpu": 3
|
||||
},
|
||||
"run": TrainMNIST,
|
||||
|
||||
@@ -181,7 +181,7 @@ if __name__ == "__main__":
|
||||
|
||||
train_spec = {
|
||||
"run": Cifar10Model,
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
"cpu": 1,
|
||||
"gpu": 1
|
||||
},
|
||||
|
||||
@@ -187,7 +187,7 @@ if __name__ == '__main__':
|
||||
},
|
||||
"run": "train_mnist",
|
||||
"num_samples": 1 if args.smoke_test else 10,
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
"cpu": args.threads,
|
||||
"gpu": 0.5 if args.use_gpu else 0
|
||||
},
|
||||
|
||||
@@ -15,6 +15,23 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _raise_deprecation_note(deprecated, replacement, soft=False):
|
||||
"""User notification for deprecated parameter.
|
||||
|
||||
Arguments:
|
||||
deprecated (str): Deprecated parameter.
|
||||
replacement (str): Replacement parameter to use instead.
|
||||
soft (bool): Fatal if True.
|
||||
"""
|
||||
error_msg = ("`{deprecated}` is deprecated. Please use `{replacement}`. "
|
||||
"`{deprecated}` will be removed in future versions of "
|
||||
"Ray.".format(deprecated=deprecated, replacement=replacement))
|
||||
if soft:
|
||||
logger.warning(error_msg)
|
||||
else:
|
||||
raise DeprecationWarning(error_msg)
|
||||
|
||||
|
||||
class Experiment(object):
|
||||
"""Tracks experiment specifications.
|
||||
|
||||
@@ -31,12 +48,10 @@ class Experiment(object):
|
||||
config (dict): Algorithm-specific configuration for Tune variant
|
||||
generation (e.g. env, hyperparams). Defaults to empty dict.
|
||||
Custom search algorithms may ignore this.
|
||||
trial_resources (dict): Machine resources to allocate per trial,
|
||||
resources_per_trial (dict): Machine resources to allocate per trial,
|
||||
e.g. ``{"cpu": 64, "gpu": 8}``. Note that GPUs will not be
|
||||
assigned unless you specify them here. Defaults to 1 CPU and 0
|
||||
GPUs in ``Trainable.default_resource_request()``.
|
||||
repeat (int): Deprecated and will be removed in future versions of
|
||||
Ray. Use `num_samples` instead.
|
||||
num_samples (int): Number of times to sample from the
|
||||
hyperparameter space. Defaults to 1. If `grid_search` is
|
||||
provided as an argument, the grid will be repeated
|
||||
@@ -62,6 +77,10 @@ class Experiment(object):
|
||||
checkpointing is enabled. Defaults to 3.
|
||||
restore (str): Path to checkpoint. Only makes sense to set if
|
||||
running 1 trial. Defaults to None.
|
||||
repeat: Deprecated and will be removed in future versions of
|
||||
Ray. Use `num_samples` instead.
|
||||
trial_resources: Deprecated and will be removed in future versions of
|
||||
Ray. Use `resources_per_trial` instead.
|
||||
|
||||
|
||||
Examples:
|
||||
@@ -73,7 +92,7 @@ class Experiment(object):
|
||||
>>> "alpha": tune.grid_search([0.2, 0.4, 0.6]),
|
||||
>>> "beta": tune.grid_search([1, 2]),
|
||||
>>> },
|
||||
>>> trial_resources={
|
||||
>>> resources_per_trial={
|
||||
>>> "cpu": 1,
|
||||
>>> "gpu": 0
|
||||
>>> },
|
||||
@@ -90,8 +109,7 @@ class Experiment(object):
|
||||
run,
|
||||
stop=None,
|
||||
config=None,
|
||||
trial_resources=None,
|
||||
repeat=1,
|
||||
resources_per_trial=None,
|
||||
num_samples=1,
|
||||
local_dir=None,
|
||||
upload_dir=None,
|
||||
@@ -101,15 +119,25 @@ class Experiment(object):
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
max_failures=3,
|
||||
restore=None):
|
||||
restore=None,
|
||||
repeat=None,
|
||||
trial_resources=None):
|
||||
validate_sync_function(sync_function)
|
||||
if sync_function:
|
||||
assert upload_dir, "Need `upload_dir` if sync_function given."
|
||||
|
||||
if repeat:
|
||||
_raise_deprecation_note("repeat", "num_samples", soft=False)
|
||||
if trial_resources:
|
||||
_raise_deprecation_note(
|
||||
"trial_resources", "resources_per_trial", soft=True)
|
||||
resources_per_trial = trial_resources
|
||||
|
||||
spec = {
|
||||
"run": self._register_if_needed(run),
|
||||
"stop": stop or {},
|
||||
"config": config or {},
|
||||
"trial_resources": trial_resources,
|
||||
"resources_per_trial": resources_per_trial,
|
||||
"num_samples": num_samples,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir or "", # argparse converts None to "null"
|
||||
@@ -136,13 +164,6 @@ class Experiment(object):
|
||||
if "run" not in spec:
|
||||
raise TuneError("No trainable specified!")
|
||||
|
||||
if "repeat" in spec:
|
||||
raise DeprecationWarning("The parameter `repeat` is deprecated; \
|
||||
converting to `num_samples`. `repeat` will be removed in \
|
||||
future versions of Ray.")
|
||||
spec["num_samples"] = spec["repeat"]
|
||||
del spec["repeat"]
|
||||
|
||||
# Special case the `env` param for RLlib by automatically
|
||||
# moving it into the `config` section.
|
||||
if "env" in spec:
|
||||
|
||||
@@ -100,7 +100,8 @@ class UnifiedLogger(Logger):
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
except Exception:
|
||||
logger.exception("Could not instantiate {} - skipping.")
|
||||
logger.exception("Could not instantiate {} - skipping.".format(
|
||||
str(cls)))
|
||||
self._log_syncer = get_syncer(
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ _MAX_RESOLUTION_PASSES = 20
|
||||
def format_vars(resolved_vars):
|
||||
out = []
|
||||
for path, value in sorted(resolved_vars.items()):
|
||||
if path[0] in ["run", "env", "trial_resources"]:
|
||||
if path[0] in ["run", "env", "resources_per_trial"]:
|
||||
continue # TrialRunner already has these in the experiment_tag
|
||||
pieces = []
|
||||
last_string = True
|
||||
|
||||
@@ -294,7 +294,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "PPO",
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
"asdf": 1
|
||||
}
|
||||
}
|
||||
@@ -681,6 +681,22 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
def testDeprecatedResources(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"trial_resources": {
|
||||
"cpu": 1
|
||||
}
|
||||
}
|
||||
})
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
|
||||
def testCustomLogger(self):
|
||||
class CustomLogger(Logger):
|
||||
def on_result(self, result):
|
||||
|
||||
@@ -65,7 +65,7 @@ class TuneServerSuite(unittest.TestCase):
|
||||
"stop": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"trial_resources": {
|
||||
"resources_per_trial": {
|
||||
'cpu': 1,
|
||||
'gpu': 1
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user