[tune][minor] Fixes (#1383)

This commit is contained in:
Richard Liaw
2018-01-11 18:14:20 -08:00
committed by GitHub
parent 1290072764
commit d4592382a4
3 changed files with 45 additions and 22 deletions
+9 -1
View File
@@ -11,7 +11,7 @@ import os
from collections import namedtuple
from ray.tune import TuneError
from ray.tune.logger import NoopLogger, UnifiedLogger
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
@@ -285,6 +285,14 @@ class Trial(object):
print("Error restoring runner:", traceback.format_exc())
self.status = Trial.ERROR
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_result = result
self.result_logger.on_result(self.last_result)
def _setup_runner(self):
self.status = Trial.RUNNING
trainable_cls = get_registry().get(
+18 -21
View File
@@ -8,7 +8,6 @@ import time
import traceback
from ray.tune import TuneError
from ray.tune.result import pretty_print
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
@@ -157,35 +156,33 @@ class TrialRunner(object):
# have been lost
def _process_events(self):
[result_id], _ = ray.wait(list(self._running.keys()))
trial = self._running[result_id]
del self._running[result_id]
[result_id], _ = ray.wait(list(self._running))
trial = self._running.pop(result_id)
try:
result = ray.get(result_id)
trial.result_logger.on_result(result)
print("TrainingResult for {}:".format(trial))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
trial.last_result = result
self._total_time += result.time_this_iter_s
if trial.should_stop(result):
self._scheduler_alg.on_trial_complete(self, trial, result)
self._stop_trial(trial)
decision = TrialScheduler.STOP
else:
decision = self._scheduler_alg.on_trial_result(
self, trial, result)
if decision == TrialScheduler.CONTINUE:
if trial.should_checkpoint():
# TODO(rliaw): This is a blocking call
trial.checkpoint()
self._running[trial.train_remote()] = trial
elif decision == TrialScheduler.PAUSE:
self._pause_trial(trial)
elif decision == TrialScheduler.STOP:
self._stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(
decision)
trial.update_last_result(
result, terminate=(decision == TrialScheduler.STOP))
if decision == TrialScheduler.CONTINUE:
if trial.should_checkpoint():
# TODO(rliaw): This is a blocking call
trial.checkpoint()
self._running[trial.train_remote()] = trial
elif decision == TrialScheduler.PAUSE:
self._pause_trial(trial)
elif decision == TrialScheduler.STOP:
self._stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(
decision)
except Exception:
print("Error processing event:", traceback.format_exc())
if trial.status == Trial.RUNNING: