mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:02:16 +08:00
[tune] Custom Logging, Trial Name (#3465)
Adds support for custom loggers, custom trial strings, and custom sync commands. Closes #3034, #2985, and #3390.
This commit is contained in:
@@ -106,6 +106,21 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional URI to sync training results to (e.g. s3://bucket).")
|
||||
parser.add_argument(
|
||||
"--trial-name-creator",
|
||||
default=None,
|
||||
help="Optional creator function for the trial string, used in "
|
||||
"generating a trial directory.")
|
||||
parser.add_argument(
|
||||
"--sync-function",
|
||||
default=None,
|
||||
help="Function for syncing the local_dir to upload_dir. If string, "
|
||||
"then it must be a string template for syncer to run and needs to "
|
||||
"include replacement fields '{local_dir}' and '{remote_dir}'.")
|
||||
parser.add_argument(
|
||||
"--custom-loggers",
|
||||
default=None,
|
||||
help="List of custom logger creators to be used with each Trial.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq",
|
||||
default=0,
|
||||
@@ -198,5 +213,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
trial_name_creator=spec.get("trial_name_creator"),
|
||||
custom_loggers=spec.get("custom_loggers"),
|
||||
# str(None) doesn't create None
|
||||
sync_function=spec.get("sync_function"),
|
||||
max_failures=args.max_failures,
|
||||
**trial_kwargs)
|
||||
|
||||
@@ -22,6 +22,8 @@ General Examples
|
||||
Example of using a Trainable class with PopulationBasedTraining scheduler.
|
||||
- `pbt_ppo_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__:
|
||||
Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler.
|
||||
- `logging_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/logging_example.py>`__:
|
||||
Example of custom loggers and custom trial directory naming.
|
||||
|
||||
|
||||
Keras Examples
|
||||
|
||||
Executable
+76
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import Trainable, run_experiments, Experiment
|
||||
|
||||
|
||||
class TestLogger(tune.logger.Logger):
|
||||
def on_result(self, result):
|
||||
print("TestLogger", result)
|
||||
|
||||
|
||||
def trial_str_creator(trial):
|
||||
return "{}_{}_123".format(trial.trainable_name, trial.trial_id)
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Example agent whose learning curve is a random sigmoid.
|
||||
|
||||
The dummy hyperparameters "width" and "height" determine the slope and
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": v}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
exp = Experiment(
|
||||
name="hyperband_test",
|
||||
run=MyTrainableClass,
|
||||
num_samples=1,
|
||||
trial_name_creator=tune.function(trial_str_creator),
|
||||
custom_loggers=[TestLogger],
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
config={
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random())
|
||||
})
|
||||
|
||||
trials = run_experiments(exp)
|
||||
@@ -7,9 +7,10 @@ import logging
|
||||
import six
|
||||
import types
|
||||
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.log_sync import validate_sync_function
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,6 +45,14 @@ class Experiment(object):
|
||||
Defaults to ``~/ray_results``.
|
||||
upload_dir (str): Optional URI to sync training results
|
||||
to (e.g. ``s3://bucket``).
|
||||
trial_name_creator (func): Optional function for generating
|
||||
the trial string representation.
|
||||
custom_loggers (list): List of custom logger creators to be used with
|
||||
each Trial. See `ray/tune/logger.py`.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
upload_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, the sync command defaults
|
||||
to standard S3 or gsutil sync comamnds.
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
@@ -86,10 +95,16 @@ class Experiment(object):
|
||||
num_samples=1,
|
||||
local_dir=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
custom_loggers=None,
|
||||
sync_function=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
max_failures=3,
|
||||
restore=None):
|
||||
validate_sync_function(sync_function)
|
||||
if sync_function:
|
||||
assert upload_dir, "Need `upload_dir` if sync_function given."
|
||||
spec = {
|
||||
"run": self._register_if_needed(run),
|
||||
"stop": stop or {},
|
||||
@@ -98,6 +113,9 @@ class Experiment(object):
|
||||
"num_samples": num_samples,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir or "", # argparse converts None to "null"
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"custom_loggers": custom_loggers,
|
||||
"sync_function": sync_function or "", # See `upload_dir`.
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"max_failures": max_failures,
|
||||
|
||||
+68
-11
@@ -7,6 +7,7 @@ import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import types
|
||||
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
@@ -17,6 +18,7 @@ import ray
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.suggest.variant_generator import function as tune_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,9 +30,9 @@ GCS_PREFIX = "gs://"
|
||||
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX)
|
||||
|
||||
|
||||
def get_syncer(local_dir, remote_dir=None):
|
||||
def get_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
if remote_dir:
|
||||
if not any(
|
||||
if not sync_function and not any(
|
||||
remote_dir.startswith(prefix)
|
||||
for prefix in ALLOWED_REMOTE_PREFIXES):
|
||||
raise TuneError("Upload uri must start with one of: {}"
|
||||
@@ -53,7 +55,7 @@ def get_syncer(local_dir, remote_dir=None):
|
||||
|
||||
key = (local_dir, remote_dir)
|
||||
if key not in _syncers:
|
||||
_syncers[key] = _LogSyncer(local_dir, remote_dir)
|
||||
_syncers[key] = _LogSyncer(local_dir, remote_dir, sync_function)
|
||||
|
||||
return _syncers[key]
|
||||
|
||||
@@ -63,15 +65,47 @@ def wait_for_log_sync():
|
||||
syncer.wait()
|
||||
|
||||
|
||||
def validate_sync_function(sync_function):
|
||||
if sync_function is None:
|
||||
return
|
||||
elif isinstance(sync_function, str):
|
||||
assert "{remote_dir}" in sync_function, (
|
||||
"Sync template missing '{remote_dir}'.")
|
||||
assert "{local_dir}" in sync_function, (
|
||||
"Sync template missing '{local_dir}'.")
|
||||
elif not (isinstance(sync_function, types.FunctionType)
|
||||
or isinstance(sync_function, tune_function)):
|
||||
raise ValueError("Sync function {} must be string or function".format(
|
||||
sync_function))
|
||||
|
||||
|
||||
class _LogSyncer(object):
|
||||
"""Log syncer for tune.
|
||||
|
||||
This syncs files from workers to the local node, and optionally also from
|
||||
the local node to a remote directory (e.g. S3)."""
|
||||
the local node to a remote directory (e.g. S3).
|
||||
|
||||
def __init__(self, local_dir, remote_dir=None):
|
||||
Arguments:
|
||||
logdir (str): Directory to sync from.
|
||||
upload_uri (str): Directory to sync to.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
upload_dir. If string, then it must be a string template
|
||||
for syncer to run and needs to include replacement fields
|
||||
'{local_dir}' and '{remote_dir}'.
|
||||
"""
|
||||
|
||||
def __init__(self, local_dir, remote_dir=None, sync_function=None):
|
||||
self.local_dir = local_dir
|
||||
self.remote_dir = remote_dir
|
||||
|
||||
# Resolve sync_function into template or function
|
||||
self.sync_func = None
|
||||
self.sync_cmd_tmpl = None
|
||||
if isinstance(sync_function, types.FunctionType) or isinstance(
|
||||
sync_function, tune_function):
|
||||
self.sync_func = sync_function
|
||||
elif isinstance(sync_function, str):
|
||||
self.sync_cmd_tmpl = sync_function
|
||||
self.last_sync_time = 0
|
||||
self.sync_process = None
|
||||
self.local_ip = ray.services.get_node_ip_address()
|
||||
@@ -116,12 +150,14 @@ class _LogSyncer(object):
|
||||
quote(ssh_key), quote(source), quote(target)))
|
||||
|
||||
if self.remote_dir:
|
||||
if self.remote_dir.startswith(S3_PREFIX):
|
||||
local_to_remote_sync_cmd = ("aws s3 sync {} {}".format(
|
||||
quote(self.local_dir), quote(self.remote_dir)))
|
||||
elif self.remote_dir.startswith(GCS_PREFIX):
|
||||
local_to_remote_sync_cmd = ("gsutil rsync -r {} {}".format(
|
||||
quote(self.local_dir), quote(self.remote_dir)))
|
||||
if self.sync_func:
|
||||
local_to_remote_sync_cmd = None
|
||||
try:
|
||||
self.sync_func(self.local_dir, self.remote_dir)
|
||||
except Exception:
|
||||
logger.exception("Sync function failed.")
|
||||
else:
|
||||
local_to_remote_sync_cmd = self.get_remote_sync_cmd()
|
||||
else:
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
@@ -148,3 +184,24 @@ class _LogSyncer(object):
|
||||
def wait(self):
|
||||
if self.sync_process:
|
||||
self.sync_process.wait()
|
||||
|
||||
def get_remote_sync_cmd(self):
|
||||
if self.sync_cmd_tmpl:
|
||||
local_to_remote_sync_cmd = (self.sync_cmd_tmpl.format(
|
||||
local_dir=quote(self.local_dir),
|
||||
remote_dir=quote(self.remote_dir)))
|
||||
elif self.remote_dir.startswith(S3_PREFIX):
|
||||
local_to_remote_sync_cmd = (
|
||||
"aws s3 sync {local_dir} {remote_dir}".format(
|
||||
local_dir=quote(self.local_dir),
|
||||
remote_dir=quote(self.remote_dir)))
|
||||
elif self.remote_dir.startswith(GCS_PREFIX):
|
||||
local_to_remote_sync_cmd = (
|
||||
"gsutil rsync -r {local_dir} {remote_dir}".format(
|
||||
local_dir=quote(self.local_dir),
|
||||
remote_dir=quote(self.remote_dir)))
|
||||
else:
|
||||
logger.warning("Remote sync unsupported, skipping.")
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
return local_to_remote_sync_cmd
|
||||
|
||||
+39
-10
@@ -25,10 +25,16 @@ except ImportError:
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""Logging interface for ray.tune; specialized implementations follow.
|
||||
"""Logging interface for ray.tune.
|
||||
|
||||
By default, the UnifiedLogger implementation is used which logs results in
|
||||
multiple formats (TensorBoard, rllab/viskit, plain json) at once.
|
||||
multiple formats (TensorBoard, rllab/viskit, plain json, custom loggers)
|
||||
at once.
|
||||
|
||||
Arguments:
|
||||
config: Configuration passed to all logger creators.
|
||||
logdir: Directory for all logger creators to log to.
|
||||
upload_uri (str): Optional URI where the logdir is sync'ed to.
|
||||
"""
|
||||
|
||||
def __init__(self, config, logdir, upload_uri=None):
|
||||
@@ -59,17 +65,40 @@ class Logger(object):
|
||||
class UnifiedLogger(Logger):
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
|
||||
|
||||
This class also periodically syncs output to the given upload uri."""
|
||||
This class also periodically syncs output to the given upload uri.
|
||||
|
||||
Arguments:
|
||||
config: Configuration passed to all logger creators.
|
||||
logdir: Directory for all logger creators to log to.
|
||||
upload_uri (str): Optional URI where the logdir is sync'ed to.
|
||||
custom_loggers (list): List of custom logger creators.
|
||||
sync_function (func|str): Optional function for syncer to run.
|
||||
See ray/python/ray/tune/log_sync.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
logdir,
|
||||
upload_uri=None,
|
||||
custom_loggers=None,
|
||||
sync_function=None):
|
||||
self._logger_list = [_JsonLogger, _TFLogger, _VisKitLogger]
|
||||
self._sync_function = sync_function
|
||||
if custom_loggers:
|
||||
assert isinstance(custom_loggers, list), "Improper custom loggers."
|
||||
self._logger_list += custom_loggers
|
||||
|
||||
Logger.__init__(self, config, logdir, upload_uri)
|
||||
|
||||
def _init(self):
|
||||
self._loggers = []
|
||||
for cls in [_JsonLogger, _TFLogger, _VisKitLogger]:
|
||||
if cls is _TFLogger and tf is None:
|
||||
logger.info("TF not installed - "
|
||||
"cannot log with {}...".format(cls))
|
||||
continue
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
self._log_syncer = get_syncer(self.logdir, self.uri)
|
||||
for cls in self._logger_list:
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
except Exception:
|
||||
logger.exception("Could not instantiate {} - skipping.")
|
||||
self._log_syncer = get_syncer(
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
|
||||
def on_result(self, result):
|
||||
for logger in self._loggers:
|
||||
|
||||
@@ -10,6 +10,7 @@ import unittest
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
@@ -17,6 +18,7 @@ from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
EPISODES_TOTAL)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
@@ -679,6 +681,83 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
def testCustomLogger(self):
|
||||
class CustomLogger(Logger):
|
||||
def on_result(self, result):
|
||||
with open(os.path.join(self.logdir, "test.log"), "w") as f:
|
||||
f.write("hi")
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"custom_loggers": [CustomLogger]
|
||||
}
|
||||
})
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
|
||||
def testCustomTrialString(self):
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"trial_name_creator": tune.function(
|
||||
lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id))
|
||||
}
|
||||
})
|
||||
self.assertEquals(
|
||||
str(trial), "{}_{}_321".format(trial.trainable_name,
|
||||
trial.trial_id))
|
||||
|
||||
def testSyncFunction(self):
|
||||
def fail_sync_local():
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_function": "ls {remote_dir}"
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(AssertionError, fail_sync_local)
|
||||
|
||||
def fail_sync_remote():
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_function": "ls {local_dir}"
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(AssertionError, fail_sync_remote)
|
||||
|
||||
def sync_func(local, remote):
|
||||
with open(os.path.join(local, "test.log"), "w") as f:
|
||||
f.write(remote)
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_function": tune.function(sync_func)
|
||||
}
|
||||
})
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -574,6 +574,7 @@ class _MockTrial(Trial):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
self.config = config
|
||||
self.experiment_tag = "tag"
|
||||
self.trial_name_creator = None
|
||||
self.logger_running = False
|
||||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
|
||||
@@ -124,6 +124,9 @@ class Trial(object):
|
||||
checkpoint_at_end=False,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
custom_loggers=None,
|
||||
sync_function=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
|
||||
@@ -146,6 +149,9 @@ class Trial(object):
|
||||
or self._get_trainable_cls().default_resource_request(self.config))
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.upload_dir = upload_dir
|
||||
self.trial_name_creator = trial_name_creator
|
||||
self.custom_loggers = custom_loggers
|
||||
self.sync_function = sync_function
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
@@ -160,10 +166,7 @@ class Trial(object):
|
||||
self.logdir = None
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
if trial_id is not None:
|
||||
self.trial_id = trial_id
|
||||
else:
|
||||
self.trial_id = Trial.generate_id()
|
||||
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
@@ -181,8 +184,12 @@ class Trial(object):
|
||||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
self.result_logger = UnifiedLogger(self.config, self.logdir,
|
||||
self.upload_dir)
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config,
|
||||
self.logdir,
|
||||
upload_uri=self.upload_dir,
|
||||
custom_loggers=self.custom_loggers,
|
||||
sync_function=self.sync_function)
|
||||
|
||||
def close_logger(self):
|
||||
"""Close logger."""
|
||||
@@ -316,7 +323,13 @@ class Trial(object):
|
||||
return str(self)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
|
||||
|
||||
Can be overriden with a custom string creator.
|
||||
"""
|
||||
if self.trial_name_creator:
|
||||
return self.trial_name_creator(self)
|
||||
|
||||
if "env" in self.config:
|
||||
env = self.config["env"]
|
||||
if isinstance(env, type):
|
||||
|
||||
Reference in New Issue
Block a user