mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 20:39:52 +08:00
[tune] Use public methods for trainable (#9184)
This commit is contained in:
@@ -165,7 +165,7 @@ class FunctionRunner(Trainable):
|
||||
|
||||
_name = "func"
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
# Semaphore for notifying the reporter to continue with the computation
|
||||
# and to generate the next result.
|
||||
self._continue_semaphore = threading.Semaphore(0)
|
||||
@@ -212,7 +212,7 @@ class FunctionRunner(Trainable):
|
||||
# now done or has raised an exception.
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
"""Implements train() for a Function API.
|
||||
|
||||
If the RunnerThread finishes without reporting "done",
|
||||
@@ -313,7 +313,7 @@ class FunctionRunner(Trainable):
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint):
|
||||
# This should be removed once Trainables are refactored.
|
||||
if "tune_checkpoint_path" in checkpoint:
|
||||
del checkpoint["tune_checkpoint_path"]
|
||||
@@ -330,7 +330,7 @@ class FunctionRunner(Trainable):
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
# If everything stayed in synch properly, this should never happen.
|
||||
if not self._results_queue.empty():
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user