mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:46:57 +08:00
[tune] add accessible trial_info (#7378)
* add accessible trial_info * trial name and info * doc * fix gp * Update doc/source/tune-package-ref.rst * Apply suggestions from code review * fix * trial * fixtest * testfix
This commit is contained in:
@@ -29,10 +29,17 @@ class StatusReporter:
|
||||
>>> reporter(timesteps_this_iter=1)
|
||||
"""
|
||||
|
||||
def __init__(self, result_queue, continue_semaphore, logdir=None):
|
||||
def __init__(self,
|
||||
result_queue,
|
||||
continue_semaphore,
|
||||
trial_name=None,
|
||||
trial_id=None,
|
||||
logdir=None):
|
||||
self._queue = result_queue
|
||||
self._last_report_time = None
|
||||
self._continue_semaphore = continue_semaphore
|
||||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
@@ -78,6 +85,16 @@ class StatusReporter:
|
||||
def logdir(self):
|
||||
return self._logdir
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable."""
|
||||
return self._trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial id for the corresponding trial of this Trainable."""
|
||||
return self._trial_id
|
||||
|
||||
|
||||
class _RunnerThread(threading.Thread):
|
||||
"""Supervisor thread that runs your script."""
|
||||
@@ -133,7 +150,11 @@ class FunctionRunner(Trainable):
|
||||
self._error_queue = queue.Queue(1)
|
||||
|
||||
self._status_reporter = StatusReporter(
|
||||
self._results_queue, self._continue_semaphore, self.logdir)
|
||||
self._results_queue,
|
||||
self._continue_semaphore,
|
||||
trial_name=self.trial_name,
|
||||
trial_id=self.trial_id,
|
||||
logdir=self.logdir)
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user