[tune] Trial reporter fix (#3951)

Fixes #3949.
This commit is contained in:
Andrew Tan
2019-02-13 01:03:54 -08:00
committed by Richard Liaw
parent 3a7fb182cc
commit 57dcd3033e
4 changed files with 35 additions and 7 deletions
+6 -1
View File
@@ -8,7 +8,7 @@ import threading
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
logger = logging.getLogger(__name__)
@@ -28,6 +28,7 @@ class StatusReporter(object):
self._lock = threading.Lock()
self._error = None
self._done = False
self._iteration = 0
def __call__(self, **kwargs):
"""Report updated training status.
@@ -44,6 +45,7 @@ class StatusReporter(object):
with self._lock:
self._latest_result = self._last_result = kwargs.copy()
self._iteration += 1
def _get_and_clear_status(self):
if self._error:
@@ -55,10 +57,13 @@ class StatusReporter(object):
"last result. To avoid this, include done=True "
"upon the last reporter call.")
self._last_result.update(done=True)
self._last_result.setdefault(TRAINING_ITERATION, self._iteration)
return self._last_result
with self._lock:
res = self._latest_result
self._latest_result = None
if res:
res.setdefault(TRAINING_ITERATION, self._iteration)
return res
def _stop(self):