mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 10:11:52 +08:00
[tune] Fix default handling for timesteps (#3293)
This PR fixes an issue where previously if timesteps_this_iter = 0, then it would render as "None". Closes #3057.
This commit is contained in:
@@ -14,7 +14,8 @@ from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
EPISODES_TOTAL)
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
@@ -434,10 +435,25 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
})
|
||||
self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL])
|
||||
|
||||
def train3(config, reporter):
|
||||
def train2(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(timesteps_total=5)
|
||||
|
||||
[trial2] = run_experiments({
|
||||
"foo": {
|
||||
"run": train2,
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial2.last_result[TIMESTEPS_TOTAL], 5)
|
||||
self.assertEqual(trial2.last_result["timesteps_this_iter"], 0)
|
||||
|
||||
def train3(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(timesteps_this_iter=0, episodes_this_iter=0)
|
||||
|
||||
[trial3] = run_experiments({
|
||||
"foo": {
|
||||
"run": train3,
|
||||
@@ -446,8 +462,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5)
|
||||
self.assertEqual(trial3.last_result["timesteps_this_iter"], 0)
|
||||
self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 0)
|
||||
self.assertEqual(trial3.last_result[EPISODES_TOTAL], 0)
|
||||
|
||||
def testCheckpointDict(self):
|
||||
class TestTrain(Trainable):
|
||||
|
||||
@@ -161,14 +161,14 @@ class Trainable(object):
|
||||
result.setdefault(DONE, False)
|
||||
|
||||
# self._timesteps_total should only be tracked if increments provided
|
||||
if result.get(TIMESTEPS_THIS_ITER):
|
||||
if result.get(TIMESTEPS_THIS_ITER) is not None:
|
||||
if self._timesteps_total is None:
|
||||
self._timesteps_total = 0
|
||||
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
|
||||
self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]
|
||||
|
||||
# self._timesteps_total should only be tracked if increments provided
|
||||
if result.get(EPISODES_THIS_ITER):
|
||||
# self._episodes_total should only be tracked if increments provided
|
||||
if result.get(EPISODES_THIS_ITER) is not None:
|
||||
if self._episodes_total is None:
|
||||
self._episodes_total = 0
|
||||
self._episodes_total += result[EPISODES_THIS_ITER]
|
||||
|
||||
Reference in New Issue
Block a user