[tune] Update trainable docs and support hparams (#5558)

This commit is contained in:
Richard Liaw
2019-09-04 12:44:42 -07:00
committed by Eric Liang
parent 3ea9062419
commit 34f6d2fc5c
12 changed files with 269 additions and 113 deletions
+4 -1
View File
@@ -4,6 +4,9 @@
Note that this requires a cluster with at least 8 GPUs in order for all trials
to run concurrently, otherwise PBT will round-robin train the trials which
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
Note that Tune in general does not need 8 GPUs, and this is just a more
computationally demainding example.
"""
from __future__ import absolute_import
@@ -51,9 +54,9 @@ if __name__ == "__main__":
name="pbt_humanoid_test",
scheduler=pbt,
**{
"env": "Humanoid-v1",
"num_samples": 8,
"config": {
"env": "Humanoid-v1",
"kl_coeff": 1.0,
"num_workers": 8,
"num_gpus": 1,
+30 -1
View File
@@ -136,6 +136,7 @@ class JsonLogger(Logger):
def tf2_compat_logger(config, logdir):
"""Chooses TensorBoard logger depending on imported TF version."""
global tf
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
@@ -153,6 +154,16 @@ def tf2_compat_logger(config, logdir):
class TF2Logger(Logger):
"""TensorBoard Logger for TF version >= 1.14.
Automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
If you need to do more advanced logging, it is recommended
to use a Summary Writer in the Trainable yourself.
"""
def _init(self):
self._file_writer = None
@@ -202,6 +213,16 @@ def to_tf_values(result, path):
class TFLogger(Logger):
"""TensorBoard Logger for TF version < 1.14.
Automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
If you need to do more advanced logging, it is recommended
to use a Summary Writer in the Trainable yourself.
"""
def _init(self):
logger.info("Initializing TFLogger instead of TF2Logger.")
self._file_writer = tf.compat.v1.summary.FileWriter(self.logdir)
@@ -232,9 +253,17 @@ class TFLogger(Logger):
class CSVLogger(Logger):
"""Logs results to progress.csv under the trial directory.
Automatically flattens nested dicts in the result dict before writing
to csv:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
def _init(self):
"""CSV outputted with Headers as first set of results."""
# Note that we assume params.json was already created by JsonLogger
progress_file = os.path.join(self.logdir, EXPR_PROGRESS_FILE)
self._continuing = os.path.exists(progress_file)
self._file = open(progress_file, "a")
+25 -5
View File
@@ -33,7 +33,12 @@ def function(func):
def uniform(*args, **kwargs):
"""A wrapper around np.random.uniform."""
"""Wraps tune.sample_from around ``np.random.uniform``.
``tune.uniform(1, 10)`` is equivalent to
``tune.sample_from(lambda _: np.random.uniform(1, 10))``
"""
return sample_from(lambda _: np.random.uniform(*args, **kwargs))
@@ -44,7 +49,7 @@ def loguniform(min_bound, max_bound, base=10):
min_bound (float): Lower boundary of the output interval (1e-4)
max_bound (float): Upper boundary of the output interval (1e-2)
base (float): Base of the log. Defaults to 10.
"""
"""
logmin = np.log(min_bound) / np.log(base)
logmax = np.log(max_bound) / np.log(base)
@@ -55,15 +60,30 @@ def loguniform(min_bound, max_bound, base=10):
def choice(*args, **kwargs):
"""A wrapper around np.random.choice."""
"""Wraps tune.sample_from around ``np.random.choice``.
``tune.choice(10)`` is equivalent to
``tune.sample_from(lambda _: np.random.choice(10))``
"""
return sample_from(lambda _: np.random.choice(*args, **kwargs))
def randint(*args, **kwargs):
"""A wrapper around np.random.randint."""
"""Wraps tune.sample_from around ``np.random.randint``.
``tune.randint(10)`` is equivalent to
``tune.sample_from(lambda _: np.random.randint(10))``
"""
return sample_from(lambda _: np.random.randint(*args, **kwargs))
def randn(*args, **kwargs):
"""A wrapper around np.random.randn."""
"""Wraps tune.sample_from around ``np.random.randn``.
``tune.randn(10)`` is equivalent to
``tune.sample_from(lambda _: np.random.randn(10))``
"""
return sample_from(lambda _: np.random.randn(*args, **kwargs))
+1
View File
@@ -84,6 +84,7 @@ class BasicVariantGenerator(SearchAlgorithm):
spec,
output_path,
self._parser,
evaluated_params=resolved_vars,
experiment_tag=experiment_tag)
def is_finished(self):
+1
View File
@@ -98,6 +98,7 @@ class SuggestionAlgorithm(SearchAlgorithm):
spec,
output_path,
self._parser,
evaluated_params=list(suggested_config),
experiment_tag=tag,
trial_id=trial_id)
+78 -21
View File
@@ -40,14 +40,11 @@ class Trainable(object):
Calling ``save()`` should save the training state of a trainable to disk,
and ``restore(path)`` should restore a trainable to the given state.
Generally you only need to implement ``_train``, ``_save``, and
``_restore`` here when subclassing Trainable.
Generally you only need to implement ``_setup``, ``_train``,
``_save``, and ``_restore`` when subclassing Trainable.
Note that, if you don't require checkpoint/restore functionality, then
instead of implementing this class you can also get away with supplying
just a ``my_train(config, reporter)`` function to the config.
The function will be automatically converted to this interface
(sans checkpoint functionality).
Other implementation methods that may be helpful to override are
``_log_result``, ``reset_config``, ``_stop``, and ``_export_model``.
When using Tune, Tune will convert this class into a Ray actor, which
runs on a separate process. Tune will also change the current working
@@ -112,6 +109,14 @@ class Trainable(object):
This can be overriden by sub-classes to set the correct trial resource
allocation, so the user does not need to.
Example:
>>> def default_resource_request(cls, config):
return Resources(
cpu=0,
gpu=0,
extra_cpu=config["workers"],
extra_gpu=int(config["use_gpu"]) * config["workers"])
"""
return None
@@ -451,7 +456,7 @@ class Trainable(object):
The return value will be automatically passed to the loggers. Users
can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT`
to manually trigger termination of this trial or checkpointing of this
as a key to manually trigger termination or checkpointing of this
trial. Note that manual checkpointing only works when subclassing
Trainables.
@@ -462,26 +467,38 @@ class Trainable(object):
raise NotImplementedError
def _save(self, checkpoint_dir):
"""Subclasses should override this to implement save().
def _save(self, tmp_checkpoint_dir):
"""Subclasses should override this to implement ``save()``.
Warning:
Do not rely on absolute paths in the implementation of ``_save``
and ``_restore``.
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors
before execution.
>>> from ray.tune.util import validate_save_restore
>>> validate_save_restore(MyTrainableClass)
>>> validate_save_restore(MyTrainableClass, use_object_store=True)
Args:
checkpoint_dir (str): The directory where the checkpoint
file must be stored. In a Tune run, this defaults to
`<self.logdir>/checkpoint_<ITER>` (which is the same as
`local_dir/exp_name/trial_name/checkpoint_<ITER>`).
tmp_checkpoint_dir (str): The directory where the checkpoint
file must be stored. In a Tune run, if the trial is paused,
the provided path may be temporary and moved.
Returns:
checkpoint (str | dict): If string, the return value is
expected to be the checkpoint path or prefix to be passed to
`_restore()`. If dict, the return value will be automatically
serialized by Tune and passed to `_restore()`.
A dict or string. If string, the return value is expected to be
prefixed by `tmp_checkpoint_dir`. If dict, the return value will
be automatically serialized by Tune and passed to `_restore()`.
Examples:
>>> print(trainable1._save("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2._save("/tmp/checkpoint_2"))
{"some": "data"}
>>> trainable._save("/tmp/bad_example")
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
"""
raise NotImplementedError
@@ -489,9 +506,42 @@ class Trainable(object):
def _restore(self, checkpoint):
"""Subclasses should override this to implement restore().
Warning:
In this method, do not rely on absolute paths. The absolute
path of the checkpoint_dir used in ``_save`` may be changed.
If ``_save`` returned a prefixed string, the prefix of the checkpoint
string returned by ``_save`` may be changed. This is because trial
pausing depends on temporary directories.
The directory structure under the checkpoint_dir provided to ``_save``
is preserved.
See the example below.
.. code-block:: python
class Example(Trainable):
def _save(self, checkpoint_path):
print(checkpoint_path)
return os.path.join(checkpoint_path, "my/check/point")
def _restore(self, checkpoint):
print(checkpoint)
>>> trainer = Example()
>>> obj = trainer.save_to_object() # This is used when PAUSED.
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
>>> trainer.restore_from_object(obj) # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
Args:
checkpoint (str | dict): Value as returned by `_save`.
If a string, then it is the checkpoint path.
checkpoint (str|dict): If dict, the return value is as
returned by `_save`. If a string, then it is a checkpoint path
that may have a different prefix than that returned by `_save`.
The directory structure underneath the `checkpoint_dir`
`_save` is preserved.
"""
raise NotImplementedError
@@ -514,7 +564,14 @@ class Trainable(object):
self._result_logger.on_result(result)
def _stop(self):
"""Subclasses should override this for any cleanup on stop."""
"""Subclasses should override this for any cleanup on stop.
If any Ray actors are launched in the Trainable (i.e., with a RLlib
trainer), be sure to kill the Ray actor process here.
You can kill a Ray actor by calling `actor.__ray_terminate__.remote()`
on the actor.
"""
pass
def _export_model(self, export_formats, export_dir):
+4
View File
@@ -108,6 +108,7 @@ class Trial(object):
config=None,
trial_id=None,
local_dir=DEFAULT_RESULTS_DIR,
evaluated_params=None,
experiment_tag="",
resources=None,
stopping_criterion=None,
@@ -133,6 +134,9 @@ class Trial(object):
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
self.config = config or {}
self.local_dir = local_dir # This remains unexpanded for syncing.
#: Parameters that Tune varies across searches.
self.evaluated_params = evaluated_params or []
self.experiment_tag = experiment_tag
trainable_cls = self._get_trainable_cls()
if trainable_cls and hasattr(trainable_cls,