[tune] Trial reporter fix (#3951)

Fixes #3949.
This commit is contained in:
Andrew Tan
2019-02-13 01:03:54 -08:00
committed by Richard Liaw
parent 3a7fb182cc
commit 57dcd3033e
4 changed files with 35 additions and 7 deletions
+6 -1
View File
@@ -8,7 +8,7 @@ import threading
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
logger = logging.getLogger(__name__)
@@ -28,6 +28,7 @@ class StatusReporter(object):
self._lock = threading.Lock()
self._error = None
self._done = False
self._iteration = 0
def __call__(self, **kwargs):
"""Report updated training status.
@@ -44,6 +45,7 @@ class StatusReporter(object):
with self._lock:
self._latest_result = self._last_result = kwargs.copy()
self._iteration += 1
def _get_and_clear_status(self):
if self._error:
@@ -55,10 +57,13 @@ class StatusReporter(object):
"last result. To avoid this, include done=True "
"upon the last reporter call.")
self._last_result.update(done=True)
self._last_result.setdefault(TRAINING_ITERATION, self._iteration)
return self._last_result
with self._lock:
res = self._latest_result
self._latest_result = None
if res:
res.setdefault(TRAINING_ITERATION, self._iteration)
return res
def _stop(self):
+2 -1
View File
@@ -96,7 +96,8 @@ class BayesOptSearch(SuggestionAlgorithm):
self.optimizer.register(
params=self._live_trial_mapping[trial_id],
target=result[self._reward_attr])
del self._live_trial_mapping[trial_id]
del self._live_trial_mapping[trial_id]
def _num_live_trials(self):
return len(self._live_trial_mapping)
+23 -1
View File
@@ -19,7 +19,7 @@ from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
EPISODES_TOTAL)
EPISODES_TOTAL, TRAINING_ITERATION)
from ray.tune.logger import Logger
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
@@ -560,6 +560,28 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testIterationCounter(self):
def train(config, reporter):
for i in range(100):
reporter(itr=i, done=i == 99)
register_trainable("exp", train)
config = {
"my_exp": {
"run": "exp",
"config": {
"iterations": 100,
},
"stop": {
"timesteps_total": 100
},
}
}
[trial] = run_experiments(config)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
self.assertEqual(trial.last_result["itr"], 99)
class RunExperimentTest(unittest.TestCase):
def setUp(self):
+4 -4
View File
@@ -18,9 +18,9 @@ import uuid
import ray
from ray.tune.logger import UnifiedLogger
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
EPISODES_THIS_ITER, EPISODES_TOTAL)
from ray.tune.result import (
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION)
from ray.tune.trial import Resources
logger = logging.getLogger(__name__)
@@ -181,6 +181,7 @@ class Trainable(object):
# self._timesteps_total should not override user-provided total
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
result.setdefault(EPISODES_TOTAL, self._episodes_total)
result.setdefault(TRAINING_ITERATION, self._iteration)
# Provides auto-filled neg_mean_loss for avoiding regressions
if result.get("mean_loss"):
@@ -191,7 +192,6 @@ class Trainable(object):
experiment_id=self._experiment_id,
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
timestamp=int(time.mktime(now.timetuple())),
training_iteration=self._iteration,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total,
pid=os.getpid(),