mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[tune] tune.track -> tune.report (#8388)
This commit is contained in:
@@ -7,33 +7,18 @@ from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.session import (report, get_trial_dir, get_trial_name,
|
||||
get_trial_id)
|
||||
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
|
||||
JupyterNotebookReporter)
|
||||
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
randn, loguniform)
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"DurableTrainable",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run",
|
||||
"run_experiments",
|
||||
"Stopper",
|
||||
"Experiment",
|
||||
"function",
|
||||
"sample_from",
|
||||
"track",
|
||||
"uniform",
|
||||
"choice",
|
||||
"randint",
|
||||
"randn",
|
||||
"loguniform",
|
||||
"ExperimentAnalysis",
|
||||
"Analysis",
|
||||
"CLIReporter",
|
||||
"JupyterNotebookReporter",
|
||||
"ProgressReporter",
|
||||
"Trainable", "DurableTrainable", "TuneError", "grid_search",
|
||||
"register_env", "register_trainable", "run", "run_experiments", "Stopper",
|
||||
"Experiment", "function", "sample_from", "track", "uniform", "choice",
|
||||
"randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis",
|
||||
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
|
||||
"get_trial_dir", "get_trial_name", "get_trial_id"
|
||||
]
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray import tune
|
||||
def LightGBMCallback(env):
|
||||
"""Assumes that `valid_0` is the target validation score."""
|
||||
_, metric, score, _ = env.evaluation_result_list[0]
|
||||
tune.track.log(**{metric: score})
|
||||
tune.report(**{metric: score})
|
||||
|
||||
|
||||
def train_breast_cancer(config):
|
||||
@@ -27,7 +27,7 @@ def train_breast_cancer(config):
|
||||
callbacks=[LightGBMCallback])
|
||||
preds = gbm.predict(test_x)
|
||||
pred_labels = np.rint(preds)
|
||||
tune.track.log(
|
||||
tune.report(
|
||||
mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels),
|
||||
done=True)
|
||||
|
||||
|
||||
@@ -20,9 +20,8 @@ def easy_objective(config):
|
||||
result = dict(
|
||||
timesteps_total=i,
|
||||
mean_loss=(config["height"] - 14)**2 - abs(config["width"] - 3))
|
||||
tune.track.log(**result)
|
||||
tune.report(**result)
|
||||
time.sleep(0.02)
|
||||
tune.track.log(done=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,7 +12,6 @@ from torchvision import datasets, transforms
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import track
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
|
||||
# Change these values if you want the training to run quicker or slower.
|
||||
@@ -33,7 +32,8 @@ class ConvNet(nn.Module):
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def train(model, optimizer, train_loader, device=torch.device("cpu")):
|
||||
def train(model, optimizer, train_loader, device=None):
|
||||
device = device or torch.device("cpu")
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if batch_idx * len(data) > EPOCH_SIZE:
|
||||
@@ -46,7 +46,8 @@ def train(model, optimizer, train_loader, device=torch.device("cpu")):
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def test(model, data_loader, device=torch.device("cpu")):
|
||||
def test(model, data_loader, device=None):
|
||||
device = device or torch.device("cpu")
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
@@ -99,7 +100,7 @@ def train_mnist(config):
|
||||
while True:
|
||||
train(model, optimizer, train_loader, device)
|
||||
acc = test(model, test_loader, device)
|
||||
track.log(mean_accuracy=acc)
|
||||
tune.report(mean_accuracy=acc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras.datasets import mnist
|
||||
|
||||
from ray.tune import track
|
||||
from ray.tune.integration.keras import TuneReporterCallback
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar="LR",
|
||||
help="learning rate (default: 0.01)")
|
||||
parser.add_argument(
|
||||
"--momentum",
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar="M",
|
||||
help="SGD momentum (default: 0.5)")
|
||||
parser.add_argument(
|
||||
"--hidden", type=int, default=64, help="Size of hidden layer.")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
def train_mnist(args):
|
||||
track.init(trial_name="track-example", trial_config=vars(args))
|
||||
batch_size = 128
|
||||
num_classes = 10
|
||||
epochs = 1 if args.smoke_test else 12
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Flatten(input_shape=(28, 28)),
|
||||
tf.keras.layers.Dense(args.hidden, activation="relu"),
|
||||
tf.keras.layers.Dropout(0.2),
|
||||
tf.keras.layers.Dense(num_classes, activation="softmax")
|
||||
])
|
||||
|
||||
model.compile(
|
||||
loss="sparse_categorical_crossentropy",
|
||||
optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum),
|
||||
metrics=["accuracy"])
|
||||
|
||||
model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=[TuneReporterCallback()])
|
||||
track.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_mnist(args)
|
||||
@@ -10,7 +10,7 @@ parser.add_argument(
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
def train_mnist(config, reporter):
|
||||
def train_mnist(config):
|
||||
# https://github.com/tensorflow/tensorflow/issues/32159
|
||||
import tensorflow as tf
|
||||
batch_size = 128
|
||||
@@ -39,7 +39,7 @@ def train_mnist(config, reporter):
|
||||
epochs=epochs,
|
||||
verbose=0,
|
||||
validation_data=(x_test, y_test),
|
||||
callbacks=[TuneReporterCallback(reporter)])
|
||||
callbacks=[TuneReporterCallback()])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
mnist.load_data() # we do this on the driver because it's not threadsafe
|
||||
|
||||
ray.init(num_cpus=2 if args.smoke_test else None)
|
||||
ray.init(num_cpus=4 if args.smoke_test else None)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray import tune
|
||||
|
||||
|
||||
def XGBCallback(env):
|
||||
tune.track.log(**dict(env.evaluation_result_list))
|
||||
tune.report(**dict(env.evaluation_result_list))
|
||||
|
||||
|
||||
def train_breast_cancer(config):
|
||||
@@ -21,7 +21,7 @@ def train_breast_cancer(config):
|
||||
config, train_set, evals=[(test_set, "eval")], callbacks=[XGBCallback])
|
||||
preds = bst.predict(test_set)
|
||||
pred_labels = np.rint(preds)
|
||||
tune.track.log(
|
||||
tune.report(
|
||||
mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels),
|
||||
done=True)
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ import threading
|
||||
import traceback
|
||||
from six.moves import queue
|
||||
|
||||
from ray.tune import track
|
||||
from ray.tune import TuneError
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
|
||||
@@ -158,6 +157,8 @@ class FunctionRunner(Trainable):
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
session.init(self._status_reporter)
|
||||
|
||||
def entrypoint():
|
||||
return self._trainable_func(config, self._status_reporter)
|
||||
|
||||
@@ -251,6 +252,8 @@ class FunctionRunner(Trainable):
|
||||
# Check for any errors that might have been missed.
|
||||
self._report_thread_runner_error()
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def _report_thread_runner_error(self, block=False):
|
||||
try:
|
||||
err_tb_str = self._error_queue.get(
|
||||
@@ -262,32 +265,19 @@ class FunctionRunner(Trainable):
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
|
||||
use_track = False
|
||||
try:
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
logger.debug("tune.track signature detected.")
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Function inspection failed - assuming reporter signature.")
|
||||
|
||||
class WrappedFunc(FunctionRunner):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter):
|
||||
output = train_func(config, reporter)
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
output = train_func(config)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
|
||||
# If train_func returns, we need to notify the main event loop
|
||||
# of the last result while avoiding double logging. This is done
|
||||
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
|
||||
reporter(**{RESULT_DUPLICATE: True})
|
||||
return output
|
||||
|
||||
class WrappedTrackFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter):
|
||||
track.init(_tune_reporter=reporter)
|
||||
output = train_func(config)
|
||||
reporter(**{RESULT_DUPLICATE: True})
|
||||
track.shutdown()
|
||||
return output
|
||||
|
||||
return WrappedTrackFunc if use_track else WrappedFunc
|
||||
return ImplicitFunc
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
from tensorflow import keras
|
||||
from ray.tune import track
|
||||
|
||||
|
||||
class TuneReporterCallback(keras.callbacks.Callback):
|
||||
"""Tune Callback for Keras."""
|
||||
|
||||
def __init__(self, reporter=None, freq="batch", logs={}):
|
||||
def __init__(self, reporter=None, freq="batch", logs=None):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
reporter (StatusReporter|tune.track.log|None): Tune object for
|
||||
returning results.
|
||||
freq (str): Sets the frequency of reporting intermediate results.
|
||||
One of ["batch", "epoch"].
|
||||
"""
|
||||
self.reporter = reporter or track.log
|
||||
self.iteration = 0
|
||||
logs = logs or {}
|
||||
if freq not in ["batch", "epoch"]:
|
||||
raise ValueError("{} not supported as a frequency.".format(freq))
|
||||
self.freq = freq
|
||||
super(TuneReporterCallback, self).__init__()
|
||||
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
from ray import tune
|
||||
logs = logs or {}
|
||||
if not self.freq == "batch":
|
||||
return
|
||||
self.iteration += 1
|
||||
@@ -29,11 +28,13 @@ class TuneReporterCallback(keras.callbacks.Callback):
|
||||
if "loss" in metric and "neg_" not in metric:
|
||||
logs["neg_" + metric] = -logs[metric]
|
||||
if "acc" in logs:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
tune.report(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
else:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
|
||||
tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy"))
|
||||
|
||||
def on_epoch_end(self, batch, logs={}):
|
||||
def on_epoch_end(self, batch, logs=None):
|
||||
from ray import tune
|
||||
logs = logs or {}
|
||||
if not self.freq == "epoch":
|
||||
return
|
||||
self.iteration += 1
|
||||
@@ -41,6 +42,6 @@ class TuneReporterCallback(keras.callbacks.Callback):
|
||||
if "loss" in metric and "neg_" not in metric:
|
||||
logs["neg_" + metric] = -logs[metric]
|
||||
if "acc" in logs:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
tune.report(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
else:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
|
||||
tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy"))
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_session = None
|
||||
|
||||
|
||||
class _ReporterSession:
|
||||
def __init__(self, tune_reporter):
|
||||
self.tune_reporter = tune_reporter
|
||||
|
||||
def report(self, **metrics):
|
||||
return self.tune_reporter(**metrics)
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
"""Trial logdir (subdir of given experiment directory)"""
|
||||
return self.tune_reporter.logdir
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable"""
|
||||
return self.tune_reporter.trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial id for the corresponding trial of this Trainable"""
|
||||
return self.tune_reporter.trial_id
|
||||
|
||||
|
||||
def get_session():
|
||||
global _session
|
||||
if _session is None:
|
||||
raise ValueError(
|
||||
"Session not detected. You should not be calling this function "
|
||||
"outside `tune.run` or while using the class API. ")
|
||||
return _session
|
||||
|
||||
|
||||
def init(reporter, ignore_reinit_error=True):
|
||||
"""Initializes the global trial context for this process."""
|
||||
global _session
|
||||
|
||||
if _session is not None:
|
||||
# TODO(ng): would be nice to stack crawl at creation time to report
|
||||
# where that initial trial was created, and that creation line
|
||||
# info is helpful to keep around anyway.
|
||||
reinit_msg = (
|
||||
"A Tune session already exists in the current process. "
|
||||
"If you are using ray.init(local_mode=True), "
|
||||
"you must set ray.init(..., num_cpus=1, num_gpus=1) to limit "
|
||||
"available concurrency.")
|
||||
if ignore_reinit_error:
|
||||
logger.warning(reinit_msg)
|
||||
return
|
||||
else:
|
||||
raise ValueError(reinit_msg)
|
||||
|
||||
_session = _ReporterSession(reporter)
|
||||
|
||||
|
||||
def shutdown():
|
||||
"""Cleans up the trial and removes it from the global context."""
|
||||
|
||||
global _session
|
||||
_session = None
|
||||
|
||||
|
||||
def report(**kwargs):
|
||||
"""Logs all keyword arguments.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def run_me(config):
|
||||
for iter in range(100):
|
||||
time.sleep(1)
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
analysis = tune.run(run_me)
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be logged by Tune. Any of these
|
||||
metrics can be used for early stopping or optimization.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.report(**kwargs)
|
||||
|
||||
|
||||
def get_trial_dir():
|
||||
"""Returns the directory where trial results are saved.
|
||||
|
||||
For function API use only. Do not call this method in the Class API. Use
|
||||
`self.logdir` instead.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.logdir
|
||||
|
||||
|
||||
def get_trial_name():
|
||||
"""Trial name for the corresponding trial of this Trainable.
|
||||
|
||||
For function API use only. Do not call this method in the Class API. Use
|
||||
`self.trial_name` instead.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_name
|
||||
|
||||
|
||||
def get_trial_id():
|
||||
"""Trial id for the corresponding trial of this Trainable.
|
||||
|
||||
For function API use only. Do not call this method in the Class API. Use
|
||||
`self.trial_id` instead.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_id
|
||||
|
||||
|
||||
__all__ = ["report", "get_trial_dir", "get_trial_name", "get_trial_id"]
|
||||
@@ -23,7 +23,7 @@ def train_mnist(config):
|
||||
for i in range(10):
|
||||
train(model, optimizer, train_loader)
|
||||
acc = test(model, test_loader)
|
||||
tune.track.log(mean_accuracy=acc)
|
||||
tune.report(mean_accuracy=acc)
|
||||
|
||||
|
||||
analysis = tune.run(
|
||||
|
||||
@@ -572,7 +572,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
def testReportInfinity(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
for _ in range(100):
|
||||
reporter(mean_accuracy=float("inf"))
|
||||
|
||||
register_trainable("f1", train)
|
||||
@@ -606,8 +606,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
|
||||
|
||||
def track_train(config):
|
||||
tune.track.log(
|
||||
name=tune.track.trial_name(), trial_id=tune.track.trial_id())
|
||||
tune.report(
|
||||
name=tune.get_trial_name(), trial_id=tune.get_trial_id())
|
||||
|
||||
analysis = tune.run(track_train, stop={TRAINING_ITERATION: 1})
|
||||
trial = analysis.trials[0]
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import track
|
||||
from ray.tune.result import EXPR_PARAM_FILE, EXPR_RESULT_FILE
|
||||
from ray.tune import session
|
||||
|
||||
|
||||
def _check_json_val(fname, key, val):
|
||||
@@ -16,46 +14,24 @@ def _check_json_val(fname, key, val):
|
||||
|
||||
class TrackApiTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
track.shutdown()
|
||||
session.shutdown()
|
||||
ray.shutdown()
|
||||
|
||||
def testSessionInitShutdown(self):
|
||||
self.assertTrue(track._session is None)
|
||||
self.assertTrue(session._session is None)
|
||||
|
||||
# Checks that the singleton _session is created/destroyed
|
||||
# by track.init() and track.shutdown()
|
||||
# by session.init() and session.shutdown()
|
||||
for _ in range(2):
|
||||
# do it twice to see that we can reopen the session
|
||||
track.init(trial_name="test_init")
|
||||
self.assertTrue(track._session is not None)
|
||||
track.shutdown()
|
||||
self.assertTrue(track._session is None)
|
||||
session.init(reporter=None)
|
||||
self.assertTrue(session._session is not None)
|
||||
session.shutdown()
|
||||
self.assertTrue(session._session is None)
|
||||
|
||||
def testLogCreation(self):
|
||||
"""Checks that track.init() starts logger and creates log files."""
|
||||
track.init(trial_name="test_init")
|
||||
session = track.get_session()
|
||||
self.assertTrue(session is not None)
|
||||
|
||||
self.assertTrue(os.path.isdir(session.logdir))
|
||||
|
||||
params_path = os.path.join(session.logdir, EXPR_PARAM_FILE)
|
||||
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
|
||||
|
||||
self.assertTrue(os.path.exists(params_path))
|
||||
self.assertTrue(os.path.exists(result_path))
|
||||
self.assertTrue(session.logdir == track.trial_dir())
|
||||
|
||||
def testMetric(self):
|
||||
track.init(trial_name="test_log")
|
||||
session = track.get_session()
|
||||
for i in range(5):
|
||||
track.log(test=i)
|
||||
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
|
||||
self.assertTrue(_check_json_val(result_path, "test", i))
|
||||
|
||||
def testRayOutput(self):
|
||||
"""Checks that local and remote log format are the same."""
|
||||
def testSoftDeprecation(self):
|
||||
"""Checks that tune.track.log code does not break."""
|
||||
from ray.tune import track
|
||||
ray.init()
|
||||
|
||||
def testme(config):
|
||||
@@ -67,18 +43,6 @@ class TrackApiTest(unittest.TestCase):
|
||||
self.assertTrue(trial_res["hi"], "test")
|
||||
self.assertTrue(trial_res["training_iteration"], 5)
|
||||
|
||||
def testLocalMetrics(self):
|
||||
"""Checks that metric state is updated correctly."""
|
||||
track.init(trial_name="test_logs")
|
||||
session = track.get_session()
|
||||
self.assertEqual(set(session.trial_config.keys()), {"trial_id"})
|
||||
|
||||
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
|
||||
track.log(test=1)
|
||||
self.assertTrue(_check_json_val(result_path, "test", 1))
|
||||
track.log(iteration=1, test=2)
|
||||
self.assertTrue(_check_json_val(result_path, "test", 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -9,7 +9,6 @@ import torch.optim as optim
|
||||
from torchvision import datasets
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import track
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test
|
||||
# __tutorial_imports_end__
|
||||
@@ -26,7 +25,7 @@ def train_mnist(config):
|
||||
for i in range(10):
|
||||
train(model, optimizer, train_loader)
|
||||
acc = test(model, test_loader)
|
||||
track.log(mean_accuracy=acc)
|
||||
tune.report(mean_accuracy=acc)
|
||||
if i % 5 == 0:
|
||||
# This saves the model to the trial directory
|
||||
torch.save(model, "./model.pth")
|
||||
|
||||
@@ -1,105 +1,58 @@
|
||||
import logging
|
||||
|
||||
from ray.tune.track.session import TrackSession as _TrackSession
|
||||
from ray.tune import session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_session = None
|
||||
warned = False
|
||||
|
||||
|
||||
def _deprecation_warning(call=None, alternative_call=None, soft=True):
|
||||
msg = "tune.track is now deprecated."
|
||||
if call and alternative_call:
|
||||
msg = "tune.track.{} is now deprecated.".format(call)
|
||||
msg += " Use `tune.{}` instead.".format(alternative_call)
|
||||
global warned
|
||||
if soft:
|
||||
msg += " This warning will throw an error in a future version of Ray."
|
||||
if not warned:
|
||||
logger.warning(msg)
|
||||
warned = True
|
||||
else:
|
||||
raise DeprecationWarning(msg)
|
||||
|
||||
|
||||
def get_session():
|
||||
global _session
|
||||
if not _session:
|
||||
raise ValueError("Session not detected. Try `track.init()`?")
|
||||
return _session
|
||||
_deprecation_warning(soft=False)
|
||||
|
||||
|
||||
def init(ignore_reinit_error=True, **session_kwargs):
|
||||
"""Initializes the global trial context for this process.
|
||||
|
||||
This creates a TrackSession object and the corresponding hooks for logging.
|
||||
|
||||
Examples:
|
||||
>>> from ray.tune import track
|
||||
>>> track.init()
|
||||
"""
|
||||
global _session
|
||||
|
||||
if _session:
|
||||
# TODO(ng): would be nice to stack crawl at creation time to report
|
||||
# where that initial trial was created, and that creation line
|
||||
# info is helpful to keep around anyway.
|
||||
reinit_msg = "A session already exists in the current context."
|
||||
if ignore_reinit_error:
|
||||
if not _session.is_tune_session:
|
||||
logger.warning(reinit_msg)
|
||||
return
|
||||
else:
|
||||
raise ValueError(reinit_msg)
|
||||
|
||||
_session = _TrackSession(**session_kwargs)
|
||||
_deprecation_warning(soft=False)
|
||||
|
||||
|
||||
def shutdown():
|
||||
"""Cleans up the trial and removes it from the global context."""
|
||||
|
||||
global _session
|
||||
if _session:
|
||||
_session.close()
|
||||
_session = None
|
||||
_deprecation_warning(soft=False)
|
||||
|
||||
|
||||
def log(**kwargs):
|
||||
"""Logs all keyword arguments.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray import tune
|
||||
from ray.tune import track
|
||||
|
||||
def run_me(config):
|
||||
for iter in range(100):
|
||||
time.sleep(1)
|
||||
track.log(hello="world", ray="tune")
|
||||
|
||||
analysis = tune.run(run_me)
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be logged by Tune. Any of these
|
||||
metrics can be used for early stopping or optimization.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.log(**kwargs)
|
||||
_deprecation_warning(call="log", alternative_call="report", soft=True)
|
||||
session.report(**kwargs)
|
||||
|
||||
|
||||
def trial_dir():
|
||||
"""Returns the directory where trial results are saved.
|
||||
|
||||
This includes json data containing the session's parameters and metrics.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.logdir
|
||||
_deprecation_warning(
|
||||
call="trial_dir", alternative_call="get_trial_dir", soft=True)
|
||||
return session.get_trial_dir()
|
||||
|
||||
|
||||
def trial_name():
|
||||
"""Trial name for the corresponding trial of this Trainable.
|
||||
|
||||
This is not set if not using Tune.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_name
|
||||
_deprecation_warning(
|
||||
call="trial_name", alternative_call="get_trial_name", soft=True)
|
||||
return session.get_trial_name()
|
||||
|
||||
|
||||
def trial_id():
|
||||
"""Trial id for the corresponding trial of this Trainable.
|
||||
|
||||
This is not set if not using Tune.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"session", "log", "trial_dir", "init", "shutdown", "trial_name", "trial_id"
|
||||
]
|
||||
_deprecation_warning(
|
||||
call="trial_id", alternative_call="get_trial_id", soft=True)
|
||||
return session.get_trial_id()
|
||||
|
||||
@@ -1,123 +1,3 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TRAINING_ITERATION
|
||||
from ray.tune.logger import UnifiedLogger, Logger
|
||||
|
||||
|
||||
class _ReporterHook(Logger):
|
||||
def __init__(self, tune_reporter):
|
||||
self.tune_reporter = tune_reporter
|
||||
|
||||
def on_result(self, metrics):
|
||||
return self.tune_reporter(**metrics)
|
||||
|
||||
|
||||
class TrackSession:
|
||||
"""Manages results for a single session.
|
||||
|
||||
Represents a single Trial in an experiment. This is automatically
|
||||
created when using ``tune.run``.
|
||||
|
||||
Attributes:
|
||||
trial_name (str): Custom trial name.
|
||||
experiment_dir (str): Directory where results for all trials
|
||||
are stored. Each session is stored into a unique directory
|
||||
inside experiment_dir.
|
||||
upload_dir (str): Directory to sync results to.
|
||||
trial_config (dict): Parameters that will be logged to disk.
|
||||
_tune_reporter (StatusReporter): For rerouting when using Tune.
|
||||
Will not instantiate logging if not None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trial_name=None,
|
||||
experiment_dir=None,
|
||||
upload_dir=None,
|
||||
trial_config=None,
|
||||
_tune_reporter=None):
|
||||
self._experiment_dir = None
|
||||
self._logdir = None
|
||||
self._upload_dir = None
|
||||
self.trial_config = None
|
||||
self._iteration = -1
|
||||
self.is_tune_session = bool(_tune_reporter)
|
||||
if self.is_tune_session:
|
||||
self._logger = _ReporterHook(_tune_reporter)
|
||||
self._logdir = _tune_reporter.logdir
|
||||
self._trial_name = _tune_reporter.trial_name
|
||||
self._trial_id = _tune_reporter.trial_id
|
||||
else:
|
||||
self._trial_id = Trial.generate_id()
|
||||
self._trial_name = trial_name or self._trial_id
|
||||
self._initialize_logging(experiment_dir, upload_dir, trial_config)
|
||||
|
||||
def _initialize_logging(self,
|
||||
experiment_dir=None,
|
||||
upload_dir=None,
|
||||
trial_config=None):
|
||||
if upload_dir:
|
||||
raise NotImplementedError("Upload Dir is not yet implemented.")
|
||||
|
||||
# TODO(rliaw): In other parts of the code, this is `local_dir`.
|
||||
if experiment_dir is None:
|
||||
experiment_dir = os.path.join(DEFAULT_RESULTS_DIR, "default")
|
||||
|
||||
self._experiment_dir = os.path.expanduser(experiment_dir)
|
||||
|
||||
# TODO(rliaw): Refactor `logdir` to `trial_dir`.
|
||||
self._logdir = Trial.create_logdir(self.trial_name,
|
||||
self._experiment_dir)
|
||||
self._upload_dir = upload_dir
|
||||
self.trial_config = trial_config or {}
|
||||
|
||||
# misc metadata to save as well
|
||||
self.trial_config["trial_id"] = self.trial_id
|
||||
self._logger = UnifiedLogger(self.trial_config, self._logdir)
|
||||
|
||||
def log(self, **metrics):
|
||||
"""Logs all named arguments specified in `metrics`.
|
||||
|
||||
This will log trial metrics locally, and they will be synchronized
|
||||
with the driver periodically through ray.
|
||||
|
||||
Arguments:
|
||||
metrics: named arguments with corresponding values to log.
|
||||
"""
|
||||
self._iteration += 1
|
||||
# TODO: Implement a batching mechanism for multiple calls to `log`
|
||||
# within the same iteration.
|
||||
metrics_dict = metrics.copy()
|
||||
metrics_dict.update({"trial_id": self.trial_id})
|
||||
|
||||
# TODO: Move Trainable autopopulation to a util function
|
||||
metrics_dict.setdefault(TRAINING_ITERATION, self._iteration)
|
||||
self._logger.on_result(metrics_dict)
|
||||
|
||||
def close(self):
|
||||
"""Closes loggers.
|
||||
|
||||
No need to call this when using ``tune.run``.
|
||||
"""
|
||||
self.trial_config["trial_completed"] = True
|
||||
self.trial_config["end_time"] = datetime.now().isoformat()
|
||||
# TODO(rliaw): Have Tune support updated configs
|
||||
self._logger.update_config(self.trial_config)
|
||||
self._logger.flush()
|
||||
self._logger.close()
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
"""Trial logdir (subdir of given experiment directory)"""
|
||||
return self._logdir
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable"""
|
||||
return self._trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial id for the corresponding trial of this Trainable"""
|
||||
return self._trial_id
|
||||
def TrackSession(*args, **kwargs):
|
||||
msg = "tune.track is now deprecated. Use `tune.report` instead."
|
||||
raise DeprecationWarning(msg)
|
||||
|
||||
@@ -214,7 +214,6 @@ class Trainable:
|
||||
return ""
|
||||
|
||||
def get_current_ip(self):
|
||||
logger.info("Getting current IP.")
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
return self._local_ip
|
||||
|
||||
|
||||
Reference in New Issue
Block a user