mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 08:31:18 +08:00
@@ -8,7 +8,7 @@ import threading
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIMESTEPS_TOTAL
|
||||
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,6 +28,7 @@ class StatusReporter(object):
|
||||
self._lock = threading.Lock()
|
||||
self._error = None
|
||||
self._done = False
|
||||
self._iteration = 0
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
@@ -44,6 +45,7 @@ class StatusReporter(object):
|
||||
|
||||
with self._lock:
|
||||
self._latest_result = self._last_result = kwargs.copy()
|
||||
self._iteration += 1
|
||||
|
||||
def _get_and_clear_status(self):
|
||||
if self._error:
|
||||
@@ -55,10 +57,13 @@ class StatusReporter(object):
|
||||
"last result. To avoid this, include done=True "
|
||||
"upon the last reporter call.")
|
||||
self._last_result.update(done=True)
|
||||
self._last_result.setdefault(TRAINING_ITERATION, self._iteration)
|
||||
return self._last_result
|
||||
with self._lock:
|
||||
res = self._latest_result
|
||||
self._latest_result = None
|
||||
if res:
|
||||
res.setdefault(TRAINING_ITERATION, self._iteration)
|
||||
return res
|
||||
|
||||
def _stop(self):
|
||||
|
||||
Reference in New Issue
Block a user