diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 2ffbe87fe..6b142d354 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -14,7 +14,8 @@ from ray.tune import register_env, register_trainable, run_experiments from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.registry import _global_registry, TRAINABLE_CLASS -from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE +from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, + EPISODES_TOTAL) from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment from ray.tune.trial import Trial, Resources @@ -434,10 +435,25 @@ class TrainableFunctionApiTest(unittest.TestCase): }) self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL]) - def train3(config, reporter): + def train2(config, reporter): for i in range(10): reporter(timesteps_total=5) + [trial2] = run_experiments({ + "foo": { + "run": train2, + "config": { + "script_min_iter_time_s": 0, + }, + } + }) + self.assertEqual(trial2.last_result[TIMESTEPS_TOTAL], 5) + self.assertEqual(trial2.last_result["timesteps_this_iter"], 0) + + def train3(config, reporter): + for i in range(10): + reporter(timesteps_this_iter=0, episodes_this_iter=0) + [trial3] = run_experiments({ "foo": { "run": train3, @@ -446,8 +462,8 @@ class TrainableFunctionApiTest(unittest.TestCase): }, } }) - self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5) - self.assertEqual(trial3.last_result["timesteps_this_iter"], 0) + self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 0) + self.assertEqual(trial3.last_result[EPISODES_TOTAL], 0) def testCheckpointDict(self): class TestTrain(Trainable): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 5d5e682e7..5824c5221 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -161,14 +161,14 @@ class Trainable(object): result.setdefault(DONE, False) # self._timesteps_total should only be tracked if increments provided - if result.get(TIMESTEPS_THIS_ITER): + if result.get(TIMESTEPS_THIS_ITER) is not None: if self._timesteps_total is None: self._timesteps_total = 0 self._timesteps_total += result[TIMESTEPS_THIS_ITER] self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] - # self._timesteps_total should only be tracked if increments provided - if result.get(EPISODES_THIS_ITER): + # self._episodes_total should only be tracked if increments provided + if result.get(EPISODES_THIS_ITER) is not None: if self._episodes_total is None: self._episodes_total = 0 self._episodes_total += result[EPISODES_THIS_ITER]