mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:19:02 +08:00
[tune] Update trainable docs and support hparams (#5558)
This commit is contained in:
@@ -4,6 +4,9 @@
|
||||
Note that this requires a cluster with at least 8 GPUs in order for all trials
|
||||
to run concurrently, otherwise PBT will round-robin train the trials which
|
||||
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
|
||||
|
||||
Note that Tune in general does not need 8 GPUs, and this is just a more
|
||||
computationally demainding example.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@@ -51,9 +54,9 @@ if __name__ == "__main__":
|
||||
name="pbt_humanoid_test",
|
||||
scheduler=pbt,
|
||||
**{
|
||||
"env": "Humanoid-v1",
|
||||
"num_samples": 8,
|
||||
"config": {
|
||||
"env": "Humanoid-v1",
|
||||
"kl_coeff": 1.0,
|
||||
"num_workers": 8,
|
||||
"num_gpus": 1,
|
||||
|
||||
@@ -136,6 +136,7 @@ class JsonLogger(Logger):
|
||||
|
||||
|
||||
def tf2_compat_logger(config, logdir):
|
||||
"""Chooses TensorBoard logger depending on imported TF version."""
|
||||
global tf
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning("Not importing TensorFlow for test purposes")
|
||||
@@ -153,6 +154,16 @@ def tf2_compat_logger(config, logdir):
|
||||
|
||||
|
||||
class TF2Logger(Logger):
|
||||
"""TensorBoard Logger for TF version >= 1.14.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
|
||||
If you need to do more advanced logging, it is recommended
|
||||
to use a Summary Writer in the Trainable yourself.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
self._file_writer = None
|
||||
|
||||
@@ -202,6 +213,16 @@ def to_tf_values(result, path):
|
||||
|
||||
|
||||
class TFLogger(Logger):
|
||||
"""TensorBoard Logger for TF version < 1.14.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
|
||||
If you need to do more advanced logging, it is recommended
|
||||
to use a Summary Writer in the Trainable yourself.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
logger.info("Initializing TFLogger instead of TF2Logger.")
|
||||
self._file_writer = tf.compat.v1.summary.FileWriter(self.logdir)
|
||||
@@ -232,9 +253,17 @@ class TFLogger(Logger):
|
||||
|
||||
|
||||
class CSVLogger(Logger):
|
||||
"""Logs results to progress.csv under the trial directory.
|
||||
|
||||
Automatically flattens nested dicts in the result dict before writing
|
||||
to csv:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
"""CSV outputted with Headers as first set of results."""
|
||||
# Note that we assume params.json was already created by JsonLogger
|
||||
progress_file = os.path.join(self.logdir, EXPR_PROGRESS_FILE)
|
||||
self._continuing = os.path.exists(progress_file)
|
||||
self._file = open(progress_file, "a")
|
||||
|
||||
@@ -33,7 +33,12 @@ def function(func):
|
||||
|
||||
|
||||
def uniform(*args, **kwargs):
|
||||
"""A wrapper around np.random.uniform."""
|
||||
"""Wraps tune.sample_from around ``np.random.uniform``.
|
||||
|
||||
``tune.uniform(1, 10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.uniform(1, 10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.uniform(*args, **kwargs))
|
||||
|
||||
|
||||
@@ -44,7 +49,7 @@ def loguniform(min_bound, max_bound, base=10):
|
||||
min_bound (float): Lower boundary of the output interval (1e-4)
|
||||
max_bound (float): Upper boundary of the output interval (1e-2)
|
||||
base (float): Base of the log. Defaults to 10.
|
||||
"""
|
||||
"""
|
||||
logmin = np.log(min_bound) / np.log(base)
|
||||
logmax = np.log(max_bound) / np.log(base)
|
||||
|
||||
@@ -55,15 +60,30 @@ def loguniform(min_bound, max_bound, base=10):
|
||||
|
||||
|
||||
def choice(*args, **kwargs):
|
||||
"""A wrapper around np.random.choice."""
|
||||
"""Wraps tune.sample_from around ``np.random.choice``.
|
||||
|
||||
``tune.choice(10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.choice(10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.choice(*args, **kwargs))
|
||||
|
||||
|
||||
def randint(*args, **kwargs):
|
||||
"""A wrapper around np.random.randint."""
|
||||
"""Wraps tune.sample_from around ``np.random.randint``.
|
||||
|
||||
``tune.randint(10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.randint(10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.randint(*args, **kwargs))
|
||||
|
||||
|
||||
def randn(*args, **kwargs):
|
||||
"""A wrapper around np.random.randn."""
|
||||
"""Wraps tune.sample_from around ``np.random.randn``.
|
||||
|
||||
``tune.randn(10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.randn(10))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.randn(*args, **kwargs))
|
||||
|
||||
@@ -84,6 +84,7 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=resolved_vars,
|
||||
experiment_tag=experiment_tag)
|
||||
|
||||
def is_finished(self):
|
||||
|
||||
@@ -98,6 +98,7 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=list(suggested_config),
|
||||
experiment_tag=tag,
|
||||
trial_id=trial_id)
|
||||
|
||||
|
||||
@@ -40,14 +40,11 @@ class Trainable(object):
|
||||
Calling ``save()`` should save the training state of a trainable to disk,
|
||||
and ``restore(path)`` should restore a trainable to the given state.
|
||||
|
||||
Generally you only need to implement ``_train``, ``_save``, and
|
||||
``_restore`` here when subclassing Trainable.
|
||||
Generally you only need to implement ``_setup``, ``_train``,
|
||||
``_save``, and ``_restore`` when subclassing Trainable.
|
||||
|
||||
Note that, if you don't require checkpoint/restore functionality, then
|
||||
instead of implementing this class you can also get away with supplying
|
||||
just a ``my_train(config, reporter)`` function to the config.
|
||||
The function will be automatically converted to this interface
|
||||
(sans checkpoint functionality).
|
||||
Other implementation methods that may be helpful to override are
|
||||
``_log_result``, ``reset_config``, ``_stop``, and ``_export_model``.
|
||||
|
||||
When using Tune, Tune will convert this class into a Ray actor, which
|
||||
runs on a separate process. Tune will also change the current working
|
||||
@@ -112,6 +109,14 @@ class Trainable(object):
|
||||
|
||||
This can be overriden by sub-classes to set the correct trial resource
|
||||
allocation, so the user does not need to.
|
||||
|
||||
Example:
|
||||
>>> def default_resource_request(cls, config):
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=config["workers"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["workers"])
|
||||
"""
|
||||
|
||||
return None
|
||||
@@ -451,7 +456,7 @@ class Trainable(object):
|
||||
|
||||
The return value will be automatically passed to the loggers. Users
|
||||
can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT`
|
||||
to manually trigger termination of this trial or checkpointing of this
|
||||
as a key to manually trigger termination or checkpointing of this
|
||||
trial. Note that manual checkpointing only works when subclassing
|
||||
Trainables.
|
||||
|
||||
@@ -462,26 +467,38 @@ class Trainable(object):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
"""Subclasses should override this to implement save().
|
||||
def _save(self, tmp_checkpoint_dir):
|
||||
"""Subclasses should override this to implement ``save()``.
|
||||
|
||||
Warning:
|
||||
Do not rely on absolute paths in the implementation of ``_save``
|
||||
and ``_restore``.
|
||||
|
||||
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors
|
||||
before execution.
|
||||
|
||||
>>> from ray.tune.util import validate_save_restore
|
||||
>>> validate_save_restore(MyTrainableClass)
|
||||
>>> validate_save_restore(MyTrainableClass, use_object_store=True)
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): The directory where the checkpoint
|
||||
file must be stored. In a Tune run, this defaults to
|
||||
`<self.logdir>/checkpoint_<ITER>` (which is the same as
|
||||
`local_dir/exp_name/trial_name/checkpoint_<ITER>`).
|
||||
tmp_checkpoint_dir (str): The directory where the checkpoint
|
||||
file must be stored. In a Tune run, if the trial is paused,
|
||||
the provided path may be temporary and moved.
|
||||
|
||||
Returns:
|
||||
checkpoint (str | dict): If string, the return value is
|
||||
expected to be the checkpoint path or prefix to be passed to
|
||||
`_restore()`. If dict, the return value will be automatically
|
||||
serialized by Tune and passed to `_restore()`.
|
||||
A dict or string. If string, the return value is expected to be
|
||||
prefixed by `tmp_checkpoint_dir`. If dict, the return value will
|
||||
be automatically serialized by Tune and passed to `_restore()`.
|
||||
|
||||
Examples:
|
||||
>>> print(trainable1._save("/tmp/checkpoint_1"))
|
||||
"/tmp/checkpoint_1/my_checkpoint_file"
|
||||
>>> print(trainable2._save("/tmp/checkpoint_2"))
|
||||
{"some": "data"}
|
||||
|
||||
>>> trainable._save("/tmp/bad_example")
|
||||
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -489,9 +506,42 @@ class Trainable(object):
|
||||
def _restore(self, checkpoint):
|
||||
"""Subclasses should override this to implement restore().
|
||||
|
||||
Warning:
|
||||
In this method, do not rely on absolute paths. The absolute
|
||||
path of the checkpoint_dir used in ``_save`` may be changed.
|
||||
|
||||
If ``_save`` returned a prefixed string, the prefix of the checkpoint
|
||||
string returned by ``_save`` may be changed. This is because trial
|
||||
pausing depends on temporary directories.
|
||||
|
||||
The directory structure under the checkpoint_dir provided to ``_save``
|
||||
is preserved.
|
||||
|
||||
See the example below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Example(Trainable):
|
||||
def _save(self, checkpoint_path):
|
||||
print(checkpoint_path)
|
||||
return os.path.join(checkpoint_path, "my/check/point")
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
print(checkpoint)
|
||||
|
||||
>>> trainer = Example()
|
||||
>>> obj = trainer.save_to_object() # This is used when PAUSED.
|
||||
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
|
||||
>>> trainer.restore_from_object(obj) # Note the different prefix.
|
||||
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
|
||||
|
||||
|
||||
Args:
|
||||
checkpoint (str | dict): Value as returned by `_save`.
|
||||
If a string, then it is the checkpoint path.
|
||||
checkpoint (str|dict): If dict, the return value is as
|
||||
returned by `_save`. If a string, then it is a checkpoint path
|
||||
that may have a different prefix than that returned by `_save`.
|
||||
The directory structure underneath the `checkpoint_dir`
|
||||
`_save` is preserved.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -514,7 +564,14 @@ class Trainable(object):
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
def _stop(self):
|
||||
"""Subclasses should override this for any cleanup on stop."""
|
||||
"""Subclasses should override this for any cleanup on stop.
|
||||
|
||||
If any Ray actors are launched in the Trainable (i.e., with a RLlib
|
||||
trainer), be sure to kill the Ray actor process here.
|
||||
|
||||
You can kill a Ray actor by calling `actor.__ray_terminate__.remote()`
|
||||
on the actor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
|
||||
@@ -108,6 +108,7 @@ class Trial(object):
|
||||
config=None,
|
||||
trial_id=None,
|
||||
local_dir=DEFAULT_RESULTS_DIR,
|
||||
evaluated_params=None,
|
||||
experiment_tag="",
|
||||
resources=None,
|
||||
stopping_criterion=None,
|
||||
@@ -133,6 +134,9 @@ class Trial(object):
|
||||
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
|
||||
self.config = config or {}
|
||||
self.local_dir = local_dir # This remains unexpanded for syncing.
|
||||
|
||||
#: Parameters that Tune varies across searches.
|
||||
self.evaluated_params = evaluated_params or []
|
||||
self.experiment_tag = experiment_tag
|
||||
trainable_cls = self._get_trainable_cls()
|
||||
if trainable_cls and hasattr(trainable_cls,
|
||||
|
||||
Reference in New Issue
Block a user