[tune] tune.track -> tune.report (#8388)

This commit is contained in:
Richard Liaw
2020-05-16 12:55:08 -07:00
committed by GitHub
parent c8cd716295
commit 67c01455fe
20 changed files with 228 additions and 395 deletions
+14 -24
View File
@@ -5,8 +5,7 @@ import threading
import traceback
from six.moves import queue
from ray.tune import track
from ray.tune import TuneError
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
@@ -158,6 +157,8 @@ class FunctionRunner(Trainable):
self._last_result = {}
config = config.copy()
session.init(self._status_reporter)
def entrypoint():
return self._trainable_func(config, self._status_reporter)
@@ -251,6 +252,8 @@ class FunctionRunner(Trainable):
# Check for any errors that might have been missed.
self._report_thread_runner_error()
session.shutdown()
def _report_thread_runner_error(self, block=False):
try:
err_tb_str = self._error_queue.get(
@@ -262,32 +265,19 @@ class FunctionRunner(Trainable):
def wrap_function(train_func):
use_track = False
try:
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
logger.debug("tune.track signature detected.")
except Exception:
logger.info(
"Function inspection failed - assuming reporter signature.")
class WrappedFunc(FunctionRunner):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
output = train_func(config, reporter)
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
output = train_func(config)
else:
output = train_func(config, reporter)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
reporter(**{RESULT_DUPLICATE: True})
return output
class WrappedTrackFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
track.init(_tune_reporter=reporter)
output = train_func(config)
reporter(**{RESULT_DUPLICATE: True})
track.shutdown()
return output
return WrappedTrackFunc if use_track else WrappedFunc
return ImplicitFunc