mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:00:22 +08:00
[tune] Use public methods for trainable (#9184)
This commit is contained in:
@@ -11,8 +11,7 @@ class DurableTrainable(Trainable):
|
||||
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
|
||||
|
||||
Supports checkpointing to and restoring from remote storage. To use this
|
||||
class, implement the same private methods as ray.tune.Trainable (`_save`,
|
||||
`_train`, `_restore`, `reset_config`, `_setup`, `_stop`).
|
||||
class, implement the same private methods as ray.tune.Trainable.
|
||||
|
||||
.. warning:: This class is currently **experimental** and may
|
||||
be subject to change.
|
||||
|
||||
@@ -19,10 +19,10 @@ class MyTrainableClass(Trainable):
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
@@ -31,13 +31,13 @@ class MyTrainableClass(Trainable):
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": v}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ class MyTrainableClass(Trainable):
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
@@ -30,13 +30,13 @@ class MyTrainableClass(Trainable):
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": v}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class OptimusFn(object):
|
||||
|
||||
def get_optimus_trainable(parent_cls):
|
||||
class OptimusTrainable(parent_cls):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.iter = 0
|
||||
if config.get("seed"):
|
||||
np.random.seed(config["seed"])
|
||||
@@ -61,7 +61,7 @@ def get_optimus_trainable(parent_cls):
|
||||
self.initial_samples_per_step = 500
|
||||
self.mock_data = open("/dev/urandom", "rb").read(1024)
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
new_loss = self.func.eval(self.iter)
|
||||
time.sleep(0.5)
|
||||
@@ -71,7 +71,7 @@ def get_optimus_trainable(parent_cls):
|
||||
"samples": self.initial_samples_per_step
|
||||
}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
time.sleep(0.5)
|
||||
return {
|
||||
"func": cloudpickle.dumps(self.func),
|
||||
@@ -80,7 +80,7 @@ def get_optimus_trainable(parent_cls):
|
||||
"iter": self.iter
|
||||
}
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint):
|
||||
self.func = cloudpickle.loads(checkpoint["func"])
|
||||
self.data = checkpoint["data"]
|
||||
self.iter = checkpoint["iter"]
|
||||
|
||||
@@ -19,10 +19,10 @@ class MyTrainableClass(Trainable):
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
@@ -31,13 +31,13 @@ class MyTrainableClass(Trainable):
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": v}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ class MyTrainableClass(Trainable):
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
@@ -39,13 +39,13 @@ class MyTrainableClass(Trainable):
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": v}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ parser.add_argument(
|
||||
# yapf: disable
|
||||
# __trainable_example_begin__
|
||||
class TrainMNIST(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
||||
self.device = torch.device("cuda" if use_cuda else "cpu")
|
||||
self.train_loader, self.test_loader = get_data_loaders()
|
||||
@@ -44,18 +44,18 @@ class TrainMNIST(tune.Trainable):
|
||||
lr=config.get("lr", 0.01),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
train(
|
||||
self.model, self.optimizer, self.train_loader, device=self.device)
|
||||
acc = test(self.model, self.test_loader, self.device)
|
||||
return {"mean_accuracy": acc}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
|
||||
torch.save(self.model.state_dict(), checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
self.model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class PytorchTrainble(tune.Trainable):
|
||||
changing the original training code.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.train_loader, self.test_loader = get_data_loaders()
|
||||
self.model = ConvNet()
|
||||
self.optimizer = optim.SGD(
|
||||
@@ -35,17 +35,17 @@ class PytorchTrainble(tune.Trainable):
|
||||
lr=config.get("lr", 0.01),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
train(self.model, self.optimizer, self.train_loader)
|
||||
acc = test(self.model, self.test_loader)
|
||||
return {"mean_accuracy": acc}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
|
||||
torch.save(self.model.state_dict(), checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
self.model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
|
||||
@@ -231,7 +231,7 @@ def train(netD, netG, optimG, optimD, criterion, dataloader, iteration,
|
||||
|
||||
# __Trainable_begin__
|
||||
class PytorchTrainable(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
||||
self.device = torch.device("cuda" if use_cuda else "cpu")
|
||||
self.netD = Discriminator().to(self.device)
|
||||
@@ -250,13 +250,13 @@ class PytorchTrainable(tune.Trainable):
|
||||
with FileLock(os.path.expanduser("~/.data.lock")):
|
||||
self.dataloader = get_data_loader()
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
lossG, lossD, is_score = train(
|
||||
self.netD, self.netG, self.optimizerG, self.optimizerD,
|
||||
self.criterion, self.dataloader, self._iteration, self.device)
|
||||
return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save({
|
||||
"netDmodel": self.netD.state_dict(),
|
||||
@@ -267,7 +267,7 @@ class PytorchTrainable(tune.Trainable):
|
||||
|
||||
return checkpoint_dir
|
||||
|
||||
def _restore(self, checkpoint_dir):
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
checkpoint = torch.load(path)
|
||||
self.netD.load_state_dict(checkpoint["netDmodel"])
|
||||
|
||||
@@ -31,11 +31,11 @@ class PBTBenchmarkExample(Trainable):
|
||||
faster convergence. Training will not converge without PBT.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.lr = config["lr"]
|
||||
self.accuracy = 0.0 # end = 1000
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
midpoint = 100 # lr starts decreasing after acc > midpoint
|
||||
q_tolerance = 3 # penalize exceeding lr by more than this multiple
|
||||
noise_level = 2 # add gaussian noise to the acc increase
|
||||
@@ -66,13 +66,13 @@ class PBTBenchmarkExample(Trainable):
|
||||
"done": self.accuracy > midpoint * 2,
|
||||
}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
return {
|
||||
"accuracy": self.accuracy,
|
||||
"lr": self.lr,
|
||||
}
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint):
|
||||
self.accuracy = checkpoint["accuracy"]
|
||||
|
||||
def reset_config(self, new_config):
|
||||
|
||||
@@ -214,7 +214,7 @@ class MemNNModel(Trainable):
|
||||
model = Model([input_sequence, question], answer)
|
||||
return model
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
with FileLock(os.path.expanduser("~/.tune.lock")):
|
||||
self.train_stories, self.test_stories = read_data()
|
||||
model = self.build_model()
|
||||
@@ -226,7 +226,7 @@ class MemNNModel(Trainable):
|
||||
metrics=["accuracy"])
|
||||
self.model = model
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
# train
|
||||
self.model.fit(
|
||||
[self.inputs_train, self.queries_train],
|
||||
@@ -242,12 +242,12 @@ class MemNNModel(Trainable):
|
||||
verbose=0)
|
||||
return {"mean_accuracy": accuracy}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
self.model.save(file_path)
|
||||
return file_path
|
||||
|
||||
def _restore(self, path):
|
||||
def load_checkpoint(self, path):
|
||||
# See https://stackoverflow.com/a/42763323
|
||||
del self.model
|
||||
self.model = load_model(path)
|
||||
|
||||
@@ -106,7 +106,7 @@ class Cifar10Model(Trainable):
|
||||
model = Model(inputs=x, outputs=y, name="model1")
|
||||
return model
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.train_data, self.test_data = self._read_data()
|
||||
x_train = self.train_data[0]
|
||||
model = self._build_model(x_train.shape[1:])
|
||||
@@ -120,7 +120,7 @@ class Cifar10Model(Trainable):
|
||||
metrics=["accuracy"])
|
||||
self.model = model
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
x_train, y_train = self.train_data
|
||||
x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES]
|
||||
x_test, y_test = self.test_data
|
||||
@@ -161,17 +161,17 @@ class Cifar10Model(Trainable):
|
||||
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
|
||||
return {"mean_accuracy": accuracy}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
self.model.save(file_path)
|
||||
return file_path
|
||||
|
||||
def _restore(self, path):
|
||||
def load_checkpoint(self, path):
|
||||
# See https://stackoverflow.com/a/42763323
|
||||
del self.model
|
||||
self.model = load_model(path)
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
# If need, save your model when exit.
|
||||
# saved_path = self.model.save(self.logdir)
|
||||
# print("save model at: ", saved_path)
|
||||
|
||||
@@ -40,7 +40,7 @@ class MyModel(Model):
|
||||
|
||||
|
||||
class MNISTTrainable(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
# IMPORTANT: See the above note.
|
||||
import tensorflow as tf
|
||||
(x_train, y_train), (x_test, y_test) = load_data()
|
||||
@@ -90,7 +90,7 @@ class MNISTTrainable(tune.Trainable):
|
||||
self.tf_train_step = train_step
|
||||
self.tf_test_step = test_step
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.train_loss.reset_states()
|
||||
self.train_accuracy.reset_states()
|
||||
self.test_loss.reset_states()
|
||||
|
||||
@@ -165,7 +165,7 @@ class FunctionRunner(Trainable):
|
||||
|
||||
_name = "func"
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
# Semaphore for notifying the reporter to continue with the computation
|
||||
# and to generate the next result.
|
||||
self._continue_semaphore = threading.Semaphore(0)
|
||||
@@ -212,7 +212,7 @@ class FunctionRunner(Trainable):
|
||||
# now done or has raised an exception.
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
"""Implements train() for a Function API.
|
||||
|
||||
If the RunnerThread finishes without reporting "done",
|
||||
@@ -313,7 +313,7 @@ class FunctionRunner(Trainable):
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint):
|
||||
# This should be removed once Trainables are refactored.
|
||||
if "tune_checkpoint_path" in checkpoint:
|
||||
del checkpoint["tune_checkpoint_path"]
|
||||
@@ -330,7 +330,7 @@ class FunctionRunner(Trainable):
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
# If everything stayed in synch properly, this should never happen.
|
||||
if not self._results_queue.empty():
|
||||
logger.warning(
|
||||
|
||||
@@ -96,6 +96,8 @@ def make_checkpoint_dir(step=None):
|
||||
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
.. warning:: Do not call this function within the Trainable Class API.
|
||||
|
||||
Args:
|
||||
step (int): Current training iteration - used for setting
|
||||
an index to uniquely identify the checkpoint.
|
||||
@@ -138,6 +140,8 @@ def save_checkpoint(checkpoint):
|
||||
|
||||
analysis = tune.run(run_me)
|
||||
|
||||
.. warning:: Do not call this function within the Trainable Class API.
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be logged by Tune. Any of these
|
||||
metrics can be used for early stopping or optimization.
|
||||
|
||||
@@ -13,19 +13,19 @@ class FrequentPausesScheduler(FIFOScheduler):
|
||||
|
||||
def create_resettable_class():
|
||||
class MyResettableClass(Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.config = config
|
||||
self.num_resets = 0
|
||||
self.iter = 0
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
return {"num_resets": self.num_resets, "done": self.iter > 1}
|
||||
|
||||
def _save(self, chkpt_dir):
|
||||
def save_checkpoint(self, chkpt_dir):
|
||||
return {"iter": self.iter}
|
||||
|
||||
def _restore(self, item):
|
||||
def load_checkpoint(self, item):
|
||||
self.iter = item["iter"]
|
||||
|
||||
def reset_config(self, new_config):
|
||||
|
||||
@@ -63,11 +63,11 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
function_output.append(result)
|
||||
|
||||
class _WrappedTrainable(Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
del config
|
||||
self._result_iter = copy.deepcopy(class_results)
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
if sleep_per_iter:
|
||||
time.sleep(sleep_per_iter)
|
||||
res = self._result_iter.pop(0) # This should not fail
|
||||
@@ -233,7 +233,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def default_resource_request(cls, config):
|
||||
return Resources(cpu=config["cpu"], gpu=config["gpu"])
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
register_trainable("B", B)
|
||||
@@ -628,7 +628,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
def testTrialInfoAccess(self):
|
||||
class TestTrainable(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
result = {"name": self.trial_name, "trial_id": self.trial_id}
|
||||
print(result)
|
||||
return result
|
||||
@@ -659,11 +659,11 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
@patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3)
|
||||
def testLotsOfStops(self):
|
||||
class TestTrainable(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
result = {"name": self.trial_name, "trial_id": self.trial_id}
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
time.sleep(2)
|
||||
open(os.path.join(self.logdir, "marker"), "a").close()
|
||||
return 1
|
||||
@@ -825,17 +825,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
def testDurableTrainable(self):
|
||||
class TestTrain(DurableTrainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1, "iter": 0}
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.state["iter"] += 1
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
sync_client = mock_storage_client()
|
||||
@@ -853,16 +853,16 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
def testCheckpointDict(self):
|
||||
class TestTrain(Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1}
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
test_trainable = TestTrain()
|
||||
@@ -883,17 +883,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
def testMultipleCheckpoints(self):
|
||||
class TestTrain(Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1, "iter": 0}
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.state["iter"] += 1
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
test_trainable = TestTrain()
|
||||
@@ -938,6 +938,41 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
|
||||
self.assertEqual(trial.last_result["itr"], 99)
|
||||
|
||||
def testBackwardsCompat(self):
|
||||
class TestTrain(Trainable):
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 1, "iter": 0}
|
||||
|
||||
def _train(self):
|
||||
self.state["iter"] += 1
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
test_trainable = TestTrain()
|
||||
checkpoint_1 = test_trainable.save()
|
||||
test_trainable.train()
|
||||
checkpoint_2 = test_trainable.save()
|
||||
self.assertNotEqual(checkpoint_1, checkpoint_2)
|
||||
test_trainable.restore(checkpoint_2)
|
||||
self.assertEqual(test_trainable.state["iter"], 1)
|
||||
test_trainable.restore(checkpoint_1)
|
||||
self.assertEqual(test_trainable.state["iter"], 0)
|
||||
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
"run": TestTrain,
|
||||
"checkpoint_at_end": True
|
||||
}
|
||||
})
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -623,18 +623,18 @@ def test_cluster_interrupt(start_connected_cluster, tmpdir):
|
||||
class _Mock(tune.Trainable):
|
||||
"""Finishes on the 4th iteration."""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 0}
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
self.state["hi"] += 1
|
||||
time.sleep(0.5)
|
||||
return {"done": self.state["hi"] >= 4}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
# Removes indent from class.
|
||||
|
||||
@@ -21,19 +21,19 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
||||
4: [7, 5, 5, 5, 5, 5, 5, 5, 3]
|
||||
}
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.id = config["id"]
|
||||
self.idx = 0
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
val = self.scores_dict[self.id][self.idx]
|
||||
self.idx += 1
|
||||
return {"score": val}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
pass
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
pass
|
||||
|
||||
self.MockTrainable = MockTrainable
|
||||
|
||||
@@ -145,7 +145,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
"""Tests that reset works as expected."""
|
||||
|
||||
class B(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return dict(timesteps_this_iter=1, done=True)
|
||||
|
||||
def reset_config(self, config):
|
||||
|
||||
@@ -74,7 +74,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
class B(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
register_trainable("f1", train)
|
||||
@@ -91,10 +91,10 @@ class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
def testCheckpointAtEnd(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
checkpoint = os.path.join(path, "checkpoint")
|
||||
with open(checkpoint, "w") as f:
|
||||
f.write("OK")
|
||||
@@ -112,7 +112,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
def testExportFormats(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
@@ -134,7 +134,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
def testInvalidExportFormats(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
@@ -156,7 +156,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
ray.init(resources={"hi": 3})
|
||||
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
trials = run_experiments({
|
||||
|
||||
@@ -1141,10 +1141,10 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
||||
pbt = self.basicSetup(perturbation_interval=2)
|
||||
|
||||
class train(tune.Trainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"mean_accuracy": self.training_iteration}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
checkpoint = os.path.join(path, "checkpoint")
|
||||
with open(checkpoint, "w") as f:
|
||||
f.write("OK")
|
||||
@@ -1173,16 +1173,16 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
||||
pbt = self.basicSetup(perturbation_interval=2)
|
||||
|
||||
class train_dict(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1}
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"mean_accuracy": self.training_iteration}
|
||||
|
||||
def _save(self, path):
|
||||
def save_checkpoint(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
trial_hyperparams = {
|
||||
|
||||
@@ -19,20 +19,20 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||
class MockTrainable(Trainable):
|
||||
_name = "MockTrainable"
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self.state = {"hi": 1}
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir, "checkpoint-{}".format(self._iteration))
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(self.state, f)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
extra_data = pickle.load(f)
|
||||
self.state.update(extra_data)
|
||||
@@ -154,18 +154,18 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||
"""Tests that passing the checkpoint_dir right back works."""
|
||||
|
||||
class MockTrainable(Trainable):
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
return {"score": 1}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
with open(os.path.join(checkpoint_dir, "test.txt"), "wb") as f:
|
||||
pickle.dump("test", f)
|
||||
return checkpoint_dir
|
||||
|
||||
def _restore(self, checkpoint_dir):
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
with open(os.path.join(checkpoint_dir, "test.txt"), "rb") as f:
|
||||
x = pickle.load(f)
|
||||
|
||||
|
||||
+139
-46
@@ -15,6 +15,7 @@ import time
|
||||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
|
||||
@@ -174,11 +175,11 @@ class Trainable:
|
||||
Calling ``save()`` should save the training state of a trainable to disk,
|
||||
and ``restore(path)`` should restore a trainable to the given state.
|
||||
|
||||
Generally you only need to implement ``_setup``, ``_train``,
|
||||
``_save``, and ``_restore`` when subclassing Trainable.
|
||||
Generally you only need to implement ``build``, ``step``,
|
||||
``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable.
|
||||
|
||||
Other implementation methods that may be helpful to override are
|
||||
``_log_result``, ``reset_config``, ``_stop``, and ``_export_model``.
|
||||
``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``.
|
||||
|
||||
When using Tune, Tune will convert this class into a Ray actor, which
|
||||
runs on a separate process. Tune will also change the current working
|
||||
@@ -192,7 +193,7 @@ class Trainable:
|
||||
Sets up logging and points ``self.logdir`` to a directory in which
|
||||
training outputs should be placed.
|
||||
|
||||
Subclasses should prefer defining ``_setup()`` instead of overriding
|
||||
Subclasses should prefer defining ``build()`` instead of overriding
|
||||
``__init__()`` directly.
|
||||
|
||||
Args:
|
||||
@@ -228,11 +229,11 @@ class Trainable:
|
||||
self._trial_info = trial_info
|
||||
|
||||
start_time = time.time()
|
||||
self._setup(copy.deepcopy(self.config))
|
||||
self.setup(copy.deepcopy(self.config))
|
||||
setup_time = time.time() - start_time
|
||||
if setup_time > SETUP_TIME_THRESHOLD:
|
||||
logger.info("_setup took {:.3f} seconds. If your trainable is "
|
||||
"slow to initialize, consider setting "
|
||||
logger.info("Trainable.setup took {:.3f} seconds. If your "
|
||||
"trainable is slow to initialize, consider setting "
|
||||
"reuse_actors=True to reduce actor creation "
|
||||
"overheads.".format(setup_time))
|
||||
self._local_ip = self.get_current_ip()
|
||||
@@ -277,8 +278,9 @@ class Trainable:
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Subclasses should override ``_train()`` instead to return results.
|
||||
This class automatically fills the following fields in the result:
|
||||
Calls ``step()`` internally. Subclasses should override ``step()``
|
||||
instead to return results.
|
||||
This method automatically fills the following fields in the result:
|
||||
|
||||
`done` (bool): training is terminated. Filled only if not provided.
|
||||
|
||||
@@ -295,7 +297,7 @@ class Trainable:
|
||||
|
||||
`training_iteration` (int): The index of this
|
||||
training iteration, e.g. call to train(). This is incremented
|
||||
after `_train()` is called.
|
||||
after `step()` is called.
|
||||
|
||||
`pid` (str): The pid of the training process.
|
||||
|
||||
@@ -314,8 +316,8 @@ class Trainable:
|
||||
A dict that describes training progress.
|
||||
"""
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
assert isinstance(result, dict), "_train() needs to return a dict."
|
||||
result = self.step()
|
||||
assert isinstance(result, dict), "step() needs to return a dict."
|
||||
|
||||
# We do not modify internal state nor update this result if duplicate.
|
||||
if RESULT_DUPLICATE in result:
|
||||
@@ -376,7 +378,7 @@ class Trainable:
|
||||
if monitor_data:
|
||||
result.update(monitor_data)
|
||||
|
||||
self._log_result(result)
|
||||
self.log_result(result)
|
||||
|
||||
return result
|
||||
|
||||
@@ -404,7 +406,7 @@ class Trainable:
|
||||
"""
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
checkpoint_dir or self.logdir, index=self.iteration)
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
checkpoint = self.save_checkpoint(checkpoint_dir)
|
||||
trainable_state = self.get_state()
|
||||
checkpoint_path = TrainableUtil.process_checkpoint(
|
||||
checkpoint,
|
||||
@@ -451,9 +453,9 @@ class Trainable:
|
||||
with open(checkpoint_path, "rb") as loaded_state:
|
||||
checkpoint_dict = pickle.load(loaded_state)
|
||||
checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
|
||||
self._restore(checkpoint_dict)
|
||||
self.load_checkpoint(checkpoint_dict)
|
||||
else:
|
||||
self._restore(checkpoint_path)
|
||||
self.load_checkpoint(checkpoint_path)
|
||||
self._time_since_restore = 0.0
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
@@ -523,7 +525,7 @@ class Trainable:
|
||||
be updated to reflect the latest parameter information in Ray logs.
|
||||
|
||||
Args:
|
||||
new_config (dir): Updated hyperparameter configuration
|
||||
new_config (dict): Updated hyperparameter configuration
|
||||
for the trainable.
|
||||
|
||||
Returns:
|
||||
@@ -532,10 +534,14 @@ class Trainable:
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this trainable."""
|
||||
"""Releases all resources used by this trainable.
|
||||
|
||||
Calls ``Trainable.cleanup`` internally. Subclasses should override
|
||||
``Trainable.cleanup`` for custom cleanup procedures.
|
||||
"""
|
||||
self._result_logger.flush()
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
self.cleanup()
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
@@ -603,7 +609,7 @@ class Trainable:
|
||||
"""Returns configuration passed in by Tune."""
|
||||
return self.config
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
"""Subclasses should override this to implement train().
|
||||
|
||||
The return value will be automatically passed to the loggers. Users
|
||||
@@ -612,27 +618,43 @@ class Trainable:
|
||||
trial. Note that manual checkpointing only works when subclassing
|
||||
Trainables.
|
||||
|
||||
.. versionadded:: 0.8.7
|
||||
|
||||
Returns:
|
||||
A dict that describes training progress.
|
||||
|
||||
"""
|
||||
result = self._train()
|
||||
|
||||
if self._is_overriden("_train") and log_once("_train"):
|
||||
logger.warning(
|
||||
"Trainable._train is deprecated and will be removed in "
|
||||
"a future version of Ray. Override Trainable.step instead.")
|
||||
return result
|
||||
|
||||
def _train(self):
|
||||
"""This method is deprecated. Override 'Trainable.step' instead.
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _save(self, tmp_checkpoint_dir):
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
"""Subclasses should override this to implement ``save()``.
|
||||
|
||||
Warning:
|
||||
Do not rely on absolute paths in the implementation of ``_save``
|
||||
and ``_restore``.
|
||||
Do not rely on absolute paths in the implementation of
|
||||
``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``.
|
||||
|
||||
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors
|
||||
before execution.
|
||||
Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/
|
||||
``Trainable.load_checkpoint`` errors before execution.
|
||||
|
||||
>>> from ray.tune.utils import validate_save_restore
|
||||
>>> validate_save_restore(MyTrainableClass)
|
||||
>>> validate_save_restore(MyTrainableClass, use_object_store=True)
|
||||
|
||||
.. versionadded:: 0.8.7
|
||||
|
||||
Args:
|
||||
tmp_checkpoint_dir (str): The directory where the checkpoint
|
||||
file must be stored. In a Tune run, if the trial is paused,
|
||||
@@ -641,44 +663,60 @@ class Trainable:
|
||||
Returns:
|
||||
A dict or string. If string, the return value is expected to be
|
||||
prefixed by `tmp_checkpoint_dir`. If dict, the return value will
|
||||
be automatically serialized by Tune and passed to `_restore()`.
|
||||
be automatically serialized by Tune and
|
||||
passed to ``Trainable.load_checkpoint()``.
|
||||
|
||||
Examples:
|
||||
>>> print(trainable1._save("/tmp/checkpoint_1"))
|
||||
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
|
||||
"/tmp/checkpoint_1/my_checkpoint_file"
|
||||
>>> print(trainable2._save("/tmp/checkpoint_2"))
|
||||
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
|
||||
{"some": "data"}
|
||||
|
||||
>>> trainable._save("/tmp/bad_example")
|
||||
>>> trainable.save_checkpoint("/tmp/bad_example")
|
||||
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
|
||||
"""
|
||||
checkpoint = self._save(tmp_checkpoint_dir)
|
||||
|
||||
if self._is_overriden("_save") and log_once("_save"):
|
||||
logger.warning(
|
||||
"Trainable._save is deprecated and will be removed in a "
|
||||
"future version of Ray. Override "
|
||||
"Trainable.save_checkpoint instead.")
|
||||
return checkpoint
|
||||
|
||||
def _save(self, tmp_checkpoint_dir):
|
||||
"""This method is deprecated. Override 'save_checkpoint' instead.
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint):
|
||||
"""Subclasses should override this to implement restore().
|
||||
|
||||
Warning:
|
||||
In this method, do not rely on absolute paths. The absolute
|
||||
path of the checkpoint_dir used in ``_save`` may be changed.
|
||||
path of the checkpoint_dir used in ``Trainable.save_checkpoint``
|
||||
may be changed.
|
||||
|
||||
If ``_save`` returned a prefixed string, the prefix of the checkpoint
|
||||
string returned by ``_save`` may be changed. This is because trial
|
||||
pausing depends on temporary directories.
|
||||
If ``Trainable.save_checkpoint`` returned a prefixed string, the
|
||||
prefix of the checkpoint string returned by
|
||||
``Trainable.save_checkpoint`` may be changed.
|
||||
This is because trial pausing depends on temporary directories.
|
||||
|
||||
The directory structure under the checkpoint_dir provided to ``_save``
|
||||
is preserved.
|
||||
The directory structure under the checkpoint_dir provided to
|
||||
``Trainable.save_checkpoint`` is preserved.
|
||||
|
||||
See the example below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Example(Trainable):
|
||||
def _save(self, checkpoint_path):
|
||||
def save_checkpoint(self, checkpoint_path):
|
||||
print(checkpoint_path)
|
||||
return os.path.join(checkpoint_path, "my/check/point")
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint):
|
||||
print(checkpoint)
|
||||
|
||||
>>> trainer = Example()
|
||||
@@ -687,39 +725,78 @@ class Trainable:
|
||||
>>> trainer.restore_from_object(obj) # Note the different prefix.
|
||||
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
|
||||
|
||||
.. versionadded:: 0.8.7
|
||||
|
||||
Args:
|
||||
checkpoint (str|dict): If dict, the return value is as
|
||||
returned by `_save`. If a string, then it is a checkpoint path
|
||||
that may have a different prefix than that returned by `_save`.
|
||||
The directory structure underneath the `checkpoint_dir`
|
||||
`_save` is preserved.
|
||||
returned by `save_checkpoint`. If a string, then it is
|
||||
a checkpoint path that may have a different prefix than that
|
||||
returned by `save_checkpoint`. The directory structure
|
||||
underneath the `checkpoint_dir` `save_checkpoint` is preserved.
|
||||
"""
|
||||
self._restore(checkpoint)
|
||||
if self._is_overriden("_restore") and log_once("_restore"):
|
||||
logger.warning(
|
||||
"Trainable._restore is deprecated and will be removed in a "
|
||||
"future version of Ray. Override Trainable.load_checkpoint "
|
||||
"instead.")
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
"""This method is deprecated. Override 'load_checkpoint' instead.
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
"""Subclasses should override this for custom initialization.
|
||||
|
||||
.. versionadded:: 0.8.7
|
||||
|
||||
Args:
|
||||
config (dict): Hyperparameters and other configs given.
|
||||
Copy of `self.config`.
|
||||
"""
|
||||
self._setup(config)
|
||||
if self._is_overriden("_setup") and log_once("_setup"):
|
||||
logger.warning(
|
||||
"Trainable._setup is deprecated and will be removed in "
|
||||
"a future version of Ray. Override Trainable.setup instead.")
|
||||
|
||||
def _setup(self, config):
|
||||
"""This method is deprecated. Override 'setup' instead.
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
"""
|
||||
pass
|
||||
|
||||
def _log_result(self, result):
|
||||
def log_result(self, result):
|
||||
"""Subclasses can optionally override this to customize logging.
|
||||
|
||||
The logging here is done on the worker process rather than
|
||||
the driver. You may want to turn off driver logging via the
|
||||
``loggers`` parameter in ``tune.run`` when overriding this function.
|
||||
|
||||
.. versionadded:: 0.8.7
|
||||
|
||||
Args:
|
||||
result (dict): Training result returned by _train().
|
||||
result (dict): Training result returned by step().
|
||||
"""
|
||||
self._log_result(result)
|
||||
if self._is_overriden("_log_result") and log_once("_log_result"):
|
||||
logger.warning(
|
||||
"Trainable._log_result is deprecated and will be removed in "
|
||||
"a future version of Ray. Override "
|
||||
"Trainable.log_result instead.")
|
||||
|
||||
def _log_result(self, result):
|
||||
"""This method is deprecated. Override 'log_result' instead.
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
"""
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
"""Subclasses should override this for any cleanup on stop.
|
||||
|
||||
If any Ray actors are launched in the Trainable (i.e., with a RLlib
|
||||
@@ -727,6 +804,19 @@ class Trainable:
|
||||
|
||||
You can kill a Ray actor by calling `actor.__ray_terminate__.remote()`
|
||||
on the actor.
|
||||
|
||||
.. versionadded:: 0.8.7
|
||||
"""
|
||||
self._stop()
|
||||
if self._is_overriden("_stop") and log_once("trainable.cleanup"):
|
||||
logger.warning(
|
||||
"Trainable._stop is deprecated and will be removed in "
|
||||
"a future version of Ray. Override Trainable.cleanup instead.")
|
||||
|
||||
def _stop(self):
|
||||
"""This method is deprecated. Override 'cleanup' instead.
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -741,3 +831,6 @@ class Trainable:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _is_overriden(self, key):
|
||||
return getattr(self, key).__code__ != getattr(Trainable, key).__code__
|
||||
|
||||
@@ -293,7 +293,6 @@ class TrialRunner:
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f, cls=_TuneFunctionDecoder)
|
||||
self.checkpoint_file = newest_ckpt_path
|
||||
|
||||
logger.warning("".join([
|
||||
"Attempting to resume experiment from {}. ".format(
|
||||
self._local_checkpoint_dir), "This feature is experimental, "
|
||||
|
||||
@@ -171,7 +171,7 @@ class TFTrainable(Trainable):
|
||||
extra_cpu=config["num_replicas"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["num_replicas"])
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
self._trainer = TFTrainer(
|
||||
model_creator=config["model_creator"],
|
||||
data_creator=config["data_creator"],
|
||||
@@ -180,7 +180,7 @@ class TFTrainable(Trainable):
|
||||
use_gpu=config["use_gpu"],
|
||||
num_cpus_per_worker=config.get("num_cpus_per_worker", 1))
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
|
||||
train_stats = self._trainer.train()
|
||||
validation_stats = self._trainer.validate()
|
||||
@@ -189,11 +189,11 @@ class TFTrainable(Trainable):
|
||||
|
||||
return train_stats
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
return self._trainer.save(os.path.join(checkpoint_dir, "model"))
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
return self._trainer.restore(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
self._trainer.shutdown()
|
||||
|
||||
@@ -785,7 +785,7 @@ class BaseTorchTrainable(Trainable):
|
||||
# TorchTrainable is subclass of BaseTorchTrainable.
|
||||
|
||||
class CustomTrainable(TorchTrainable):
|
||||
def _train(self):
|
||||
def step(self):
|
||||
for i in range(5):
|
||||
train_stats = self.trainer.train()
|
||||
validation_stats = self.trainer.validate()
|
||||
@@ -799,11 +799,11 @@ class BaseTorchTrainable(Trainable):
|
||||
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
def setup(self, config):
|
||||
"""Constructs a TorchTrainer object as `self.trainer`."""
|
||||
self._trainer = self._create_trainer(config)
|
||||
|
||||
def _train(self):
|
||||
def step(self):
|
||||
"""Calls `self.trainer.train()` and `self.trainer.validate()` once.
|
||||
|
||||
You may want to override this if using a custom LR scheduler.
|
||||
@@ -813,20 +813,20 @@ class BaseTorchTrainable(Trainable):
|
||||
stats = merge_dicts(train_stats, validation_stats)
|
||||
return stats
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
"""Returns a path containing the trainer state."""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "trainer.checkpoint")
|
||||
self.trainer.save(checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
"""Restores the trainer state.
|
||||
|
||||
Override this if you have state external to the Trainer object.
|
||||
"""
|
||||
return self.trainer.load(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
def cleanup(self):
|
||||
"""Shuts down the trainer."""
|
||||
self.trainer.shutdown()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user