[tune] Fix default handling for timesteps (#3293)

This PR fixes an issue where previously if timesteps_this_iter = 0,
then it would render as "None".

Closes #3057.
This commit is contained in:
Richard Liaw
2018-11-12 15:52:17 -08:00
committed by GitHub
parent 49e2085d78
commit e37891d79d
2 changed files with 23 additions and 7 deletions
+20 -4
View File
@@ -14,7 +14,8 @@ from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
EPISODES_TOTAL)
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, Resources
@@ -434,10 +435,25 @@ class TrainableFunctionApiTest(unittest.TestCase):
})
self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL])
def train3(config, reporter):
def train2(config, reporter):
for i in range(10):
reporter(timesteps_total=5)
[trial2] = run_experiments({
"foo": {
"run": train2,
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial2.last_result[TIMESTEPS_TOTAL], 5)
self.assertEqual(trial2.last_result["timesteps_this_iter"], 0)
def train3(config, reporter):
for i in range(10):
reporter(timesteps_this_iter=0, episodes_this_iter=0)
[trial3] = run_experiments({
"foo": {
"run": train3,
@@ -446,8 +462,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
},
}
})
self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5)
self.assertEqual(trial3.last_result["timesteps_this_iter"], 0)
self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 0)
self.assertEqual(trial3.last_result[EPISODES_TOTAL], 0)
def testCheckpointDict(self):
class TestTrain(Trainable):
+3 -3
View File
@@ -161,14 +161,14 @@ class Trainable(object):
result.setdefault(DONE, False)
# self._timesteps_total should only be tracked if increments provided
if result.get(TIMESTEPS_THIS_ITER):
if result.get(TIMESTEPS_THIS_ITER) is not None:
if self._timesteps_total is None:
self._timesteps_total = 0
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]
# self._timesteps_total should only be tracked if increments provided
if result.get(EPISODES_THIS_ITER):
# self._episodes_total should only be tracked if increments provided
if result.get(EPISODES_THIS_ITER) is not None:
if self._episodes_total is None:
self._episodes_total = 0
self._episodes_total += result[EPISODES_THIS_ITER]