[tune] Use public methods for trainable (#9184)

This commit is contained in:
Richard Liaw
2020-07-01 11:00:00 -07:00
committed by GitHub
parent 1491508859
commit d35f0e40d0
40 changed files with 350 additions and 220 deletions
+4 -4
View File
@@ -165,7 +165,7 @@ class FunctionRunner(Trainable):
_name = "func"
def _setup(self, config):
def setup(self, config):
# Semaphore for notifying the reporter to continue with the computation
# and to generate the next result.
self._continue_semaphore = threading.Semaphore(0)
@@ -212,7 +212,7 @@ class FunctionRunner(Trainable):
# now done or has raised an exception.
pass
def _train(self):
def step(self):
"""Implements train() for a Function API.
If the RunnerThread finishes without reporting "done",
@@ -313,7 +313,7 @@ class FunctionRunner(Trainable):
out.write(data_dict)
return out.getvalue()
def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]
@@ -330,7 +330,7 @@ class FunctionRunner(Trainable):
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
self.restore(checkpoint_path)
def _stop(self):
def cleanup(self):
# If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty():
logger.warning(