[tune] Use public methods for trainable (#9184)

This commit is contained in:
Richard Liaw
2020-07-01 11:00:00 -07:00
committed by GitHub
parent 1491508859
commit d35f0e40d0
40 changed files with 350 additions and 220 deletions
+1 -2
View File
@@ -11,8 +11,7 @@ class DurableTrainable(Trainable):
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
Supports checkpointing to and restoring from remote storage. To use this
class, implement the same private methods as ray.tune.Trainable (`_save`,
`_train`, `_restore`, `reset_config`, `_setup`, `_stop`).
class, implement the same private methods as ray.tune.Trainable.
.. warning:: This class is currently **experimental** and may
be subject to change.
@@ -19,10 +19,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""
def _setup(self, config):
def setup(self, config):
self.timestep = 0
def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
@@ -31,13 +31,13 @@ class MyTrainableClass(Trainable):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
+4 -4
View File
@@ -18,10 +18,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""
def _setup(self, config):
def setup(self, config):
self.timestep = 0
def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
@@ -30,13 +30,13 @@ class MyTrainableClass(Trainable):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
@@ -51,7 +51,7 @@ class OptimusFn(object):
def get_optimus_trainable(parent_cls):
class OptimusTrainable(parent_cls):
def _setup(self, config):
def setup(self, config):
self.iter = 0
if config.get("seed"):
np.random.seed(config["seed"])
@@ -61,7 +61,7 @@ def get_optimus_trainable(parent_cls):
self.initial_samples_per_step = 500
self.mock_data = open("/dev/urandom", "rb").read(1024)
def _train(self):
def step(self):
self.iter += 1
new_loss = self.func.eval(self.iter)
time.sleep(0.5)
@@ -71,7 +71,7 @@ def get_optimus_trainable(parent_cls):
"samples": self.initial_samples_per_step
}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
time.sleep(0.5)
return {
"func": cloudpickle.dumps(self.func),
@@ -80,7 +80,7 @@ def get_optimus_trainable(parent_cls):
"iter": self.iter
}
def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
self.func = cloudpickle.loads(checkpoint["func"])
self.data = checkpoint["data"]
self.iter = checkpoint["iter"]
@@ -19,10 +19,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""
def _setup(self, config):
def setup(self, config):
self.timestep = 0
def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
@@ -31,13 +31,13 @@ class MyTrainableClass(Trainable):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
+4 -4
View File
@@ -27,10 +27,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""
def _setup(self, config):
def setup(self, config):
self.timestep = 0
def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
@@ -39,13 +39,13 @@ class MyTrainableClass(Trainable):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
@@ -34,7 +34,7 @@ parser.add_argument(
# yapf: disable
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
def _setup(self, config):
def setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.train_loader, self.test_loader = get_data_loaders()
@@ -44,18 +44,18 @@ class TrainMNIST(tune.Trainable):
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def _train(self):
def step(self):
train(
self.model, self.optimizer, self.train_loader, device=self.device)
acc = test(self.model, self.test_loader, self.device)
return {"mean_accuracy": acc}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))
@@ -27,7 +27,7 @@ class PytorchTrainble(tune.Trainable):
changing the original training code.
"""
def _setup(self, config):
def setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
@@ -35,17 +35,17 @@ class PytorchTrainble(tune.Trainable):
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def _train(self):
def step(self):
train(self.model, self.optimizer, self.train_loader)
acc = test(self.model, self.test_loader)
return {"mean_accuracy": acc}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))
def _export_model(self, export_formats, export_dir):
@@ -231,7 +231,7 @@ def train(netD, netG, optimG, optimD, criterion, dataloader, iteration,
# __Trainable_begin__
class PytorchTrainable(tune.Trainable):
def _setup(self, config):
def setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.netD = Discriminator().to(self.device)
@@ -250,13 +250,13 @@ class PytorchTrainable(tune.Trainable):
with FileLock(os.path.expanduser("~/.data.lock")):
self.dataloader = get_data_loader()
def _train(self):
def step(self):
lossG, lossD, is_score = train(
self.netD, self.netG, self.optimizerG, self.optimizerD,
self.criterion, self.dataloader, self._iteration, self.device)
return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save({
"netDmodel": self.netD.state_dict(),
@@ -267,7 +267,7 @@ class PytorchTrainable(tune.Trainable):
return checkpoint_dir
def _restore(self, checkpoint_dir):
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = torch.load(path)
self.netD.load_state_dict(checkpoint["netDmodel"])
+4 -4
View File
@@ -31,11 +31,11 @@ class PBTBenchmarkExample(Trainable):
faster convergence. Training will not converge without PBT.
"""
def _setup(self, config):
def setup(self, config):
self.lr = config["lr"]
self.accuracy = 0.0 # end = 1000
def _train(self):
def step(self):
midpoint = 100 # lr starts decreasing after acc > midpoint
q_tolerance = 3 # penalize exceeding lr by more than this multiple
noise_level = 2 # add gaussian noise to the acc increase
@@ -66,13 +66,13 @@ class PBTBenchmarkExample(Trainable):
"done": self.accuracy > midpoint * 2,
}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
return {
"accuracy": self.accuracy,
"lr": self.lr,
}
def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
self.accuracy = checkpoint["accuracy"]
def reset_config(self, new_config):
@@ -214,7 +214,7 @@ class MemNNModel(Trainable):
model = Model([input_sequence, question], answer)
return model
def _setup(self, config):
def setup(self, config):
with FileLock(os.path.expanduser("~/.tune.lock")):
self.train_stories, self.test_stories = read_data()
model = self.build_model()
@@ -226,7 +226,7 @@ class MemNNModel(Trainable):
metrics=["accuracy"])
self.model = model
def _train(self):
def step(self):
# train
self.model.fit(
[self.inputs_train, self.queries_train],
@@ -242,12 +242,12 @@ class MemNNModel(Trainable):
verbose=0)
return {"mean_accuracy": accuracy}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
self.model.save(file_path)
return file_path
def _restore(self, path):
def load_checkpoint(self, path):
# See https://stackoverflow.com/a/42763323
del self.model
self.model = load_model(path)
@@ -106,7 +106,7 @@ class Cifar10Model(Trainable):
model = Model(inputs=x, outputs=y, name="model1")
return model
def _setup(self, config):
def setup(self, config):
self.train_data, self.test_data = self._read_data()
x_train = self.train_data[0]
model = self._build_model(x_train.shape[1:])
@@ -120,7 +120,7 @@ class Cifar10Model(Trainable):
metrics=["accuracy"])
self.model = model
def _train(self):
def step(self):
x_train, y_train = self.train_data
x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES]
x_test, y_test = self.test_data
@@ -161,17 +161,17 @@ class Cifar10Model(Trainable):
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
return {"mean_accuracy": accuracy}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
self.model.save(file_path)
return file_path
def _restore(self, path):
def load_checkpoint(self, path):
# See https://stackoverflow.com/a/42763323
del self.model
self.model = load_model(path)
def _stop(self):
def cleanup(self):
# If need, save your model when exit.
# saved_path = self.model.save(self.logdir)
# print("save model at: ", saved_path)
+2 -2
View File
@@ -40,7 +40,7 @@ class MyModel(Model):
class MNISTTrainable(tune.Trainable):
def _setup(self, config):
def setup(self, config):
# IMPORTANT: See the above note.
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = load_data()
@@ -90,7 +90,7 @@ class MNISTTrainable(tune.Trainable):
self.tf_train_step = train_step
self.tf_test_step = test_step
def _train(self):
def step(self):
self.train_loss.reset_states()
self.train_accuracy.reset_states()
self.test_loss.reset_states()
+4 -4
View File
@@ -165,7 +165,7 @@ class FunctionRunner(Trainable):
_name = "func"
def _setup(self, config):
def setup(self, config):
# Semaphore for notifying the reporter to continue with the computation
# and to generate the next result.
self._continue_semaphore = threading.Semaphore(0)
@@ -212,7 +212,7 @@ class FunctionRunner(Trainable):
# now done or has raised an exception.
pass
def _train(self):
def step(self):
"""Implements train() for a Function API.
If the RunnerThread finishes without reporting "done",
@@ -313,7 +313,7 @@ class FunctionRunner(Trainable):
out.write(data_dict)
return out.getvalue()
def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]
@@ -330,7 +330,7 @@ class FunctionRunner(Trainable):
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
self.restore(checkpoint_path)
def _stop(self):
def cleanup(self):
# If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty():
logger.warning(
+4
View File
@@ -96,6 +96,8 @@ def make_checkpoint_dir(step=None):
tune.report(hello="world", ray="tune")
.. warning:: Do not call this function within the Trainable Class API.
Args:
step (int): Current training iteration - used for setting
an index to uniquely identify the checkpoint.
@@ -138,6 +140,8 @@ def save_checkpoint(checkpoint):
analysis = tune.run(run_me)
.. warning:: Do not call this function within the Trainable Class API.
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
+4 -4
View File
@@ -13,19 +13,19 @@ class FrequentPausesScheduler(FIFOScheduler):
def create_resettable_class():
class MyResettableClass(Trainable):
def _setup(self, config):
def setup(self, config):
self.config = config
self.num_resets = 0
self.iter = 0
def _train(self):
def step(self):
self.iter += 1
return {"num_resets": self.num_resets, "done": self.iter > 1}
def _save(self, chkpt_dir):
def save_checkpoint(self, chkpt_dir):
return {"iter": self.iter}
def _restore(self, item):
def load_checkpoint(self, item):
self.iter = item["iter"]
def reset_config(self, new_config):
+53 -18
View File
@@ -63,11 +63,11 @@ class TrainableFunctionApiTest(unittest.TestCase):
function_output.append(result)
class _WrappedTrainable(Trainable):
def _setup(self, config):
def setup(self, config):
del config
self._result_iter = copy.deepcopy(class_results)
def _train(self):
def step(self):
if sleep_per_iter:
time.sleep(sleep_per_iter)
res = self._result_iter.pop(0) # This should not fail
@@ -233,7 +233,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
def default_resource_request(cls, config):
return Resources(cpu=config["cpu"], gpu=config["gpu"])
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
register_trainable("B", B)
@@ -628,7 +628,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testTrialInfoAccess(self):
class TestTrainable(Trainable):
def _train(self):
def step(self):
result = {"name": self.trial_name, "trial_id": self.trial_id}
print(result)
return result
@@ -659,11 +659,11 @@ class TrainableFunctionApiTest(unittest.TestCase):
@patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3)
def testLotsOfStops(self):
class TestTrainable(Trainable):
def _train(self):
def step(self):
result = {"name": self.trial_name, "trial_id": self.trial_id}
return result
def _stop(self):
def cleanup(self):
time.sleep(2)
open(os.path.join(self.logdir, "marker"), "a").close()
return 1
@@ -825,17 +825,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testDurableTrainable(self):
class TestTrain(DurableTrainable):
def _setup(self, config):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def _train(self):
def step(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
def save_checkpoint(self, path):
return self.state
def _restore(self, state):
def load_checkpoint(self, state):
self.state = state
sync_client = mock_storage_client()
@@ -853,16 +853,16 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testCheckpointDict(self):
class TestTrain(Trainable):
def _setup(self, config):
def setup(self, config):
self.state = {"hi": 1}
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
def save_checkpoint(self, path):
return self.state
def _restore(self, state):
def load_checkpoint(self, state):
self.state = state
test_trainable = TestTrain()
@@ -883,17 +883,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testMultipleCheckpoints(self):
class TestTrain(Trainable):
def _setup(self, config):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def _train(self):
def step(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
def save_checkpoint(self, path):
return self.state
def _restore(self, state):
def load_checkpoint(self, state):
self.state = state
test_trainable = TestTrain()
@@ -938,6 +938,41 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
self.assertEqual(trial.last_result["itr"], 99)
def testBackwardsCompat(self):
class TestTrain(Trainable):
def _setup(self, config):
self.state = {"hi": 1, "iter": 0}
def _train(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
test_trainable = TestTrain()
checkpoint_1 = test_trainable.save()
test_trainable.train()
checkpoint_2 = test_trainable.save()
self.assertNotEqual(checkpoint_1, checkpoint_2)
test_trainable.restore(checkpoint_2)
self.assertEqual(test_trainable.state["iter"], 1)
test_trainable.restore(checkpoint_1)
self.assertEqual(test_trainable.state["iter"], 0)
trials = run_experiments({
"foo": {
"run": TestTrain,
"checkpoint_at_end": True
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
if __name__ == "__main__":
import pytest
+4 -4
View File
@@ -623,18 +623,18 @@ def test_cluster_interrupt(start_connected_cluster, tmpdir):
class _Mock(tune.Trainable):
"""Finishes on the 4th iteration."""
def _setup(self, config):
def setup(self, config):
self.state = {"hi": 0}
def _train(self):
def step(self):
self.state["hi"] += 1
time.sleep(0.5)
return {"done": self.state["hi"] >= 4}
def _save(self, path):
def save_checkpoint(self, path):
return self.state
def _restore(self, state):
def load_checkpoint(self, state):
self.state = state
# Removes indent from class.
@@ -21,19 +21,19 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
4: [7, 5, 5, 5, 5, 5, 5, 5, 3]
}
def _setup(self, config):
def setup(self, config):
self.id = config["id"]
self.idx = 0
def _train(self):
def step(self):
val = self.scores_dict[self.id][self.idx]
self.idx += 1
return {"score": val}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
pass
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
pass
self.MockTrainable = MockTrainable
@@ -145,7 +145,7 @@ class RayTrialExecutorTest(unittest.TestCase):
"""Tests that reset works as expected."""
class B(Trainable):
def _train(self):
def step(self):
return dict(timesteps_this_iter=1, done=True)
def reset_config(self, config):
+6 -6
View File
@@ -74,7 +74,7 @@ class RunExperimentTest(unittest.TestCase):
reporter(timesteps_total=i)
class B(Trainable):
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
register_trainable("f1", train)
@@ -91,10 +91,10 @@ class RunExperimentTest(unittest.TestCase):
def testCheckpointAtEnd(self):
class train(Trainable):
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
def save_checkpoint(self, path):
checkpoint = os.path.join(path, "checkpoint")
with open(checkpoint, "w") as f:
f.write("OK")
@@ -112,7 +112,7 @@ class RunExperimentTest(unittest.TestCase):
def testExportFormats(self):
class train(Trainable):
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def _export_model(self, export_formats, export_dir):
@@ -134,7 +134,7 @@ class RunExperimentTest(unittest.TestCase):
def testInvalidExportFormats(self):
class train(Trainable):
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def _export_model(self, export_formats, export_dir):
@@ -156,7 +156,7 @@ class RunExperimentTest(unittest.TestCase):
ray.init(resources={"hi": 3})
class train(Trainable):
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
trials = run_experiments({
@@ -1141,10 +1141,10 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
pbt = self.basicSetup(perturbation_interval=2)
class train(tune.Trainable):
def _train(self):
def step(self):
return {"mean_accuracy": self.training_iteration}
def _save(self, path):
def save_checkpoint(self, path):
checkpoint = os.path.join(path, "checkpoint")
with open(checkpoint, "w") as f:
f.write("OK")
@@ -1173,16 +1173,16 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
pbt = self.basicSetup(perturbation_interval=2)
class train_dict(tune.Trainable):
def _setup(self, config):
def setup(self, config):
self.state = {"hi": 1}
def _train(self):
def step(self):
return {"mean_accuracy": self.training_iteration}
def _save(self, path):
def save_checkpoint(self, path):
return self.state
def _restore(self, state):
def load_checkpoint(self, state):
self.state = state
trial_hyperparams = {
@@ -19,20 +19,20 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
class MockTrainable(Trainable):
_name = "MockTrainable"
def _setup(self, config):
def setup(self, config):
self.state = {"hi": 1}
def _train(self):
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self._iteration))
with open(checkpoint_path, "wb") as f:
pickle.dump(self.state, f)
return checkpoint_path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as f:
extra_data = pickle.load(f)
self.state.update(extra_data)
@@ -154,18 +154,18 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
"""Tests that passing the checkpoint_dir right back works."""
class MockTrainable(Trainable):
def _setup(self, config):
def setup(self, config):
pass
def _train(self):
def step(self):
return {"score": 1}
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
with open(os.path.join(checkpoint_dir, "test.txt"), "wb") as f:
pickle.dump("test", f)
return checkpoint_dir
def _restore(self, checkpoint_dir):
def load_checkpoint(self, checkpoint_dir):
with open(os.path.join(checkpoint_dir, "test.txt"), "rb") as f:
x = pickle.load(f)
+139 -46
View File
@@ -15,6 +15,7 @@ import time
import uuid
import ray
from ray.util.debug import log_once
from ray.tune.logger import UnifiedLogger
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
@@ -174,11 +175,11 @@ class Trainable:
Calling ``save()`` should save the training state of a trainable to disk,
and ``restore(path)`` should restore a trainable to the given state.
Generally you only need to implement ``_setup``, ``_train``,
``_save``, and ``_restore`` when subclassing Trainable.
Generally you only need to implement ``build``, ``step``,
``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable.
Other implementation methods that may be helpful to override are
``_log_result``, ``reset_config``, ``_stop``, and ``_export_model``.
``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``.
When using Tune, Tune will convert this class into a Ray actor, which
runs on a separate process. Tune will also change the current working
@@ -192,7 +193,7 @@ class Trainable:
Sets up logging and points ``self.logdir`` to a directory in which
training outputs should be placed.
Subclasses should prefer defining ``_setup()`` instead of overriding
Subclasses should prefer defining ``build()`` instead of overriding
``__init__()`` directly.
Args:
@@ -228,11 +229,11 @@ class Trainable:
self._trial_info = trial_info
start_time = time.time()
self._setup(copy.deepcopy(self.config))
self.setup(copy.deepcopy(self.config))
setup_time = time.time() - start_time
if setup_time > SETUP_TIME_THRESHOLD:
logger.info("_setup took {:.3f} seconds. If your trainable is "
"slow to initialize, consider setting "
logger.info("Trainable.setup took {:.3f} seconds. If your "
"trainable is slow to initialize, consider setting "
"reuse_actors=True to reduce actor creation "
"overheads.".format(setup_time))
self._local_ip = self.get_current_ip()
@@ -277,8 +278,9 @@ class Trainable:
def train(self):
"""Runs one logical iteration of training.
Subclasses should override ``_train()`` instead to return results.
This class automatically fills the following fields in the result:
Calls ``step()`` internally. Subclasses should override ``step()``
instead to return results.
This method automatically fills the following fields in the result:
`done` (bool): training is terminated. Filled only if not provided.
@@ -295,7 +297,7 @@ class Trainable:
`training_iteration` (int): The index of this
training iteration, e.g. call to train(). This is incremented
after `_train()` is called.
after `step()` is called.
`pid` (str): The pid of the training process.
@@ -314,8 +316,8 @@ class Trainable:
A dict that describes training progress.
"""
start = time.time()
result = self._train()
assert isinstance(result, dict), "_train() needs to return a dict."
result = self.step()
assert isinstance(result, dict), "step() needs to return a dict."
# We do not modify internal state nor update this result if duplicate.
if RESULT_DUPLICATE in result:
@@ -376,7 +378,7 @@ class Trainable:
if monitor_data:
result.update(monitor_data)
self._log_result(result)
self.log_result(result)
return result
@@ -404,7 +406,7 @@ class Trainable:
"""
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
checkpoint_dir or self.logdir, index=self.iteration)
checkpoint = self._save(checkpoint_dir)
checkpoint = self.save_checkpoint(checkpoint_dir)
trainable_state = self.get_state()
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint,
@@ -451,9 +453,9 @@ class Trainable:
with open(checkpoint_path, "rb") as loaded_state:
checkpoint_dict = pickle.load(loaded_state)
checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
self._restore(checkpoint_dict)
self.load_checkpoint(checkpoint_dict)
else:
self._restore(checkpoint_path)
self.load_checkpoint(checkpoint_path)
self._time_since_restore = 0.0
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
@@ -523,7 +525,7 @@ class Trainable:
be updated to reflect the latest parameter information in Ray logs.
Args:
new_config (dir): Updated hyperparameter configuration
new_config (dict): Updated hyperparameter configuration
for the trainable.
Returns:
@@ -532,10 +534,14 @@ class Trainable:
return False
def stop(self):
"""Releases all resources used by this trainable."""
"""Releases all resources used by this trainable.
Calls ``Trainable.cleanup`` internally. Subclasses should override
``Trainable.cleanup`` for custom cleanup procedures.
"""
self._result_logger.flush()
self._result_logger.close()
self._stop()
self.cleanup()
@property
def logdir(self):
@@ -603,7 +609,7 @@ class Trainable:
"""Returns configuration passed in by Tune."""
return self.config
def _train(self):
def step(self):
"""Subclasses should override this to implement train().
The return value will be automatically passed to the loggers. Users
@@ -612,27 +618,43 @@ class Trainable:
trial. Note that manual checkpointing only works when subclassing
Trainables.
.. versionadded:: 0.8.7
Returns:
A dict that describes training progress.
"""
result = self._train()
if self._is_overriden("_train") and log_once("_train"):
logger.warning(
"Trainable._train is deprecated and will be removed in "
"a future version of Ray. Override Trainable.step instead.")
return result
def _train(self):
"""This method is deprecated. Override 'Trainable.step' instead.
.. versionchanged:: 0.8.7
"""
raise NotImplementedError
def _save(self, tmp_checkpoint_dir):
def save_checkpoint(self, tmp_checkpoint_dir):
"""Subclasses should override this to implement ``save()``.
Warning:
Do not rely on absolute paths in the implementation of ``_save``
and ``_restore``.
Do not rely on absolute paths in the implementation of
``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``.
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors
before execution.
Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/
``Trainable.load_checkpoint`` errors before execution.
>>> from ray.tune.utils import validate_save_restore
>>> validate_save_restore(MyTrainableClass)
>>> validate_save_restore(MyTrainableClass, use_object_store=True)
.. versionadded:: 0.8.7
Args:
tmp_checkpoint_dir (str): The directory where the checkpoint
file must be stored. In a Tune run, if the trial is paused,
@@ -641,44 +663,60 @@ class Trainable:
Returns:
A dict or string. If string, the return value is expected to be
prefixed by `tmp_checkpoint_dir`. If dict, the return value will
be automatically serialized by Tune and passed to `_restore()`.
be automatically serialized by Tune and
passed to ``Trainable.load_checkpoint()``.
Examples:
>>> print(trainable1._save("/tmp/checkpoint_1"))
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2._save("/tmp/checkpoint_2"))
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
{"some": "data"}
>>> trainable._save("/tmp/bad_example")
>>> trainable.save_checkpoint("/tmp/bad_example")
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
"""
checkpoint = self._save(tmp_checkpoint_dir)
if self._is_overriden("_save") and log_once("_save"):
logger.warning(
"Trainable._save is deprecated and will be removed in a "
"future version of Ray. Override "
"Trainable.save_checkpoint instead.")
return checkpoint
def _save(self, tmp_checkpoint_dir):
"""This method is deprecated. Override 'save_checkpoint' instead.
.. versionchanged:: 0.8.7
"""
raise NotImplementedError
def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
"""Subclasses should override this to implement restore().
Warning:
In this method, do not rely on absolute paths. The absolute
path of the checkpoint_dir used in ``_save`` may be changed.
path of the checkpoint_dir used in ``Trainable.save_checkpoint``
may be changed.
If ``_save`` returned a prefixed string, the prefix of the checkpoint
string returned by ``_save`` may be changed. This is because trial
pausing depends on temporary directories.
If ``Trainable.save_checkpoint`` returned a prefixed string, the
prefix of the checkpoint string returned by
``Trainable.save_checkpoint`` may be changed.
This is because trial pausing depends on temporary directories.
The directory structure under the checkpoint_dir provided to ``_save``
is preserved.
The directory structure under the checkpoint_dir provided to
``Trainable.save_checkpoint`` is preserved.
See the example below.
.. code-block:: python
class Example(Trainable):
def _save(self, checkpoint_path):
def save_checkpoint(self, checkpoint_path):
print(checkpoint_path)
return os.path.join(checkpoint_path, "my/check/point")
def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
print(checkpoint)
>>> trainer = Example()
@@ -687,39 +725,78 @@ class Trainable:
>>> trainer.restore_from_object(obj) # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
.. versionadded:: 0.8.7
Args:
checkpoint (str|dict): If dict, the return value is as
returned by `_save`. If a string, then it is a checkpoint path
that may have a different prefix than that returned by `_save`.
The directory structure underneath the `checkpoint_dir`
`_save` is preserved.
returned by `save_checkpoint`. If a string, then it is
a checkpoint path that may have a different prefix than that
returned by `save_checkpoint`. The directory structure
underneath the `checkpoint_dir` `save_checkpoint` is preserved.
"""
self._restore(checkpoint)
if self._is_overriden("_restore") and log_once("_restore"):
logger.warning(
"Trainable._restore is deprecated and will be removed in a "
"future version of Ray. Override Trainable.load_checkpoint "
"instead.")
def _restore(self, checkpoint):
"""This method is deprecated. Override 'load_checkpoint' instead.
.. versionchanged:: 0.8.7
"""
raise NotImplementedError
def _setup(self, config):
def setup(self, config):
"""Subclasses should override this for custom initialization.
.. versionadded:: 0.8.7
Args:
config (dict): Hyperparameters and other configs given.
Copy of `self.config`.
"""
self._setup(config)
if self._is_overriden("_setup") and log_once("_setup"):
logger.warning(
"Trainable._setup is deprecated and will be removed in "
"a future version of Ray. Override Trainable.setup instead.")
def _setup(self, config):
"""This method is deprecated. Override 'setup' instead.
.. versionchanged:: 0.8.7
"""
pass
def _log_result(self, result):
def log_result(self, result):
"""Subclasses can optionally override this to customize logging.
The logging here is done on the worker process rather than
the driver. You may want to turn off driver logging via the
``loggers`` parameter in ``tune.run`` when overriding this function.
.. versionadded:: 0.8.7
Args:
result (dict): Training result returned by _train().
result (dict): Training result returned by step().
"""
self._log_result(result)
if self._is_overriden("_log_result") and log_once("_log_result"):
logger.warning(
"Trainable._log_result is deprecated and will be removed in "
"a future version of Ray. Override "
"Trainable.log_result instead.")
def _log_result(self, result):
"""This method is deprecated. Override 'log_result' instead.
.. versionchanged:: 0.8.7
"""
self._result_logger.on_result(result)
def _stop(self):
def cleanup(self):
"""Subclasses should override this for any cleanup on stop.
If any Ray actors are launched in the Trainable (i.e., with a RLlib
@@ -727,6 +804,19 @@ class Trainable:
You can kill a Ray actor by calling `actor.__ray_terminate__.remote()`
on the actor.
.. versionadded:: 0.8.7
"""
self._stop()
if self._is_overriden("_stop") and log_once("trainable.cleanup"):
logger.warning(
"Trainable._stop is deprecated and will be removed in "
"a future version of Ray. Override Trainable.cleanup instead.")
def _stop(self):
"""This method is deprecated. Override 'cleanup' instead.
.. versionchanged:: 0.8.7
"""
pass
@@ -741,3 +831,6 @@ class Trainable:
A dict that maps ExportFormats to successfully exported models.
"""
return {}
def _is_overriden(self, key):
return getattr(self, key).__code__ != getattr(Trainable, key).__code__
-1
View File
@@ -293,7 +293,6 @@ class TrialRunner:
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f, cls=_TuneFunctionDecoder)
self.checkpoint_file = newest_ckpt_path
logger.warning("".join([
"Attempting to resume experiment from {}. ".format(
self._local_checkpoint_dir), "This feature is experimental, "
+5 -5
View File
@@ -171,7 +171,7 @@ class TFTrainable(Trainable):
extra_cpu=config["num_replicas"],
extra_gpu=int(config["use_gpu"]) * config["num_replicas"])
def _setup(self, config):
def setup(self, config):
self._trainer = TFTrainer(
model_creator=config["model_creator"],
data_creator=config["data_creator"],
@@ -180,7 +180,7 @@ class TFTrainable(Trainable):
use_gpu=config["use_gpu"],
num_cpus_per_worker=config.get("num_cpus_per_worker", 1))
def _train(self):
def step(self):
train_stats = self._trainer.train()
validation_stats = self._trainer.validate()
@@ -189,11 +189,11 @@ class TFTrainable(Trainable):
return train_stats
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
return self._trainer.save(os.path.join(checkpoint_dir, "model"))
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
return self._trainer.restore(checkpoint_path)
def _stop(self):
def cleanup(self):
self._trainer.shutdown()
+6 -6
View File
@@ -785,7 +785,7 @@ class BaseTorchTrainable(Trainable):
# TorchTrainable is subclass of BaseTorchTrainable.
class CustomTrainable(TorchTrainable):
def _train(self):
def step(self):
for i in range(5):
train_stats = self.trainer.train()
validation_stats = self.trainer.validate()
@@ -799,11 +799,11 @@ class BaseTorchTrainable(Trainable):
"""
def _setup(self, config):
def setup(self, config):
"""Constructs a TorchTrainer object as `self.trainer`."""
self._trainer = self._create_trainer(config)
def _train(self):
def step(self):
"""Calls `self.trainer.train()` and `self.trainer.validate()` once.
You may want to override this if using a custom LR scheduler.
@@ -813,20 +813,20 @@ class BaseTorchTrainable(Trainable):
stats = merge_dicts(train_stats, validation_stats)
return stats
def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
"""Returns a path containing the trainer state."""
checkpoint_path = os.path.join(checkpoint_dir, "trainer.checkpoint")
self.trainer.save(checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
"""Restores the trainer state.
Override this if you have state external to the Trainer object.
"""
return self.trainer.load(checkpoint_path)
def _stop(self):
def cleanup(self):
"""Shuts down the trainer."""
self.trainer.shutdown()