mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:58:26 +08:00
[tune] Distributed example + walkthrough (#5157)
This commit is contained in:
@@ -33,10 +33,11 @@ class StatusReporter(object):
|
||||
>>> reporter(timesteps_this_iter=1)
|
||||
"""
|
||||
|
||||
def __init__(self, result_queue, continue_semaphore):
|
||||
def __init__(self, result_queue, continue_semaphore, logdir=None):
|
||||
self._queue = result_queue
|
||||
self._last_report_time = None
|
||||
self._continue_semaphore = continue_semaphore
|
||||
self._logdir = logdir
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
@@ -77,6 +78,10 @@ class StatusReporter(object):
|
||||
def _start(self):
|
||||
self._last_report_time = time.time()
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
return self._logdir
|
||||
|
||||
|
||||
class _RunnerThread(threading.Thread):
|
||||
"""Supervisor thread that runs your script."""
|
||||
@@ -131,8 +136,8 @@ class FunctionRunner(Trainable):
|
||||
# reporting to block until finished.
|
||||
self._error_queue = queue.Queue(1)
|
||||
|
||||
self._status_reporter = StatusReporter(self._results_queue,
|
||||
self._continue_semaphore)
|
||||
self._status_reporter = StatusReporter(
|
||||
self._results_queue, self._continue_semaphore, self.logdir)
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user