[tune] tune.track -> tune.report (#8388)

This commit is contained in:
Richard Liaw
2020-05-16 12:55:08 -07:00
committed by GitHub
parent c8cd716295
commit 67c01455fe
20 changed files with 228 additions and 395 deletions
+8 -23
View File
@@ -7,33 +7,18 @@ from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.session import (report, get_trial_dir, get_trial_name,
get_trial_id)
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn, loguniform)
__all__ = [
"Trainable",
"DurableTrainable",
"TuneError",
"grid_search",
"register_env",
"register_trainable",
"run",
"run_experiments",
"Stopper",
"Experiment",
"function",
"sample_from",
"track",
"uniform",
"choice",
"randint",
"randn",
"loguniform",
"ExperimentAnalysis",
"Analysis",
"CLIReporter",
"JupyterNotebookReporter",
"ProgressReporter",
"Trainable", "DurableTrainable", "TuneError", "grid_search",
"register_env", "register_trainable", "run", "run_experiments", "Stopper",
"Experiment", "function", "sample_from", "track", "uniform", "choice",
"randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis",
"CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
"get_trial_dir", "get_trial_name", "get_trial_id"
]
+2 -2
View File
@@ -10,7 +10,7 @@ from ray import tune
def LightGBMCallback(env):
"""Assumes that `valid_0` is the target validation score."""
_, metric, score, _ = env.evaluation_result_list[0]
tune.track.log(**{metric: score})
tune.report(**{metric: score})
def train_breast_cancer(config):
@@ -27,7 +27,7 @@ def train_breast_cancer(config):
callbacks=[LightGBMCallback])
preds = gbm.predict(test_x)
pred_labels = np.rint(preds)
tune.track.log(
tune.report(
mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels),
done=True)
+1 -2
View File
@@ -20,9 +20,8 @@ def easy_objective(config):
result = dict(
timesteps_total=i,
mean_loss=(config["height"] - 14)**2 - abs(config["width"] - 3))
tune.track.log(**result)
tune.report(**result)
time.sleep(0.02)
tune.track.log(done=True)
if __name__ == "__main__":
+5 -4
View File
@@ -12,7 +12,6 @@ from torchvision import datasets, transforms
import ray
from ray import tune
from ray.tune import track
from ray.tune.schedulers import AsyncHyperBandScheduler
# Change these values if you want the training to run quicker or slower.
@@ -33,7 +32,8 @@ class ConvNet(nn.Module):
return F.log_softmax(x, dim=1)
def train(model, optimizer, train_loader, device=torch.device("cpu")):
def train(model, optimizer, train_loader, device=None):
device = device or torch.device("cpu")
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx * len(data) > EPOCH_SIZE:
@@ -46,7 +46,8 @@ def train(model, optimizer, train_loader, device=torch.device("cpu")):
optimizer.step()
def test(model, data_loader, device=torch.device("cpu")):
def test(model, data_loader, device=None):
device = device or torch.device("cpu")
model.eval()
correct = 0
total = 0
@@ -99,7 +100,7 @@ def train_mnist(config):
while True:
train(model, optimizer, train_loader, device)
acc = test(model, test_loader, device)
track.log(mean_accuracy=acc)
tune.report(mean_accuracy=acc)
if __name__ == "__main__":
-60
View File
@@ -1,60 +0,0 @@
import argparse
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from ray.tune import track
from ray.tune.integration.keras import TuneReporterCallback
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)")
parser.add_argument(
"--momentum",
type=float,
default=0.5,
metavar="M",
help="SGD momentum (default: 0.5)")
parser.add_argument(
"--hidden", type=int, default=64, help="Size of hidden layer.")
args, _ = parser.parse_known_args()
def train_mnist(args):
track.init(trial_name="track-example", trial_config=vars(args))
batch_size = 128
num_classes = 10
epochs = 1 if args.smoke_test else 12
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(args.hidden, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(num_classes, activation="softmax")
])
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum),
metrics=["accuracy"])
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
callbacks=[TuneReporterCallback()])
track.shutdown()
if __name__ == "__main__":
train_mnist(args)
+3 -3
View File
@@ -10,7 +10,7 @@ parser.add_argument(
args, _ = parser.parse_known_args()
def train_mnist(config, reporter):
def train_mnist(config):
# https://github.com/tensorflow/tensorflow/issues/32159
import tensorflow as tf
batch_size = 128
@@ -39,7 +39,7 @@ def train_mnist(config, reporter):
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks=[TuneReporterCallback(reporter)])
callbacks=[TuneReporterCallback()])
if __name__ == "__main__":
@@ -48,7 +48,7 @@ if __name__ == "__main__":
from ray.tune.schedulers import AsyncHyperBandScheduler
mnist.load_data() # we do this on the driver because it's not threadsafe
ray.init(num_cpus=2 if args.smoke_test else None)
ray.init(num_cpus=4 if args.smoke_test else None)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration",
metric="mean_accuracy",
+2 -2
View File
@@ -8,7 +8,7 @@ from ray import tune
def XGBCallback(env):
tune.track.log(**dict(env.evaluation_result_list))
tune.report(**dict(env.evaluation_result_list))
def train_breast_cancer(config):
@@ -21,7 +21,7 @@ def train_breast_cancer(config):
config, train_set, evals=[(test_set, "eval")], callbacks=[XGBCallback])
preds = bst.predict(test_set)
pred_labels = np.rint(preds)
tune.track.log(
tune.report(
mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels),
done=True)
+14 -24
View File
@@ -5,8 +5,7 @@ import threading
import traceback
from six.moves import queue
from ray.tune import track
from ray.tune import TuneError
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
@@ -158,6 +157,8 @@ class FunctionRunner(Trainable):
self._last_result = {}
config = config.copy()
session.init(self._status_reporter)
def entrypoint():
return self._trainable_func(config, self._status_reporter)
@@ -251,6 +252,8 @@ class FunctionRunner(Trainable):
# Check for any errors that might have been missed.
self._report_thread_runner_error()
session.shutdown()
def _report_thread_runner_error(self, block=False):
try:
err_tb_str = self._error_queue.get(
@@ -262,32 +265,19 @@ class FunctionRunner(Trainable):
def wrap_function(train_func):
use_track = False
try:
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
logger.debug("tune.track signature detected.")
except Exception:
logger.info(
"Function inspection failed - assuming reporter signature.")
class WrappedFunc(FunctionRunner):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
output = train_func(config, reporter)
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
output = train_func(config)
else:
output = train_func(config, reporter)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
reporter(**{RESULT_DUPLICATE: True})
return output
class WrappedTrackFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
track.init(_tune_reporter=reporter)
output = train_func(config)
reporter(**{RESULT_DUPLICATE: True})
track.shutdown()
return output
return WrappedTrackFunc if use_track else WrappedFunc
return ImplicitFunc
+12 -11
View File
@@ -1,27 +1,26 @@
from tensorflow import keras
from ray.tune import track
class TuneReporterCallback(keras.callbacks.Callback):
"""Tune Callback for Keras."""
def __init__(self, reporter=None, freq="batch", logs={}):
def __init__(self, reporter=None, freq="batch", logs=None):
"""Initializer.
Args:
reporter (StatusReporter|tune.track.log|None): Tune object for
returning results.
freq (str): Sets the frequency of reporting intermediate results.
One of ["batch", "epoch"].
"""
self.reporter = reporter or track.log
self.iteration = 0
logs = logs or {}
if freq not in ["batch", "epoch"]:
raise ValueError("{} not supported as a frequency.".format(freq))
self.freq = freq
super(TuneReporterCallback, self).__init__()
def on_batch_end(self, batch, logs={}):
def on_batch_end(self, batch, logs=None):
from ray import tune
logs = logs or {}
if not self.freq == "batch":
return
self.iteration += 1
@@ -29,11 +28,13 @@ class TuneReporterCallback(keras.callbacks.Callback):
if "loss" in metric and "neg_" not in metric:
logs["neg_" + metric] = -logs[metric]
if "acc" in logs:
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
tune.report(keras_info=logs, mean_accuracy=logs["acc"])
else:
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy"))
def on_epoch_end(self, batch, logs={}):
def on_epoch_end(self, batch, logs=None):
from ray import tune
logs = logs or {}
if not self.freq == "epoch":
return
self.iteration += 1
@@ -41,6 +42,6 @@ class TuneReporterCallback(keras.callbacks.Callback):
if "loss" in metric and "neg_" not in metric:
logs["neg_" + metric] = -logs[metric]
if "acc" in logs:
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
tune.report(keras_info=logs, mean_accuracy=logs["acc"])
else:
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy"))
+122
View File
@@ -0,0 +1,122 @@
import logging
logger = logging.getLogger(__name__)
_session = None
class _ReporterSession:
def __init__(self, tune_reporter):
self.tune_reporter = tune_reporter
def report(self, **metrics):
return self.tune_reporter(**metrics)
@property
def logdir(self):
"""Trial logdir (subdir of given experiment directory)"""
return self.tune_reporter.logdir
@property
def trial_name(self):
"""Trial name for the corresponding trial of this Trainable"""
return self.tune_reporter.trial_name
@property
def trial_id(self):
"""Trial id for the corresponding trial of this Trainable"""
return self.tune_reporter.trial_id
def get_session():
global _session
if _session is None:
raise ValueError(
"Session not detected. You should not be calling this function "
"outside `tune.run` or while using the class API. ")
return _session
def init(reporter, ignore_reinit_error=True):
"""Initializes the global trial context for this process."""
global _session
if _session is not None:
# TODO(ng): would be nice to stack crawl at creation time to report
# where that initial trial was created, and that creation line
# info is helpful to keep around anyway.
reinit_msg = (
"A Tune session already exists in the current process. "
"If you are using ray.init(local_mode=True), "
"you must set ray.init(..., num_cpus=1, num_gpus=1) to limit "
"available concurrency.")
if ignore_reinit_error:
logger.warning(reinit_msg)
return
else:
raise ValueError(reinit_msg)
_session = _ReporterSession(reporter)
def shutdown():
"""Cleans up the trial and removes it from the global context."""
global _session
_session = None
def report(**kwargs):
"""Logs all keyword arguments.
.. code-block:: python
import time
from ray import tune
def run_me(config):
for iter in range(100):
time.sleep(1)
tune.report(hello="world", ray="tune")
analysis = tune.run(run_me)
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
"""
_session = get_session()
return _session.report(**kwargs)
def get_trial_dir():
"""Returns the directory where trial results are saved.
For function API use only. Do not call this method in the Class API. Use
`self.logdir` instead.
"""
_session = get_session()
return _session.logdir
def get_trial_name():
"""Trial name for the corresponding trial of this Trainable.
For function API use only. Do not call this method in the Class API. Use
`self.trial_name` instead.
"""
_session = get_session()
return _session.trial_name
def get_trial_id():
"""Trial id for the corresponding trial of this Trainable.
For function API use only. Do not call this method in the Class API. Use
`self.trial_id` instead.
"""
_session = get_session()
return _session.trial_id
__all__ = ["report", "get_trial_dir", "get_trial_name", "get_trial_id"]
+1 -1
View File
@@ -23,7 +23,7 @@ def train_mnist(config):
for i in range(10):
train(model, optimizer, train_loader)
acc = test(model, test_loader)
tune.track.log(mean_accuracy=acc)
tune.report(mean_accuracy=acc)
analysis = tune.run(
+3 -3
View File
@@ -572,7 +572,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testReportInfinity(self):
def train(config, reporter):
for i in range(100):
for _ in range(100):
reporter(mean_accuracy=float("inf"))
register_trainable("f1", train)
@@ -606,8 +606,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
def track_train(config):
tune.track.log(
name=tune.track.trial_name(), trial_id=tune.track.trial_id())
tune.report(
name=tune.get_trial_name(), trial_id=tune.get_trial_id())
analysis = tune.run(track_train, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
+11 -47
View File
@@ -1,11 +1,9 @@
import os
import pandas as pd
import unittest
import ray
from ray import tune
from ray.tune import track
from ray.tune.result import EXPR_PARAM_FILE, EXPR_RESULT_FILE
from ray.tune import session
def _check_json_val(fname, key, val):
@@ -16,46 +14,24 @@ def _check_json_val(fname, key, val):
class TrackApiTest(unittest.TestCase):
def tearDown(self):
track.shutdown()
session.shutdown()
ray.shutdown()
def testSessionInitShutdown(self):
self.assertTrue(track._session is None)
self.assertTrue(session._session is None)
# Checks that the singleton _session is created/destroyed
# by track.init() and track.shutdown()
# by session.init() and session.shutdown()
for _ in range(2):
# do it twice to see that we can reopen the session
track.init(trial_name="test_init")
self.assertTrue(track._session is not None)
track.shutdown()
self.assertTrue(track._session is None)
session.init(reporter=None)
self.assertTrue(session._session is not None)
session.shutdown()
self.assertTrue(session._session is None)
def testLogCreation(self):
"""Checks that track.init() starts logger and creates log files."""
track.init(trial_name="test_init")
session = track.get_session()
self.assertTrue(session is not None)
self.assertTrue(os.path.isdir(session.logdir))
params_path = os.path.join(session.logdir, EXPR_PARAM_FILE)
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
self.assertTrue(os.path.exists(params_path))
self.assertTrue(os.path.exists(result_path))
self.assertTrue(session.logdir == track.trial_dir())
def testMetric(self):
track.init(trial_name="test_log")
session = track.get_session()
for i in range(5):
track.log(test=i)
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
self.assertTrue(_check_json_val(result_path, "test", i))
def testRayOutput(self):
"""Checks that local and remote log format are the same."""
def testSoftDeprecation(self):
"""Checks that tune.track.log code does not break."""
from ray.tune import track
ray.init()
def testme(config):
@@ -67,18 +43,6 @@ class TrackApiTest(unittest.TestCase):
self.assertTrue(trial_res["hi"], "test")
self.assertTrue(trial_res["training_iteration"], 5)
def testLocalMetrics(self):
"""Checks that metric state is updated correctly."""
track.init(trial_name="test_logs")
session = track.get_session()
self.assertEqual(set(session.trial_config.keys()), {"trial_id"})
result_path = os.path.join(session.logdir, EXPR_RESULT_FILE)
track.log(test=1)
self.assertTrue(_check_json_val(result_path, "test", 1))
track.log(iteration=1, test=2)
self.assertTrue(_check_json_val(result_path, "test", 2))
if __name__ == "__main__":
import pytest
+1 -2
View File
@@ -9,7 +9,6 @@ import torch.optim as optim
from torchvision import datasets
from ray import tune
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test
# __tutorial_imports_end__
@@ -26,7 +25,7 @@ def train_mnist(config):
for i in range(10):
train(model, optimizer, train_loader)
acc = test(model, test_loader)
track.log(mean_accuracy=acc)
tune.report(mean_accuracy=acc)
if i % 5 == 0:
# This saves the model to the trial directory
torch.save(model, "./model.pth")
+31 -78
View File
@@ -1,105 +1,58 @@
import logging
from ray.tune.track.session import TrackSession as _TrackSession
from ray.tune import session
logger = logging.getLogger(__name__)
_session = None
warned = False
def _deprecation_warning(call=None, alternative_call=None, soft=True):
msg = "tune.track is now deprecated."
if call and alternative_call:
msg = "tune.track.{} is now deprecated.".format(call)
msg += " Use `tune.{}` instead.".format(alternative_call)
global warned
if soft:
msg += " This warning will throw an error in a future version of Ray."
if not warned:
logger.warning(msg)
warned = True
else:
raise DeprecationWarning(msg)
def get_session():
global _session
if not _session:
raise ValueError("Session not detected. Try `track.init()`?")
return _session
_deprecation_warning(soft=False)
def init(ignore_reinit_error=True, **session_kwargs):
"""Initializes the global trial context for this process.
This creates a TrackSession object and the corresponding hooks for logging.
Examples:
>>> from ray.tune import track
>>> track.init()
"""
global _session
if _session:
# TODO(ng): would be nice to stack crawl at creation time to report
# where that initial trial was created, and that creation line
# info is helpful to keep around anyway.
reinit_msg = "A session already exists in the current context."
if ignore_reinit_error:
if not _session.is_tune_session:
logger.warning(reinit_msg)
return
else:
raise ValueError(reinit_msg)
_session = _TrackSession(**session_kwargs)
_deprecation_warning(soft=False)
def shutdown():
"""Cleans up the trial and removes it from the global context."""
global _session
if _session:
_session.close()
_session = None
_deprecation_warning(soft=False)
def log(**kwargs):
"""Logs all keyword arguments.
.. code-block:: python
import time
from ray import tune
from ray.tune import track
def run_me(config):
for iter in range(100):
time.sleep(1)
track.log(hello="world", ray="tune")
analysis = tune.run(run_me)
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
"""
_session = get_session()
return _session.log(**kwargs)
_deprecation_warning(call="log", alternative_call="report", soft=True)
session.report(**kwargs)
def trial_dir():
"""Returns the directory where trial results are saved.
This includes json data containing the session's parameters and metrics.
"""
_session = get_session()
return _session.logdir
_deprecation_warning(
call="trial_dir", alternative_call="get_trial_dir", soft=True)
return session.get_trial_dir()
def trial_name():
"""Trial name for the corresponding trial of this Trainable.
This is not set if not using Tune.
"""
_session = get_session()
return _session.trial_name
_deprecation_warning(
call="trial_name", alternative_call="get_trial_name", soft=True)
return session.get_trial_name()
def trial_id():
"""Trial id for the corresponding trial of this Trainable.
This is not set if not using Tune.
"""
_session = get_session()
return _session.trial_id
__all__ = [
"session", "log", "trial_dir", "init", "shutdown", "trial_name", "trial_id"
]
_deprecation_warning(
call="trial_id", alternative_call="get_trial_id", soft=True)
return session.get_trial_id()
+3 -123
View File
@@ -1,123 +1,3 @@
import os
from datetime import datetime
from ray.tune.trial import Trial
from ray.tune.result import DEFAULT_RESULTS_DIR, TRAINING_ITERATION
from ray.tune.logger import UnifiedLogger, Logger
class _ReporterHook(Logger):
def __init__(self, tune_reporter):
self.tune_reporter = tune_reporter
def on_result(self, metrics):
return self.tune_reporter(**metrics)
class TrackSession:
"""Manages results for a single session.
Represents a single Trial in an experiment. This is automatically
created when using ``tune.run``.
Attributes:
trial_name (str): Custom trial name.
experiment_dir (str): Directory where results for all trials
are stored. Each session is stored into a unique directory
inside experiment_dir.
upload_dir (str): Directory to sync results to.
trial_config (dict): Parameters that will be logged to disk.
_tune_reporter (StatusReporter): For rerouting when using Tune.
Will not instantiate logging if not None.
"""
def __init__(self,
trial_name=None,
experiment_dir=None,
upload_dir=None,
trial_config=None,
_tune_reporter=None):
self._experiment_dir = None
self._logdir = None
self._upload_dir = None
self.trial_config = None
self._iteration = -1
self.is_tune_session = bool(_tune_reporter)
if self.is_tune_session:
self._logger = _ReporterHook(_tune_reporter)
self._logdir = _tune_reporter.logdir
self._trial_name = _tune_reporter.trial_name
self._trial_id = _tune_reporter.trial_id
else:
self._trial_id = Trial.generate_id()
self._trial_name = trial_name or self._trial_id
self._initialize_logging(experiment_dir, upload_dir, trial_config)
def _initialize_logging(self,
experiment_dir=None,
upload_dir=None,
trial_config=None):
if upload_dir:
raise NotImplementedError("Upload Dir is not yet implemented.")
# TODO(rliaw): In other parts of the code, this is `local_dir`.
if experiment_dir is None:
experiment_dir = os.path.join(DEFAULT_RESULTS_DIR, "default")
self._experiment_dir = os.path.expanduser(experiment_dir)
# TODO(rliaw): Refactor `logdir` to `trial_dir`.
self._logdir = Trial.create_logdir(self.trial_name,
self._experiment_dir)
self._upload_dir = upload_dir
self.trial_config = trial_config or {}
# misc metadata to save as well
self.trial_config["trial_id"] = self.trial_id
self._logger = UnifiedLogger(self.trial_config, self._logdir)
def log(self, **metrics):
"""Logs all named arguments specified in `metrics`.
This will log trial metrics locally, and they will be synchronized
with the driver periodically through ray.
Arguments:
metrics: named arguments with corresponding values to log.
"""
self._iteration += 1
# TODO: Implement a batching mechanism for multiple calls to `log`
# within the same iteration.
metrics_dict = metrics.copy()
metrics_dict.update({"trial_id": self.trial_id})
# TODO: Move Trainable autopopulation to a util function
metrics_dict.setdefault(TRAINING_ITERATION, self._iteration)
self._logger.on_result(metrics_dict)
def close(self):
"""Closes loggers.
No need to call this when using ``tune.run``.
"""
self.trial_config["trial_completed"] = True
self.trial_config["end_time"] = datetime.now().isoformat()
# TODO(rliaw): Have Tune support updated configs
self._logger.update_config(self.trial_config)
self._logger.flush()
self._logger.close()
@property
def logdir(self):
"""Trial logdir (subdir of given experiment directory)"""
return self._logdir
@property
def trial_name(self):
"""Trial name for the corresponding trial of this Trainable"""
return self._trial_name
@property
def trial_id(self):
"""Trial id for the corresponding trial of this Trainable"""
return self._trial_id
def TrackSession(*args, **kwargs):
msg = "tune.track is now deprecated. Use `tune.report` instead."
raise DeprecationWarning(msg)
-1
View File
@@ -214,7 +214,6 @@ class Trainable:
return ""
def get_current_ip(self):
logger.info("Getting current IP.")
self._local_ip = ray.services.get_node_ip_address()
return self._local_ip