[tune] Tune onto Logging Module (#2882)

Moves Tune onto logging in Python. Ignores examples and tests.
This commit is contained in:
Richard Liaw
2018-09-16 12:09:36 -07:00
committed by GitHub
parent a8248e8628
commit f372f48bf3
17 changed files with 121 additions and 63 deletions
+6 -4
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import re
import django
@@ -11,6 +12,7 @@ from django.core.management import execute_from_command_line
from common.exception import DatabaseError
root_path = os.path.dirname(os.path.abspath(__file__))
logger = logging.getLogger(__name__)
def run_board(args):
@@ -33,7 +35,7 @@ def run_board(args):
service.run()
# frontend service
print("Try to start automlboard on port %s\n" % args.port)
logger.info("Try to start automlboard on port %s\n" % args.port)
command = [
os.path.join(root_path, 'manage.py'), 'runserver',
'0.0.0.0:%s' % args.port, '--noreload'
@@ -64,12 +66,12 @@ def init_config(args):
os.environ["AUTOMLBOARD_DB_HOST"] = match.group(4)
os.environ["AUTOMLBOARD_DB_PORT"] = match.group(5)
os.environ["AUTOMLBOARD_DB_NAME"] = match.group(6)
print("Using %s as the database backend." % match.group(1))
logger.info("Using %s as the database backend." % match.group(1))
except BaseException as e:
raise DatabaseError(e)
else:
print("Using sqlite3 as the database backend, "
"information will be stored in automlboard.db")
logger.info("Using sqlite3 as the database backend, "
"information will be stored in automlboard.db")
os.environ.setdefault("DJANGO_SETTINGS_MODULE",
"ray.tune.automlboard.settings")
+7 -3
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import copy
import logging
import six
import types
@@ -10,6 +11,8 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable
logger = logging.getLogger(__name__)
class Experiment(object):
"""Tracks experiment specifications.
@@ -158,7 +161,8 @@ class Experiment(object):
return run_object
elif isinstance(run_object, types.FunctionType):
if run_object.__name__ == "<lambda>":
print("Not auto-registering lambdas - resolving as variant.")
logger.warning(
"Not auto-registering lambdas - resolving as variant.")
return run_object
else:
name = run_object.__name__
@@ -202,8 +206,8 @@ def convert_to_experiment_list(experiments):
if (type(exp_list) is list
and all(isinstance(exp, Experiment) for exp in exp_list)):
if len(exp_list) > 1:
print("Warning: All experiments will be"
" using the same Search Algorithm.")
logger.warning("All experiments will be "
"using the same SearchAlgorithm.")
else:
raise TuneError("Invalid argument: {}".format(experiments))
+4 -2
View File
@@ -2,14 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import time
import threading
import traceback
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIMESTEPS_TOTAL
logger = logging.getLogger(__name__)
class StatusReporter(object):
"""Object passed into your main() that you can report status through.
@@ -74,7 +76,7 @@ class _RunnerThread(threading.Thread):
self._entrypoint(*self._entrypoint_args)
except Exception as e:
self._status_reporter._error = e
print("Runner thread raised: {}".format(traceback.format_exc()))
logger.exception("Runner Thread raised error.")
raise e
finally:
self._status_reporter._done = True
+11 -7
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import distutils.spawn
import logging
import os
import subprocess
import time
@@ -17,6 +18,8 @@ from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
logger = logging.getLogger(__name__)
# Map from (logdir, remote_dir) -> syncer
_syncers = {}
@@ -73,7 +76,8 @@ class _LogSyncer(object):
self.sync_process = None
self.local_ip = ray.services.get_node_ip_address()
self.worker_ip = None
print("Created LogSyncer for {} -> {}".format(local_dir, remote_dir))
logger.info("Created LogSyncer for {} -> {}".format(
local_dir, remote_dir))
def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""
@@ -87,7 +91,7 @@ class _LogSyncer(object):
def sync_now(self, force=False):
self.last_sync_time = time.time()
if not self.worker_ip:
print("Worker ip unknown, skipping log sync for {}".format(
logger.info("Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
return
@@ -97,11 +101,11 @@ class _LogSyncer(object):
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
if ssh_key is None or ssh_user is None:
print("Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
logger.error("Log sync requires cluster to be setup with "
"`ray create_or_update`.")
return
if not distutils.spawn.find_executable("rsync"):
print("Error: log sync requires rsync to be installed.")
logger.error("Log sync requires rsync to be installed.")
return
worker_to_local_sync_cmd = ((
"""rsync -avz -e "ssh -i {} -o ConnectTimeout=120s """
@@ -125,7 +129,7 @@ class _LogSyncer(object):
if force:
self.sync_process.kill()
else:
print("Warning: last sync is still in progress, skipping")
logger.warning("Last sync is still in progress, skipping.")
return
if worker_to_local_sync_cmd or local_to_remote_sync_cmd:
@@ -136,7 +140,7 @@ class _LogSyncer(object):
if final_cmd:
final_cmd += " && "
final_cmd += local_to_remote_sync_cmd
print("Running log sync: {}".format(final_cmd))
logger.info("Running log sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(final_cmd, shell=True)
def wait(self):
+7 -2
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import csv
import json
import logging
import numpy as np
import os
import yaml
@@ -12,11 +13,14 @@ from ray.tune.log_sync import get_syncer
from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \
TIMESTEPS_TOTAL
logger = logging.getLogger(__name__)
try:
import tensorflow as tf
except ImportError:
tf = None
print("Couldn't import TensorFlow - this disables TensorBoard logging.")
logger.warning("Couldn't import TensorFlow - "
"disabling TensorBoard logging.")
class Logger(object):
@@ -60,7 +64,8 @@ class UnifiedLogger(Logger):
self._loggers = []
for cls in [_JsonLogger, _TFLogger, _VisKitLogger]:
if cls is _TFLogger and tf is None:
print("TF not installed - cannot log with {}...".format(cls))
logger.info("TF not installed - "
"cannot log with {}...".format(cls))
continue
self._loggers.append(cls(self.config, self.logdir, self.uri))
self._log_syncer = get_syncer(self.logdir, self.uri)
+15 -11
View File
@@ -3,6 +3,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import time
import traceback
@@ -12,6 +13,8 @@ from ray.tune.logger import NoopLogger
from ray.tune.trial import Trial, Resources, Checkpoint
from ray.tune.trial_executor import TrialExecutor
logger = logging.getLogger(__name__)
class RayTrialExecutor(TrialExecutor):
"""An implemention of TrialExecutor based on Ray."""
@@ -97,7 +100,7 @@ class RayTrialExecutor(TrialExecutor):
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
except Exception:
print("Error stopping runner:", traceback.format_exc())
logger.exception("Error stopping runner.")
trial.status = Trial.ERROR
finally:
trial.runner = None
@@ -112,15 +115,15 @@ class RayTrialExecutor(TrialExecutor):
try:
self._start_trial(trial, checkpoint_obj)
except Exception:
logger.exception("Error stopping runner - retrying...")
error_msg = traceback.format_exc()
print("Error starting runner, retrying:", error_msg)
time.sleep(2)
self._stop_trial(trial, error=True, error_msg=error_msg)
try:
self._start_trial(trial)
except Exception:
logger.exception("Error starting runner, aborting!")
error_msg = traceback.format_exc()
print("Error starting runner, abort:", error_msg)
self._stop_trial(trial, error=True, error_msg=error_msg)
# note that we don't return the resources, since they may
# have been lost
@@ -245,12 +248,13 @@ class RayTrialExecutor(TrialExecutor):
can_overcommit = False # requested resource is already saturated
if can_overcommit:
print("WARNING:tune:allowing trial to start even though the "
"cluster does not have enough free resources. Trial actors "
"may appear to hang until enough resources are added to the "
"cluster (e.g., via autoscaling). You can disable this "
"behavior by specifying `queue_trials=False` in "
"ray.tune.run_experiments().")
logger.warning(
"Allowing trial to start even though the "
"cluster does not have enough free resources. Trial actors "
"may appear to hang until enough resources are added to the "
"cluster (e.g., via autoscaling). You can disable this "
"behavior by specifying `queue_trials=False` in "
"ray.tune.run_experiments().")
return True
return False
@@ -286,7 +290,7 @@ class RayTrialExecutor(TrialExecutor):
if checkpoint is None or checkpoint.value is None:
return True
if trial.runner is None:
print("Unable to restore - no runner")
logger.error("Unable to restore - no runner.")
trial.status = Trial.ERROR
return False
try:
@@ -298,6 +302,6 @@ class RayTrialExecutor(TrialExecutor):
ray.get(trial.runner.restore.remote(value))
return True
except Exception:
print("Error restoring runner:", traceback.format_exc())
logger.exception("Error restoring runner.")
trial.status = Trial.ERROR
return False
@@ -2,10 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
logger = logging.getLogger(__name__)
class AsyncHyperBandScheduler(FIFOScheduler):
"""Implements the Async Successive Halving.
@@ -133,8 +136,8 @@ class _Bracket():
if cutoff is not None and cur_rew < cutoff:
action = TrialScheduler.STOP
if cur_rew is None:
print("Reward attribute is None! Consider"
" reporting using a different field.")
logger.warning("Reward attribute is None! Consider"
" reporting using a different field.")
else:
recorded[trial.trial_id] = cur_rew
break
+4 -1
View File
@@ -4,10 +4,13 @@ from __future__ import print_function
import collections
import numpy as np
import logging
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.trial import Trial
logger = logging.getLogger(__name__)
# Implementation notes:
# This implementation contains 3 logical levels.
@@ -112,7 +115,7 @@ class HyperBandScheduler(FIFOScheduler):
s = len(cur_band)
assert s < self._s_max_1, "Current band is filled!"
if self._get_r0(s) == 0:
print("Bracket too small - Retrying...")
logger.info("Bracket too small - Retrying...")
cur_bracket = None
else:
retry = False
@@ -3,11 +3,14 @@ from __future__ import division
from __future__ import print_function
import collections
import logging
import numpy as np
from ray.tune.trial import Trial
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
logger = logging.getLogger(__name__)
class MedianStoppingRule(FIFOScheduler):
"""Implements the median stopping rule as described in the Vizier paper:
@@ -67,11 +70,12 @@ class MedianStoppingRule(FIFOScheduler):
median_result = self._get_median_result(time)
best_result = self._best_result(trial)
if self._verbose:
print("Trial {} best res={} vs median res={} at t={}".format(
logger.info("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if best_result < median_result and time > self._grace_period:
if self._verbose:
print("MedianStoppingRule: early stopping {}".format(trial))
logger.info("MedianStoppingRule: "
"early stopping {}".format(trial))
self._stopped_trials.add(trial)
if self._hard_stop:
return TrialScheduler.STOP
+10 -6
View File
@@ -5,12 +5,15 @@ from __future__ import print_function
import random
import math
import copy
import logging
from ray.tune.error import TuneError
from ray.tune.trial import Trial, Checkpoint
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest.variant_generator import format_vars
logger = logging.getLogger(__name__)
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
# the bottom PBT_QUANTILE fraction.
PBT_QUANTILE = 0.25
@@ -69,7 +72,7 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
new_config = custom_explore_fn(new_config)
assert new_config is not None, \
"Custom explore fn failed to return new config"
print("[explore] perturbed config from {} -> {}".format(
logger.info("[explore] perturbed config from {} -> {}".format(
config, new_config))
return new_config
@@ -210,15 +213,16 @@ class PopulationBasedTraining(FIFOScheduler):
trial_state = self._trial_state[trial]
new_state = self._trial_state[trial_to_clone]
if not new_state.last_checkpoint:
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
logger.warning("[pbt]: no checkpoint for trial"
"skip exploit for Trial {}".format(trial))
return
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability,
self._custom_explore_fn)
print("[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
logger.warning("[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
# TODO(ekl) restarting the trial is expensive. We should implement a
# lighter way reset() method that can alter the trial config.
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
+5 -1
View File
@@ -6,6 +6,7 @@ from datetime import datetime
import gzip
import io
import logging
import os
import pickle
import shutil
@@ -19,6 +20,8 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL)
from ray.tune.trial import Resources
logger = logging.getLogger(__name__)
class Trainable(object):
"""Abstract class for trainable models, functions, etc.
@@ -231,7 +234,8 @@ class Trainable(object):
"data": data,
})
if len(compressed) > 10e6: # getting pretty large
print("Checkpoint size is {} bytes".format(len(compressed)))
logger.info("Checkpoint size is {} bytes".format(
len(compressed)))
f.write(compressed)
shutil.rmtree(tmpdir)
+7 -4
View File
@@ -2,13 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from collections import namedtuple
from datetime import datetime
import logging
import time
import ray
import tempfile
import os
import ray
from ray.tune import TuneError
from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
@@ -21,6 +22,7 @@ from ray.utils import random_string, binary_to_hex
DEBUG_PRINT_INTERVAL = 5
MAX_LEN_IDENTIFIER = 130
logger = logging.getLogger(__name__)
def date_str():
@@ -274,8 +276,9 @@ class Trial(object):
result.update(done=True)
if self.verbose and (terminate or time.time() - self.last_debug >
DEBUG_PRINT_INTERVAL):
print("Result for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
logger.info("Result for {}:".format(self))
logger.info(" {}".format(
pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
self.last_result = result
self.result_logger.on_result(self.last_result)
+7 -3
View File
@@ -3,10 +3,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import traceback
from ray.tune.trial import Trial, Checkpoint
logger = logging.getLogger(__name__)
class TrialExecutor(object):
"""Manages platform-specific details such as resource handling
@@ -66,14 +69,15 @@ class TrialExecutor(object):
error_msg (str): Optional error message.
"""
try:
print("Attempting to recover trial state from last checkpoint")
logger.info(
"Attempting to recover trial state from last checkpoint")
self.stop_trial(
trial, error=True, error_msg=error_msg, stop_logger=False)
trial.result_logger.flush()
self.start_trial(trial)
except Exception:
error_msg = traceback.format_exc()
print("Error recovering trial from checkpoint, abort:", error_msg)
logger.exception("Error recovering trial from checkpoint, abort.")
self.stop_trial(trial, error=True, error_msg=error_msg)
def continue_training(self, trial):
@@ -92,7 +96,7 @@ class TrialExecutor(object):
self.stop_trial(trial, stop_logger=False)
trial.status = Trial.PAUSED
except Exception:
print("Error pausing runner:", traceback.format_exc())
logger.exception("Error pausing runner.")
trial.status = Trial.ERROR
def unpause_trial(self, trial):
+10 -6
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import collections
import logging
import os
import re
import time
@@ -17,6 +18,8 @@ from ray.tune.web_server import TuneServer
MAX_DEBUG_TRIALS = 20
logger = logging.getLogger(__name__)
def _naturalize(string):
"""Provides a natural representation for string for nice sorting."""
@@ -92,7 +95,7 @@ class TrialRunner(object):
"""Returns whether all trials have finished running."""
if self._total_time > self._global_time_limit:
print("Exceeded global time limit {} / {}".format(
logger.warning("Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
return True
@@ -270,8 +273,8 @@ class TrialRunner(object):
assert False, "Invalid scheduling decision: {}".format(
decision)
except Exception:
logger.exception("Error processing event.")
error_msg = traceback.format_exc()
print("Error processing event:", error_msg)
if trial.status == Trial.RUNNING:
if trial.has_checkpoint() and \
trial.num_failures < trial.max_failures:
@@ -284,11 +287,12 @@ class TrialRunner(object):
def _try_recover(self, trial, error_msg):
try:
print("Attempting to recover trial state from last checkpoint")
logger.info("Attempting to recover"
" trial state from last checkpoint.")
self.trial_executor.restart_trial(trial, error_msg)
except Exception:
error_msg = traceback.format_exc()
print("Error recovering trial from checkpoint, abort:", error_msg)
logger.warning("Error recovering trial from checkpoint, abort.")
self.trial_executor.stop_trial(trial, True, error_msg=error_msg)
def _update_trial_queue(self, blocking=False, timeout=600):
@@ -307,7 +311,7 @@ class TrialRunner(object):
start = time.time()
while (not trials and not self.is_finished()
and time.time() - start < timeout):
print("Blocking for next trial...")
logger.info("Blocking for next trial...")
trials = self._search_alg.next_trials()
time.sleep(1)
@@ -349,7 +353,7 @@ class TrialRunner(object):
trial.trial_id, result=result)
except Exception:
error_msg = traceback.format_exc()
print("Error processing event:", error_msg)
logger.exception("Error processing event.")
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
error = True
+6 -3
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import time
from ray.tune.error import TuneError
@@ -13,6 +14,8 @@ from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
FIFOScheduler, MedianStoppingRule)
from ray.tune.web_server import TuneServer
logger = logging.getLogger(__name__)
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
@@ -95,16 +98,16 @@ def run_experiments(experiments=None,
queue_trials=queue_trials,
trial_executor=trial_executor)
print(runner.debug_string(max_debug=99999))
logger.info(runner.debug_string(max_debug=99999))
last_debug = 0
while not runner.is_finished():
runner.step()
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
print(runner.debug_string())
logger.info(runner.debug_string())
last_debug = time.time()
print(runner.debug_string(max_debug=99999))
logger.info(runner.debug_string(max_debug=99999))
errored_trials = []
for trial in runner.get_trials():
+5 -3
View File
@@ -4,12 +4,14 @@ from __future__ import print_function
import pandas as pd
from pandas.api.types import is_string_dtype, is_numeric_dtype
import logging
import os
import os.path as osp
import numpy as np
import json
logger = logging.getLogger(__name__)
def _flatten_dict(dt):
while any(type(v) is dict for v in dt.values()):
@@ -35,7 +37,7 @@ def _parse_results(res_path):
pass
res_dict = _flatten_dict(json.loads(line.strip()))
except Exception as e:
print("Importing %s failed...Perhaps empty?" % res_path, e)
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
return res_dict
@@ -44,7 +46,7 @@ def _parse_configs(cfg_path):
with open(cfg_path) as f:
cfg_dict = _flatten_dict(json.load(f))
except Exception as e:
print(e)
logger.exception("Config parsing failed.")
return cfg_dict
+6 -3
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import json
import logging
import sys
import threading
@@ -15,12 +16,14 @@ if sys.version_info[0] == 2:
elif sys.version_info[0] == 3:
from http.server import SimpleHTTPRequestHandler, HTTPServer
logger = logging.getLogger(__name__)
try:
import requests # `requests` is not part of stdlib.
except ImportError:
requests = None
print("Couldn't import `requests` library. Be sure to install it on"
" the client side.")
logger.exception("Couldn't import `requests` library. "
"Be sure to install it on the client side.")
class TuneClient(object):
@@ -149,7 +152,7 @@ class TuneServer(threading.Thread):
threading.Thread.__init__(self)
self._port = port if port else self.DEFAULT_PORT
address = ('localhost', self._port)
print("Starting Tune Server...")
logger.info("Starting Tune Server...")
self._server = HTTPServer(address, RunnerHandler(runner))
self.start()