[tune] resources_per_trial from trial_resources (#3580)

Renaming variable due to user errors.
This commit is contained in:
Richard Liaw
2018-12-20 19:00:47 -08:00
committed by GitHub
parent a174a46e02
commit e046a5c767
14 changed files with 73 additions and 34 deletions
+3 -3
View File
@@ -100,9 +100,9 @@ def run(args, parser):
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"local_dir": args.local_dir,
"trial_resources": (
args.trial_resources and
resources_to_json(args.trial_resources)),
"resources_per_trial": (
args.resources_per_trial and
resources_to_json(args.resources_per_trial)),
"stop": args.stop,
"config": dict(args.config, env=args.env),
"restore": args.restore,
+4 -3
View File
@@ -83,7 +83,7 @@ def make_parser(parser_creator=None, **kwargs):
help="Algorithm-specific configuration (e.g. env, hyperparams), "
"specified in JSON.")
parser.add_argument(
"--trial-resources",
"--resources-per-trial",
default=None,
type=json_to_resources,
help="Override the machine resources to allocate per trial, e.g. "
@@ -197,8 +197,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
args = parser.parse_args(to_argv(spec))
except SystemExit:
raise TuneError("Error parsing args, see above message", spec)
if "trial_resources" in spec:
trial_kwargs["resources"] = json_to_resources(spec["trial_resources"])
if "resources_per_trial" in spec:
trial_kwargs["resources"] = json_to_resources(
spec["resources_per_trial"])
return Trial(
# Submitting trial via server in py2.7 creates Unicode, which does not
# convert to string in a straightforward manner.
@@ -71,7 +71,7 @@ if __name__ == "__main__":
"training_iteration": 1 if args.smoke_test else 99999
},
"num_samples": 20,
"trial_resources": {
"resources_per_trial": {
"cpu": 1,
"gpu": 0
},
+1 -1
View File
@@ -175,7 +175,7 @@ if __name__ == '__main__':
"mean_accuracy": 0.98,
"training_iteration": 1 if args.smoke_test else 20
},
"trial_resources": {
"resources_per_trial": {
"cpu": 3
},
"run": "train_mnist",
@@ -187,7 +187,7 @@ if __name__ == '__main__':
"mean_accuracy": 0.95,
"training_iteration": 1 if args.smoke_test else 20,
},
"trial_resources": {
"resources_per_trial": {
"cpu": 3
},
"run": TrainMNIST,
@@ -181,7 +181,7 @@ if __name__ == "__main__":
train_spec = {
"run": Cifar10Model,
"trial_resources": {
"resources_per_trial": {
"cpu": 1,
"gpu": 1
},
+1 -1
View File
@@ -187,7 +187,7 @@ if __name__ == '__main__':
},
"run": "train_mnist",
"num_samples": 1 if args.smoke_test else 10,
"trial_resources": {
"resources_per_trial": {
"cpu": args.threads,
"gpu": 0.5 if args.use_gpu else 0
},
+36 -15
View File
@@ -15,6 +15,23 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
logger = logging.getLogger(__name__)
def _raise_deprecation_note(deprecated, replacement, soft=False):
"""User notification for deprecated parameter.
Arguments:
deprecated (str): Deprecated parameter.
replacement (str): Replacement parameter to use instead.
soft (bool): Fatal if True.
"""
error_msg = ("`{deprecated}` is deprecated. Please use `{replacement}`. "
"`{deprecated}` will be removed in future versions of "
"Ray.".format(deprecated=deprecated, replacement=replacement))
if soft:
logger.warning(error_msg)
else:
raise DeprecationWarning(error_msg)
class Experiment(object):
"""Tracks experiment specifications.
@@ -31,12 +48,10 @@ class Experiment(object):
config (dict): Algorithm-specific configuration for Tune variant
generation (e.g. env, hyperparams). Defaults to empty dict.
Custom search algorithms may ignore this.
trial_resources (dict): Machine resources to allocate per trial,
resources_per_trial (dict): Machine resources to allocate per trial,
e.g. ``{"cpu": 64, "gpu": 8}``. Note that GPUs will not be
assigned unless you specify them here. Defaults to 1 CPU and 0
GPUs in ``Trainable.default_resource_request()``.
repeat (int): Deprecated and will be removed in future versions of
Ray. Use `num_samples` instead.
num_samples (int): Number of times to sample from the
hyperparameter space. Defaults to 1. If `grid_search` is
provided as an argument, the grid will be repeated
@@ -62,6 +77,10 @@ class Experiment(object):
checkpointing is enabled. Defaults to 3.
restore (str): Path to checkpoint. Only makes sense to set if
running 1 trial. Defaults to None.
repeat: Deprecated and will be removed in future versions of
Ray. Use `num_samples` instead.
trial_resources: Deprecated and will be removed in future versions of
Ray. Use `resources_per_trial` instead.
Examples:
@@ -73,7 +92,7 @@ class Experiment(object):
>>> "alpha": tune.grid_search([0.2, 0.4, 0.6]),
>>> "beta": tune.grid_search([1, 2]),
>>> },
>>> trial_resources={
>>> resources_per_trial={
>>> "cpu": 1,
>>> "gpu": 0
>>> },
@@ -90,8 +109,7 @@ class Experiment(object):
run,
stop=None,
config=None,
trial_resources=None,
repeat=1,
resources_per_trial=None,
num_samples=1,
local_dir=None,
upload_dir=None,
@@ -101,15 +119,25 @@ class Experiment(object):
checkpoint_freq=0,
checkpoint_at_end=False,
max_failures=3,
restore=None):
restore=None,
repeat=None,
trial_resources=None):
validate_sync_function(sync_function)
if sync_function:
assert upload_dir, "Need `upload_dir` if sync_function given."
if repeat:
_raise_deprecation_note("repeat", "num_samples", soft=False)
if trial_resources:
_raise_deprecation_note(
"trial_resources", "resources_per_trial", soft=True)
resources_per_trial = trial_resources
spec = {
"run": self._register_if_needed(run),
"stop": stop or {},
"config": config or {},
"trial_resources": trial_resources,
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"upload_dir": upload_dir or "", # argparse converts None to "null"
@@ -136,13 +164,6 @@ class Experiment(object):
if "run" not in spec:
raise TuneError("No trainable specified!")
if "repeat" in spec:
raise DeprecationWarning("The parameter `repeat` is deprecated; \
converting to `num_samples`. `repeat` will be removed in \
future versions of Ray.")
spec["num_samples"] = spec["repeat"]
del spec["repeat"]
# Special case the `env` param for RLlib by automatically
# moving it into the `config` section.
if "env" in spec:
+2 -1
View File
@@ -100,7 +100,8 @@ class UnifiedLogger(Logger):
try:
self._loggers.append(cls(self.config, self.logdir, self.uri))
except Exception:
logger.exception("Could not instantiate {} - skipping.")
logger.exception("Could not instantiate {} - skipping.".format(
str(cls)))
self._log_syncer = get_syncer(
self.logdir, self.uri, sync_function=self._sync_function)
+1 -1
View File
@@ -96,7 +96,7 @@ _MAX_RESOLUTION_PASSES = 20
def format_vars(resolved_vars):
out = []
for path, value in sorted(resolved_vars.items()):
if path[0] in ["run", "env", "trial_resources"]:
if path[0] in ["run", "env", "resources_per_trial"]:
continue # TrialRunner already has these in the experiment_tag
pieces = []
last_string = True
+17 -1
View File
@@ -294,7 +294,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
run_experiments({
"foo": {
"run": "PPO",
"trial_resources": {
"resources_per_trial": {
"asdf": 1
}
}
@@ -681,6 +681,22 @@ class RunExperimentTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testDeprecatedResources(self):
class train(Trainable):
def _train(self):
return {"timesteps_this_iter": 1, "done": True}
trials = run_experiments({
"foo": {
"run": train,
"trial_resources": {
"cpu": 1
}
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
def testCustomLogger(self):
class CustomLogger(Logger):
def on_result(self, result):
+1 -1
View File
@@ -65,7 +65,7 @@ class TuneServerSuite(unittest.TestCase):
"stop": {
"training_iteration": 3
},
"trial_resources": {
"resources_per_trial": {
'cpu': 1,
'gpu': 1
},