mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 01:59:56 +08:00
[tune] Introduce ability to turn off default logging. (#4104)
This commit is contained in:
@@ -10,13 +10,7 @@ from ray.tune.trainable import Trainable
|
||||
from ray.tune.suggest import grid_search, function, sample_from
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run_experiments",
|
||||
"Experiment",
|
||||
"function",
|
||||
"sample_from",
|
||||
"Trainable", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run_experiments", "Experiment", "function",
|
||||
"sample_from"
|
||||
]
|
||||
|
||||
@@ -88,9 +88,10 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
"then it must be a string template for syncer to run and needs to "
|
||||
"include replacement fields '{local_dir}' and '{remote_dir}'.")
|
||||
parser.add_argument(
|
||||
"--custom-loggers",
|
||||
"--loggers",
|
||||
default=None,
|
||||
help="List of custom logger creators to be used with each Trial.")
|
||||
help="List of logger creators to be used with each Trial. "
|
||||
"Defaults to ray.tune.logger.DEFAULT_LOGGERS.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq",
|
||||
default=0,
|
||||
@@ -192,7 +193,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
trial_name_creator=spec.get("trial_name_creator"),
|
||||
custom_loggers=spec.get("custom_loggers"),
|
||||
loggers=spec.get("loggers"),
|
||||
# str(None) doesn't create None
|
||||
sync_function=spec.get("sync_function"),
|
||||
max_failures=args.max_failures,
|
||||
|
||||
@@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
run=MyTrainableClass,
|
||||
num_samples=1,
|
||||
trial_name_creator=tune.function(trial_str_creator),
|
||||
custom_loggers=[TestLogger],
|
||||
loggers=[TestLogger],
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
config={
|
||||
"width": tune.sample_from(
|
||||
|
||||
@@ -62,8 +62,9 @@ class Experiment(object):
|
||||
to (e.g. ``s3://bucket``).
|
||||
trial_name_creator (func): Optional function for generating
|
||||
the trial string representation.
|
||||
custom_loggers (list): List of custom logger creators to be used with
|
||||
each Trial. See `ray/tune/logger.py`.
|
||||
loggers (list): List of logger creators to be used with
|
||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||
See `ray/tune/logger.py`.
|
||||
sync_function (func|str): Function for syncing the local_dir to
|
||||
upload_dir. If string, then it must be a string template for
|
||||
syncer to run. If not provided, the sync command defaults
|
||||
@@ -84,6 +85,8 @@ class Experiment(object):
|
||||
Ray. Use `num_samples` instead.
|
||||
trial_resources: Deprecated and will be removed in future versions of
|
||||
Ray. Use `resources_per_trial` instead.
|
||||
custom_loggers: Deprecated and will be removed in future versions of
|
||||
Ray. Use `loggers` instead.
|
||||
|
||||
|
||||
Examples:
|
||||
@@ -117,6 +120,7 @@ class Experiment(object):
|
||||
local_dir=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
loggers=None,
|
||||
custom_loggers=None,
|
||||
sync_function=None,
|
||||
checkpoint_freq=0,
|
||||
@@ -135,6 +139,8 @@ class Experiment(object):
|
||||
_raise_deprecation_note(
|
||||
"trial_resources", "resources_per_trial", soft=True)
|
||||
resources_per_trial = trial_resources
|
||||
if custom_loggers:
|
||||
_raise_deprecation_note("custom_loggers", "loggers", soft=False)
|
||||
|
||||
spec = {
|
||||
"run": Experiment._register_if_needed(run),
|
||||
@@ -145,7 +151,7 @@ class Experiment(object):
|
||||
"local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR),
|
||||
"upload_dir": upload_dir or "", # argparse converts None to "null"
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"custom_loggers": custom_loggers,
|
||||
"loggers": loggers,
|
||||
"sync_function": sync_function,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
|
||||
+76
-72
@@ -66,81 +66,12 @@ class Logger(object):
|
||||
pass
|
||||
|
||||
|
||||
class UnifiedLogger(Logger):
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
|
||||
|
||||
This class also periodically syncs output to the given upload uri.
|
||||
|
||||
Arguments:
|
||||
config: Configuration passed to all logger creators.
|
||||
logdir: Directory for all logger creators to log to.
|
||||
upload_uri (str): Optional URI where the logdir is sync'ed to.
|
||||
custom_loggers (list): List of custom logger creators.
|
||||
sync_function (func|str): Optional function for syncer to run.
|
||||
See ray/python/ray/tune/log_sync.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
logdir,
|
||||
upload_uri=None,
|
||||
custom_loggers=None,
|
||||
sync_function=None):
|
||||
self._logger_list = [_JsonLogger, _TFLogger, _CSVLogger]
|
||||
self._sync_function = sync_function
|
||||
self._log_syncer = None
|
||||
if custom_loggers:
|
||||
assert isinstance(custom_loggers, list), "Improper custom loggers."
|
||||
self._logger_list += custom_loggers
|
||||
|
||||
Logger.__init__(self, config, logdir, upload_uri)
|
||||
|
||||
def _init(self):
|
||||
self._loggers = []
|
||||
for cls in self._logger_list:
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
except Exception:
|
||||
logger.warning("Could not instantiate {} - skipping.".format(
|
||||
str(cls)))
|
||||
self._log_syncer = get_syncer(
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
|
||||
def on_result(self, result):
|
||||
for _logger in self._loggers:
|
||||
_logger.on_result(result)
|
||||
self._log_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
self._log_syncer.sync_if_needed()
|
||||
|
||||
def close(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.close()
|
||||
|
||||
def flush(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.wait()
|
||||
|
||||
def sync_results_to_new_location(self, worker_ip):
|
||||
"""Sends the current log directory to the remote node.
|
||||
|
||||
Syncing will not occur if the cluster is not started
|
||||
with the Ray autoscaler.
|
||||
"""
|
||||
if worker_ip != self._log_syncer.worker_ip:
|
||||
self._log_syncer.set_worker_ip(worker_ip)
|
||||
self._log_syncer.sync_to_worker_if_possible()
|
||||
|
||||
|
||||
class NoopLogger(Logger):
|
||||
def on_result(self, result):
|
||||
pass
|
||||
|
||||
|
||||
class _JsonLogger(Logger):
|
||||
class JsonLogger(Logger):
|
||||
def _init(self):
|
||||
config_out = os.path.join(self.logdir, "params.json")
|
||||
with open(config_out, "w") as f:
|
||||
@@ -188,7 +119,7 @@ def to_tf_values(result, path):
|
||||
return values
|
||||
|
||||
|
||||
class _TFLogger(Logger):
|
||||
class TFLogger(Logger):
|
||||
def _init(self):
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
@@ -217,7 +148,7 @@ class _TFLogger(Logger):
|
||||
self._file_writer.close()
|
||||
|
||||
|
||||
class _CSVLogger(Logger):
|
||||
class CSVLogger(Logger):
|
||||
def _init(self):
|
||||
"""CSV outputted with Headers as first set of results."""
|
||||
# Note that we assume params.json was already created by JsonLogger
|
||||
@@ -242,6 +173,79 @@ class _CSVLogger(Logger):
|
||||
self._file.close()
|
||||
|
||||
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TFLogger)
|
||||
|
||||
|
||||
class UnifiedLogger(Logger):
|
||||
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
|
||||
|
||||
This class also periodically syncs output to the given upload uri.
|
||||
|
||||
Arguments:
|
||||
config: Configuration passed to all logger creators.
|
||||
logdir: Directory for all logger creators to log to.
|
||||
upload_uri (str): Optional URI where the logdir is sync'ed to.
|
||||
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
|
||||
and JSON loggers.
|
||||
sync_function (func|str): Optional function for syncer to run.
|
||||
See ray/python/ray/tune/log_sync.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
logdir,
|
||||
upload_uri=None,
|
||||
loggers=None,
|
||||
sync_function=None):
|
||||
if loggers is None:
|
||||
self._logger_cls_list = DEFAULT_LOGGERS
|
||||
else:
|
||||
self._logger_cls_list = loggers
|
||||
self._sync_function = sync_function
|
||||
self._log_syncer = None
|
||||
|
||||
Logger.__init__(self, config, logdir, upload_uri)
|
||||
|
||||
def _init(self):
|
||||
self._loggers = []
|
||||
for cls in self._logger_cls_list:
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
except Exception:
|
||||
logger.warning("Could not instantiate {} - skipping.".format(
|
||||
str(cls)))
|
||||
self._log_syncer = get_syncer(
|
||||
self.logdir, self.uri, sync_function=self._sync_function)
|
||||
|
||||
def on_result(self, result):
|
||||
for _logger in self._loggers:
|
||||
_logger.on_result(result)
|
||||
self._log_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
self._log_syncer.sync_if_needed()
|
||||
|
||||
def close(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.close()
|
||||
|
||||
def flush(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
self._log_syncer.sync_now(force=True)
|
||||
self._log_syncer.wait()
|
||||
|
||||
def sync_results_to_new_location(self, worker_ip):
|
||||
"""Sends the current log directory to the remote node.
|
||||
|
||||
Syncing will not occur if the cluster is not started
|
||||
with the Ray autoscaler.
|
||||
"""
|
||||
if worker_ip != self._log_syncer.worker_ip:
|
||||
self._log_syncer.set_worker_ip(worker_ip)
|
||||
self._log_syncer.sync_to_worker_if_possible()
|
||||
|
||||
|
||||
class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(_SafeFallbackEncoder, self).__init__(**kwargs)
|
||||
|
||||
@@ -793,10 +793,35 @@ class RunExperimentTest(unittest.TestCase):
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"custom_loggers": [CustomLogger]
|
||||
"loggers": [CustomLogger]
|
||||
}
|
||||
})
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
})
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"loggers": []
|
||||
}
|
||||
})
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
def testCustomTrialString(self):
|
||||
[trial] = run_experiments({
|
||||
|
||||
@@ -256,7 +256,7 @@ class Trial(object):
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
custom_loggers=None,
|
||||
loggers=None,
|
||||
sync_function=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
@@ -276,7 +276,7 @@ class Trial(object):
|
||||
or self._get_trainable_cls().default_resource_request(self.config))
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.upload_dir = upload_dir
|
||||
self.custom_loggers = custom_loggers
|
||||
self.loggers = loggers
|
||||
self.sync_function = sync_function
|
||||
validate_sync_function(sync_function)
|
||||
self.verbose = True
|
||||
@@ -333,7 +333,7 @@ class Trial(object):
|
||||
self.config,
|
||||
self.logdir,
|
||||
upload_uri=self.upload_dir,
|
||||
custom_loggers=self.custom_loggers,
|
||||
loggers=self.loggers,
|
||||
sync_function=self.sync_function)
|
||||
|
||||
def sync_logger_to_new_location(self, worker_ip):
|
||||
@@ -509,10 +509,11 @@ class Trial(object):
|
||||
state = self.__dict__.copy()
|
||||
state["resources"] = resources_to_json(self.resources)
|
||||
|
||||
# These are non-pickleable entries.
|
||||
pickle_data = {
|
||||
"_checkpoint": self._checkpoint,
|
||||
"config": self.config,
|
||||
"custom_loggers": self.custom_loggers,
|
||||
"loggers": self.loggers,
|
||||
"sync_function": self.sync_function,
|
||||
"last_result": self.last_result
|
||||
}
|
||||
@@ -535,7 +536,7 @@ class Trial(object):
|
||||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
for key in [
|
||||
"_checkpoint", "config", "custom_loggers", "sync_function",
|
||||
"_checkpoint", "config", "loggers", "sync_function",
|
||||
"last_result"
|
||||
]:
|
||||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
Reference in New Issue
Block a user