mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 13:23:53 +08:00
@@ -8,7 +8,7 @@ import threading
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIMESTEPS_TOTAL
|
||||
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,6 +28,7 @@ class StatusReporter(object):
|
||||
self._lock = threading.Lock()
|
||||
self._error = None
|
||||
self._done = False
|
||||
self._iteration = 0
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
@@ -44,6 +45,7 @@ class StatusReporter(object):
|
||||
|
||||
with self._lock:
|
||||
self._latest_result = self._last_result = kwargs.copy()
|
||||
self._iteration += 1
|
||||
|
||||
def _get_and_clear_status(self):
|
||||
if self._error:
|
||||
@@ -55,10 +57,13 @@ class StatusReporter(object):
|
||||
"last result. To avoid this, include done=True "
|
||||
"upon the last reporter call.")
|
||||
self._last_result.update(done=True)
|
||||
self._last_result.setdefault(TRAINING_ITERATION, self._iteration)
|
||||
return self._last_result
|
||||
with self._lock:
|
||||
res = self._latest_result
|
||||
self._latest_result = None
|
||||
if res:
|
||||
res.setdefault(TRAINING_ITERATION, self._iteration)
|
||||
return res
|
||||
|
||||
def _stop(self):
|
||||
|
||||
@@ -96,7 +96,8 @@ class BayesOptSearch(SuggestionAlgorithm):
|
||||
self.optimizer.register(
|
||||
params=self._live_trial_mapping[trial_id],
|
||||
target=result[self._reward_attr])
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def _num_live_trials(self):
|
||||
return len(self._live_trial_mapping)
|
||||
|
||||
@@ -19,7 +19,7 @@ from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
EPISODES_TOTAL)
|
||||
EPISODES_TOTAL, TRAINING_ITERATION)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
@@ -560,6 +560,28 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
def testIterationCounter(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(itr=i, done=i == 99)
|
||||
|
||||
register_trainable("exp", train)
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"config": {
|
||||
"iterations": 100,
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 100
|
||||
},
|
||||
}
|
||||
}
|
||||
[trial] = run_experiments(config)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
|
||||
self.assertEqual(trial.last_result["itr"], 99)
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -18,9 +18,9 @@ import uuid
|
||||
|
||||
import ray
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
|
||||
EPISODES_THIS_ITER, EPISODES_TOTAL)
|
||||
from ray.tune.result import (
|
||||
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
|
||||
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION)
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -181,6 +181,7 @@ class Trainable(object):
|
||||
# self._timesteps_total should not override user-provided total
|
||||
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
|
||||
result.setdefault(EPISODES_TOTAL, self._episodes_total)
|
||||
result.setdefault(TRAINING_ITERATION, self._iteration)
|
||||
|
||||
# Provides auto-filled neg_mean_loss for avoiding regressions
|
||||
if result.get("mean_loss"):
|
||||
@@ -191,7 +192,6 @@ class Trainable(object):
|
||||
experiment_id=self._experiment_id,
|
||||
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
timestamp=int(time.mktime(now.timetuple())),
|
||||
training_iteration=self._iteration,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self._time_total,
|
||||
pid=os.getpid(),
|
||||
|
||||
Reference in New Issue
Block a user