mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 02:59:52 +08:00
[tune][minor] Fixes (#1383)
This commit is contained in:
@@ -11,7 +11,7 @@ import os
|
||||
from collections import namedtuple
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import NoopLogger, UnifiedLogger
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
||||
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
|
||||
|
||||
|
||||
@@ -285,6 +285,14 @@ class Trial(object):
|
||||
print("Error restoring runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_result = result
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
def _setup_runner(self):
|
||||
self.status = Trial.RUNNING
|
||||
trainable_cls = get_registry().get(
|
||||
|
||||
@@ -8,7 +8,6 @@ import time
|
||||
import traceback
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import pretty_print
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
@@ -157,35 +156,33 @@ class TrialRunner(object):
|
||||
# have been lost
|
||||
|
||||
def _process_events(self):
|
||||
[result_id], _ = ray.wait(list(self._running.keys()))
|
||||
trial = self._running[result_id]
|
||||
del self._running[result_id]
|
||||
[result_id], _ = ray.wait(list(self._running))
|
||||
trial = self._running.pop(result_id)
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
trial.result_logger.on_result(result)
|
||||
print("TrainingResult for {}:".format(trial))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
trial.last_result = result
|
||||
self._total_time += result.time_this_iter_s
|
||||
|
||||
if trial.should_stop(result):
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._stop_trial(trial)
|
||||
decision = TrialScheduler.STOP
|
||||
else:
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
self, trial, result)
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
if trial.should_checkpoint():
|
||||
# TODO(rliaw): This is a blocking call
|
||||
trial.checkpoint()
|
||||
self._running[trial.train_remote()] = trial
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
trial.update_last_result(
|
||||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
if trial.should_checkpoint():
|
||||
# TODO(rliaw): This is a blocking call
|
||||
trial.checkpoint()
|
||||
self._running[trial.train_remote()] = trial
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
except Exception:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
if trial.status == Trial.RUNNING:
|
||||
|
||||
Reference in New Issue
Block a user