[tune] reuse actors for function API (#11230)

Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
Kai Fricke
2020-10-09 00:15:02 +01:00
committed by GitHub
parent 587319debc
commit b450cb030a
7 changed files with 154 additions and 28 deletions
+49 -1
View File
@@ -1,5 +1,6 @@
import logging
import os
import sys
import time
import inspect
import shutil
@@ -120,12 +121,21 @@ class StatusReporter:
def __init__(self,
result_queue,
continue_semaphore,
end_event,
trial_name=None,
trial_id=None,
logdir=None):
self._queue = result_queue
self._last_report_time = None
self._continue_semaphore = continue_semaphore
self._end_event = end_event
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = None
self._fresh_checkpoint = False
def reset(self, trial_name=None, trial_id=None, logdir=None):
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
@@ -171,6 +181,11 @@ class StatusReporter:
# resume training.
self._continue_semaphore.acquire()
# If the trial should be terminated, exit gracefully.
if self._end_event.is_set():
self._end_event.clear()
sys.exit(0)
def make_checkpoint_dir(self, step):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
@@ -264,6 +279,10 @@ class FunctionRunner(Trainable):
# and to generate the next result.
self._continue_semaphore = threading.Semaphore(0)
# Event for notifying the reporter to exit gracefully, terminating
# the thread.
self._end_event = threading.Event()
# Queue for passing results between threads
self._results_queue = queue.Queue(1)
@@ -275,6 +294,7 @@ class FunctionRunner(Trainable):
self._status_reporter = StatusReporter(
self._results_queue,
self._continue_semaphore,
self._end_event,
trial_name=self.trial_name,
trial_id=self.trial_id,
logdir=self.logdir)
@@ -363,7 +383,7 @@ class FunctionRunner(Trainable):
# This keyword appears if the train_func using the Function API
# finishes without "done=True". This duplicates the last result, but
# the TrialRunner will not log this result again.
if "__duplicate__" in result:
if RESULT_DUPLICATE in result:
new_result = self._last_result.copy()
new_result.update(result)
result = new_result
@@ -441,6 +461,11 @@ class FunctionRunner(Trainable):
self.restore(checkpoint_path)
def cleanup(self):
# Trigger thread termination
self._end_event.set()
self._continue_semaphore.release()
# Do not wait for thread termination here.
# If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty():
logger.warning(
@@ -457,6 +482,29 @@ class FunctionRunner(Trainable):
logger.debug("Clearing temporary checkpoint: %s",
self.temp_checkpoint_dir)
def reset_config(self, new_config):
if self._runner and self._runner.is_alive():
self._end_event.set()
self._continue_semaphore.release()
# Wait for thread termination so it is save to re-use the same
# actor.
thread_timeout = int(
os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
self._runner.join(timeout=thread_timeout)
if self._runner.is_alive():
# Did not finish within timeout, reset unsuccessful.
return False
self._runner = None
self._last_result = {}
self._status_reporter.reset(
trial_name=self.trial_name,
trial_id=self.trial_id,
logdir=self.logdir)
return True
def _report_thread_runner_error(self, block=False):
try:
err_tb_str = self._error_queue.get(