mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:46:57 +08:00
[tune] tune.track -> tune.report (#8388)
This commit is contained in:
@@ -5,8 +5,7 @@ import threading
|
||||
import traceback
|
||||
from six.moves import queue
|
||||
|
||||
from ray.tune import track
|
||||
from ray.tune import TuneError
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
|
||||
@@ -158,6 +157,8 @@ class FunctionRunner(Trainable):
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
session.init(self._status_reporter)
|
||||
|
||||
def entrypoint():
|
||||
return self._trainable_func(config, self._status_reporter)
|
||||
|
||||
@@ -251,6 +252,8 @@ class FunctionRunner(Trainable):
|
||||
# Check for any errors that might have been missed.
|
||||
self._report_thread_runner_error()
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def _report_thread_runner_error(self, block=False):
|
||||
try:
|
||||
err_tb_str = self._error_queue.get(
|
||||
@@ -262,32 +265,19 @@ class FunctionRunner(Trainable):
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
|
||||
use_track = False
|
||||
try:
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
logger.debug("tune.track signature detected.")
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Function inspection failed - assuming reporter signature.")
|
||||
|
||||
class WrappedFunc(FunctionRunner):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter):
|
||||
output = train_func(config, reporter)
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
output = train_func(config)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
|
||||
# If train_func returns, we need to notify the main event loop
|
||||
# of the last result while avoiding double logging. This is done
|
||||
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
|
||||
reporter(**{RESULT_DUPLICATE: True})
|
||||
return output
|
||||
|
||||
class WrappedTrackFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter):
|
||||
track.init(_tune_reporter=reporter)
|
||||
output = train_func(config)
|
||||
reporter(**{RESULT_DUPLICATE: True})
|
||||
track.shutdown()
|
||||
return output
|
||||
|
||||
return WrappedTrackFunc if use_track else WrappedFunc
|
||||
return ImplicitFunc
|
||||
|
||||
Reference in New Issue
Block a user