mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:33:16 +08:00
[RaySGD] Simplify Builder Process (#10321)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -11,8 +11,8 @@ from torch.utils.data import DataLoader
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.training_operator import (_TestingOperator,
|
||||
_TestMetricsOperator)
|
||||
from ray.util.sgd.torch.training_operator import (
|
||||
get_test_operator, get_test_metrics_operator, TrainingOperator)
|
||||
from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
|
||||
BATCH_SIZE)
|
||||
@@ -44,13 +44,12 @@ def ray_start_4_cpus():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
Operator = TrainingOperator.from_creators(
|
||||
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=1)
|
||||
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
@@ -59,49 +58,9 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=1)
|
||||
trainer.train(num_steps=1)
|
||||
trainer.max_replicas = 2
|
||||
results = trainer.train(num_steps=1, reduce_results=False)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_non_serialized_data(ray_start_2_cpus): # noqa: F811
|
||||
duration = 10
|
||||
|
||||
def slow_data(func):
|
||||
def slowed_func(*args, **kwargs):
|
||||
time.sleep(duration)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return slowed_func
|
||||
|
||||
start = time.time()
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=slow_data(data_creator),
|
||||
optimizer_creator=optimizer_creator,
|
||||
serialize_data_creation=False,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
elapsed = time.time() - start
|
||||
assert elapsed < duration * 2
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
TestOperator = get_test_operator(Operator)
|
||||
trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2)
|
||||
trainer.train(num_steps=1)
|
||||
trainer.shutdown()
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -111,11 +70,7 @@ def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers)
|
||||
training_operator_cls=Operator, num_workers=num_workers)
|
||||
for i in range(3):
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["val_loss"]
|
||||
@@ -165,22 +120,26 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
iterator=iter(data))
|
||||
return result
|
||||
|
||||
def multi_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
class MultiModelOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
models = nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
opts = [
|
||||
torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
for model in models
|
||||
]
|
||||
loss = nn.MSELoss()
|
||||
train_dataloader, val_dataloader = data_creator(config)
|
||||
self.models, self.optimizers, self.criterion = self.register(
|
||||
models=models, optimizers=opts, criterion=loss)
|
||||
self.register_data(
|
||||
train_loader=train_dataloader,
|
||||
validation_loader=val_dataloader)
|
||||
|
||||
def multi_optimizer_creator(models, config):
|
||||
opts = [
|
||||
torch.optim.SGD(model.parameters(), lr=0.0001) for model in models
|
||||
]
|
||||
return opts[0], opts[1]
|
||||
TestOperator = get_test_operator(MultiModelOperator)
|
||||
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=multi_model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers)
|
||||
trainer1.train()
|
||||
state = trainer1.state_dict()
|
||||
@@ -190,12 +149,8 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
model_creator=multi_model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers)
|
||||
trainer2.load_state_dict(state)
|
||||
|
||||
@@ -252,17 +207,29 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
]
|
||||
return schedulers[0] if len(schedulers) == 1 else schedulers
|
||||
|
||||
class MultiModelOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
models = multi_model_creator(config)
|
||||
optimizers = multi_optimizer_creator(models, config)
|
||||
schedulers = multi_scheduler_creator(optimizers, config)
|
||||
train_loader, val_loader = data_creator(config)
|
||||
loss = nn.MSELoss()
|
||||
|
||||
self.models, self.optimizers, self.criterion, self.schedulers = \
|
||||
self.register(models=models, optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
criterion=loss)
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=val_loader)
|
||||
|
||||
TestOperator = get_test_operator(MultiModelOperator)
|
||||
|
||||
for model_count in range(1, 3):
|
||||
for optimizer_count in range(1, 3):
|
||||
for scheduler_count in range(1, 3):
|
||||
trainer = TorchTrainer(
|
||||
model_creator=multi_model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=multi_scheduler_creator,
|
||||
scheduler_step_freq="epoch",
|
||||
training_operator_cls=_TestingOperator,
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"models": model_count,
|
||||
@@ -284,24 +251,31 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
return torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
class TestTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = model_creator(config)
|
||||
optimizer = optimizer_creator(model, config)
|
||||
train_loader, val_loader = data_creator(config)
|
||||
scheduler = scheduler_creator(optimizer, config)
|
||||
loss = nn.MSELoss()
|
||||
|
||||
self.model, self.optimizer, self.criterion, self.scheduler = \
|
||||
self.register(
|
||||
models=model, optimizers=optimizer,
|
||||
criterion=loss, schedulers=scheduler)
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=val_loader)
|
||||
|
||||
if scheduler_freq is None:
|
||||
with pytest.raises(ValueError):
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
scheduler_creator=scheduler_creator,
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=TestTrainingOperator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
else:
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
scheduler_creator=scheduler_creator,
|
||||
training_operator_cls=TestTrainingOperator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
|
||||
for i in range(3):
|
||||
@@ -310,11 +284,7 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
|
||||
|
||||
def test_profiling(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
trainer = TorchTrainer(training_operator_cls=Operator)
|
||||
|
||||
stats = trainer.train(profile=True)
|
||||
assert "profile" in stats
|
||||
@@ -334,11 +304,14 @@ def test_dataset(ray_start_4_cpus):
|
||||
model_creator = mlp_identity.model_creator
|
||||
optimizer_creator = mlp_identity.optimizer_creator
|
||||
dataset_creator = mlp_identity.dataset_creator
|
||||
trainer = TorchTrainer(
|
||||
|
||||
DatasetOperator = TrainingOperator.from_creators(
|
||||
model_creator=model_creator,
|
||||
data_creator=None,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=torch.nn.MSELoss,
|
||||
loss_creator=nn.MSELoss)
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=DatasetOperator,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
@@ -366,12 +339,13 @@ def test_split_batch(ray_start_2_cpus):
|
||||
|
||||
data_size = 600
|
||||
batch_size = 21
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
data_creator,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=2,
|
||||
config={
|
||||
BATCH_SIZE: batch_size,
|
||||
@@ -398,11 +372,13 @@ def test_reduce_result(ray_start_2_cpus):
|
||||
|
||||
data_size = 600
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
data_creator,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=2,
|
||||
config={"data_size": data_size})
|
||||
list_stats = trainer.train(reduce_results=False, profile=True)
|
||||
@@ -426,11 +402,10 @@ def test_metrics(ray_start_2_cpus, num_workers):
|
||||
|
||||
train_scores = [1] + ([0] * num_train_steps)
|
||||
val_scores = [1] + ([0] * num_val_steps)
|
||||
|
||||
TestOperator = get_test_metrics_operator(Operator)
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"scores": train_scores,
|
||||
@@ -439,8 +414,7 @@ def test_metrics(ray_start_2_cpus, num_workers):
|
||||
"batch_size": batch_size,
|
||||
"data_size": data_size,
|
||||
"val_size": val_size
|
||||
},
|
||||
training_operator_cls=_TestMetricsOperator)
|
||||
})
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
# Test that we output mean and last of custom metrics in an epoch
|
||||
@@ -475,11 +449,9 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
|
||||
train_scores = [np.nan] + ([0] * num_train_steps)
|
||||
val_scores = [np.nan] + ([0] * num_val_steps)
|
||||
TestOperator = get_test_metrics_operator(Operator)
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"scores": train_scores,
|
||||
@@ -488,8 +460,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
"batch_size": batch_size,
|
||||
"data_size": data_size,
|
||||
"val_size": val_size
|
||||
},
|
||||
training_operator_cls=_TestMetricsOperator)
|
||||
})
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
assert "score" in stats
|
||||
@@ -506,19 +477,20 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
data_creator,
|
||||
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
|
||||
scheduler_step_freq="manual",
|
||||
training_operator_cls=_TestingOperator)
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
TestOperator = get_test_operator(TestOperator)
|
||||
trainer = TorchTrainer(
|
||||
scheduler_step_freq="manual", training_operator_cls=TestOperator)
|
||||
trainer.update_scheduler(0.5)
|
||||
trainer.update_scheduler(0.5)
|
||||
assert all(
|
||||
trainer.apply_all_operators(
|
||||
lambda op: op.schedulers[0].last_epoch == 2))
|
||||
lambda op: op._schedulers[0].last_epoch == 2))
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@@ -526,10 +498,7 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
**{
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.MSELoss(),
|
||||
"training_operator_cls": Operator,
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": False,
|
||||
"backend": "gloo",
|
||||
@@ -560,11 +529,7 @@ def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers,
|
||||
tmp_path): # noqa: F811
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers)
|
||||
training_operator_cls=Operator, num_workers=num_workers)
|
||||
trainer1.train()
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
@@ -574,11 +539,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers,
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers)
|
||||
training_operator_cls=Operator, num_workers=num_workers)
|
||||
trainer2.load(checkpoint_path)
|
||||
|
||||
model2 = trainer2.get_model()
|
||||
@@ -597,12 +558,7 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
wrap_ddp=False,
|
||||
num_workers=2)
|
||||
training_operator_cls=Operator, wrap_ddp=False, num_workers=2)
|
||||
trainer1.train()
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
@@ -613,12 +569,7 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
wrap_ddp=False,
|
||||
num_workers=2)
|
||||
training_operator_cls=Operator, wrap_ddp=False, num_workers=2)
|
||||
trainer2.load(checkpoint_path)
|
||||
|
||||
model2 = trainer2.get_model()
|
||||
@@ -672,12 +623,14 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
step_with_fail = gen_step_with_fail(3)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=single_loader,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
num_workers=2)
|
||||
|
||||
@@ -697,13 +650,15 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
step_with_fail = gen_step_with_fail(1)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=single_loader,
|
||||
optimizer_creator=optimizer_creator,
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
|
||||
@ray.remote
|
||||
@@ -728,13 +683,16 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
step_with_fail = gen_step_with_fail(2)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=single_loader,
|
||||
optimizer_creator=optimizer_creator,
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
|
||||
# MAX RETRIES SHOULD BE ON BY DEFAULT
|
||||
@@ -778,12 +736,13 @@ def test_multi_input_model(ray_start_2_cpus):
|
||||
)
|
||||
return train_loader, None
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=1)
|
||||
Operator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
data_creator,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
|
||||
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
@@ -794,4 +753,5 @@ def test_multi_input_model(ray_start_2_cpus):
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
||||
|
||||
@@ -50,19 +50,22 @@ def create_dataloaders(config):
|
||||
|
||||
|
||||
class TestTorchRunner(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.Operator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
create_dataloaders,
|
||||
loss_creator=loss_creator)
|
||||
|
||||
def testValidate(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
class MockOperator(self.Operator):
|
||||
def setup(self, config):
|
||||
super(MockOperator, self).setup(config)
|
||||
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
|
||||
self.validate = MagicMock(returns=dict(mean_accuracy=10))
|
||||
|
||||
runner = TorchRunner(
|
||||
model_creator,
|
||||
create_dataloaders,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
training_operator_cls=MockOperator)
|
||||
runner.setup()
|
||||
runner = TorchRunner(training_operator_cls=MockOperator)
|
||||
runner.setup_operator()
|
||||
runner.train_epoch()
|
||||
runner.train_epoch()
|
||||
result = runner.train_epoch()
|
||||
@@ -72,21 +75,17 @@ class TestTorchRunner(unittest.TestCase):
|
||||
self.assertEqual(result["epoch"], 3)
|
||||
|
||||
def testtrain_epoch(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
class MockOperator(self.Operator):
|
||||
def setup(self, config):
|
||||
super(MockOperator, self).setup(config)
|
||||
self.count = 0
|
||||
|
||||
def train_epoch(self, *args, **kwargs):
|
||||
self.count += 1
|
||||
return {"count": self.count}
|
||||
|
||||
runner = TorchRunner(
|
||||
model_creator,
|
||||
create_dataloaders,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
training_operator_cls=MockOperator)
|
||||
runner.setup()
|
||||
runner = TorchRunner(training_operator_cls=MockOperator)
|
||||
runner.setup_operator()
|
||||
runner.train_epoch(num_steps=1)
|
||||
runner.train_epoch(num_steps=1)
|
||||
result = runner.train_epoch()
|
||||
@@ -95,11 +94,6 @@ class TestTorchRunner(unittest.TestCase):
|
||||
self.assertEqual(result["epoch"], 3)
|
||||
|
||||
def testGivens(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
|
||||
self.validate = MagicMock(returns=dict(mean_accuracy=10))
|
||||
|
||||
def three_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
@@ -109,20 +103,27 @@ class TestTorchRunner(unittest.TestCase):
|
||||
]
|
||||
return opts[0], opts[1], opts[2]
|
||||
|
||||
runner = TorchRunner(
|
||||
three_model_creator,
|
||||
single_loader,
|
||||
three_optimizer_creator,
|
||||
loss_creator,
|
||||
training_operator_cls=MockOperator)
|
||||
runner.setup()
|
||||
class MockOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
models = three_model_creator(config)
|
||||
optimizers = three_optimizer_creator(models, config)
|
||||
loader = single_loader(config)
|
||||
loss = loss_creator(config)
|
||||
self.models, self.optimizers, self.criterion = \
|
||||
self.register(models=models, optimizers=optimizers,
|
||||
criterion=loss)
|
||||
self.register_data(train_loader=loader, validation_loader=None)
|
||||
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
|
||||
self.validate = MagicMock(returns=dict(mean_accuracy=10))
|
||||
|
||||
runner = TorchRunner(training_operator_cls=MockOperator)
|
||||
runner.setup_operator()
|
||||
|
||||
self.assertEqual(len(runner.given_models), 3)
|
||||
self.assertEqual(len(runner.given_optimizers), 3)
|
||||
|
||||
runner2 = TorchRunner(model_creator, single_loader, optimizer_creator,
|
||||
loss_creator)
|
||||
runner2.setup()
|
||||
runner2 = TorchRunner(training_operator_cls=self.Operator)
|
||||
runner2.setup_operator()
|
||||
|
||||
self.assertNotEqual(runner2.given_models, runner2.models)
|
||||
self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
|
||||
@@ -132,49 +133,42 @@ class TestTorchRunner(unittest.TestCase):
|
||||
return (LinearDataset(2, 5), LinearDataset(2, 5, size=400),
|
||||
LinearDataset(2, 5, size=400))
|
||||
|
||||
runner = TorchRunner(model_creator, three_data_loader,
|
||||
optimizer_creator, loss_creator)
|
||||
with self.assertRaises(ValueError):
|
||||
runner.setup()
|
||||
ThreeOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
three_data_loader,
|
||||
loss_creator=loss_creator)
|
||||
|
||||
runner2 = TorchRunner(model_creator, three_data_loader,
|
||||
optimizer_creator, loss_creator)
|
||||
runner = TorchRunner(training_operator_cls=ThreeOperator)
|
||||
with self.assertRaises(ValueError):
|
||||
runner2.setup()
|
||||
runner.setup_operator()
|
||||
|
||||
runner2 = TorchRunner(training_operator_cls=ThreeOperator)
|
||||
with self.assertRaises(ValueError):
|
||||
runner2.setup_operator()
|
||||
|
||||
def testSingleLoader(self):
|
||||
runner = TorchRunner(model_creator, single_loader, optimizer_creator,
|
||||
loss_creator)
|
||||
runner.setup()
|
||||
SingleOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=loss_creator)
|
||||
runner = TorchRunner(training_operator_cls=SingleOperator)
|
||||
runner.setup_operator()
|
||||
runner.train_epoch()
|
||||
with self.assertRaises(ValueError):
|
||||
runner.validate()
|
||||
|
||||
def testNativeLoss(self):
|
||||
runner = TorchRunner(
|
||||
NativeOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=nn.MSELoss)
|
||||
runner.setup()
|
||||
runner = TorchRunner(training_operator_cls=NativeOperator)
|
||||
runner.setup_operator()
|
||||
runner.train_epoch()
|
||||
|
||||
def testMultiModel(self):
|
||||
def multi_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
def multi_optimizer_creator(models, config):
|
||||
opts = [
|
||||
torch.optim.SGD(model.parameters(), lr=0.1) for model in models
|
||||
]
|
||||
return opts[0], opts[1], opts[2]
|
||||
|
||||
runner = TorchRunner(multi_model_creator, single_loader,
|
||||
multi_optimizer_creator, loss_creator)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
runner.setup()
|
||||
|
||||
|
||||
class TestLocalDistributedRunner(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -4,6 +4,7 @@ logger = logging.getLogger(__name__)
|
||||
TorchTrainer = None
|
||||
TrainingOperator = None
|
||||
BaseTorchTrainable = None
|
||||
CreatorOperator = None
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
@@ -11,9 +12,13 @@ try:
|
||||
from ray.util.sgd.torch.torch_trainer import (TorchTrainer,
|
||||
BaseTorchTrainable)
|
||||
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
from ray.util.sgd.torch.training_operator import (TrainingOperator,
|
||||
CreatorOperator)
|
||||
|
||||
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
|
||||
__all__ = [
|
||||
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
|
||||
"CreatorOperator"
|
||||
]
|
||||
except ImportError as e:
|
||||
logger.warning(e)
|
||||
logger.warning("PyTorch not found. TorchTrainer will not be available")
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.utils import setup_process_group
|
||||
@@ -43,9 +42,6 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
self.world_rank = None
|
||||
|
||||
def setup(self):
|
||||
raise RuntimeError("Need to call setup commands separately.")
|
||||
|
||||
def setup_process_group(self, url, world_rank, world_size, timeout):
|
||||
"""Connects the distributed PyTorch backend.
|
||||
|
||||
@@ -60,7 +56,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
setup_process_group(
|
||||
url, world_rank, world_size, timeout, backend=self.backend)
|
||||
|
||||
def setup_ddp_and_operator(self):
|
||||
def setup_operator(self):
|
||||
"""Runs distributed coordination components.
|
||||
|
||||
This helps avoid timeouts due to creator functions (perhaps
|
||||
@@ -70,29 +66,18 @@ class DistributedTorchRunner(TorchRunner):
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
device_ids = self.get_device_ids()
|
||||
|
||||
# Wrap dataloaders
|
||||
self._wrap_dataloaders()
|
||||
|
||||
training_models = self.models
|
||||
if self.wrap_ddp:
|
||||
# This needs to happen after apex
|
||||
training_models = [
|
||||
DistributedDataParallel(model, device_ids=device_ids)
|
||||
for model in self.models
|
||||
]
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=training_models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
train_loader=self.train_loader,
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=self.world_rank,
|
||||
schedulers=self.schedulers,
|
||||
device_ids=device_ids,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
wrap_ddp=self.wrap_ddp,
|
||||
wrap_distributed_sampler=True,
|
||||
add_dist_sampler=self.add_dist_sampler,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
def get_device_ids(self):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
|
||||
@@ -72,6 +72,17 @@ def init_hook():
|
||||
|
||||
class Training(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = getattr(models, config.get("model"))()
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=0.01 * config["lr_scaler"])
|
||||
train_data = LinearDataset(4,
|
||||
2) # Have to use dummy data for training.
|
||||
|
||||
self.model, self.optimizer = self.register(
|
||||
models=model,
|
||||
optimizers=optimizer,
|
||||
)
|
||||
self.register_data(train_loader=train_data, validation_loader=None)
|
||||
data = torch.randn(args.batch_size, 3, 224, 224)
|
||||
target = torch.LongTensor(args.batch_size).random_() % 1000
|
||||
if args.cuda:
|
||||
@@ -107,14 +118,12 @@ if __name__ == "__main__":
|
||||
print("Number of %ss: %d" % (device, num_workers))
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=lambda cfg: getattr(models, args.model)(),
|
||||
optimizer_creator=lambda model, cfg: optim.SGD(
|
||||
model.parameters(), lr=0.01 * cfg.get("lr_scaler")),
|
||||
data_creator=lambda cfg: LinearDataset(4, 2), # Mock dataset.
|
||||
initialization_hook=init_hook,
|
||||
config=dict(
|
||||
lr_scaler=num_workers),
|
||||
training_operator_cls=Training,
|
||||
initialization_hook=init_hook,
|
||||
config={
|
||||
"lr_scaler": num_workers,
|
||||
"model": args.model
|
||||
},
|
||||
num_workers=num_workers,
|
||||
use_gpu=args.cuda,
|
||||
use_fp16=args.fp16,
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
|
||||
from filelock import FileLock
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torchvision.transforms as transforms
|
||||
@@ -9,9 +11,9 @@ import torchvision.transforms as transforms
|
||||
from tqdm import trange
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
from ray.util.sgd.utils import BATCH_SIZE, override
|
||||
|
||||
|
||||
def initialization_hook():
|
||||
@@ -24,47 +26,66 @@ def initialization_hook():
|
||||
# os.environ["NCCL_DEBUG"] = "INFO"
|
||||
|
||||
|
||||
def cifar_creator(config):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
]) # meanstd transformation
|
||||
class CifarTrainingOperator(TrainingOperator):
|
||||
@override(TrainingOperator)
|
||||
def setup(self, config):
|
||||
# Create model.
|
||||
model = ResNet18(config)
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(
|
||||
root="~/data", train=True, download=True, transform=transform_train)
|
||||
validation_dataset = CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
# Create optimizer.
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.1),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
if config["test_mode"]:
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
# Load in training and validation data.
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
]) # meanstd transformation
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
return train_loader, validation_loader
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
with FileLock(".ray.lock"):
|
||||
train_dataset = CIFAR10(
|
||||
root="~/data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform_train)
|
||||
validation_dataset = CIFAR10(
|
||||
root="~/data",
|
||||
train=False,
|
||||
download=False,
|
||||
transform=transform_test)
|
||||
|
||||
if config["test_mode"]:
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns optimizer"""
|
||||
return torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.1),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
|
||||
# Create scheduler.
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer, milestones=[150, 250, 350], gamma=0.1)
|
||||
|
||||
def scheduler_creator(optimizer, config):
|
||||
return torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer, milestones=[150, 250, 350], gamma=0.1)
|
||||
# Create loss.
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Register all components.
|
||||
self.model, self.optimizer, self.criterion, self.scheduler = \
|
||||
self.register(models=model, optimizers=optimizer,
|
||||
criterion=criterion, schedulers=scheduler)
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=validation_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -105,11 +126,7 @@ if __name__ == "__main__":
|
||||
ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True)
|
||||
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
training_operator_cls=CifarTrainingOperator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=args.num_workers,
|
||||
config={
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
|
||||
from filelock import FileLock
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
@@ -11,9 +13,9 @@ import torchvision.transforms as transforms
|
||||
|
||||
import ray
|
||||
from ray.tune import CLIReporter
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
from ray.util.sgd.utils import BATCH_SIZE, override
|
||||
|
||||
|
||||
def initialization_hook():
|
||||
@@ -26,42 +28,62 @@ def initialization_hook():
|
||||
# os.environ["NCCL_DEBUG"] = "INFO"
|
||||
|
||||
|
||||
def cifar_creator(config):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
]) # meanstd transformation
|
||||
class CifarTrainingOperator(TrainingOperator):
|
||||
@override(TrainingOperator)
|
||||
def setup(self, config):
|
||||
# Create model.
|
||||
model = ResNet18(config)
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(
|
||||
root="~/data", train=True, download=True, transform=transform_train)
|
||||
validation_dataset = CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
# Create optimizer.
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.1),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
if config.get("test_mode"):
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
# Load in training and validation data.
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
]) # meanstd transformation
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
return train_loader, validation_loader
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
with FileLock(".ray.lock"):
|
||||
train_dataset = CIFAR10(
|
||||
root="~/data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform_train)
|
||||
validation_dataset = CIFAR10(
|
||||
root="~/data",
|
||||
train=False,
|
||||
download=False,
|
||||
transform=transform_test)
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns optimizer"""
|
||||
return torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.1),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
if config.get("test_mode"):
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
|
||||
# Create loss.
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
self.model, self.optimizer, self.criterion = \
|
||||
self.register(models=model, optimizers=optimizer,
|
||||
criterion=criterion,)
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=validation_loader)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -101,10 +123,7 @@ if __name__ == "__main__":
|
||||
ray.init(address=args.address, log_to_driver=True)
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
training_operator_cls=CifarTrainingOperator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=args.num_workers,
|
||||
config={
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.utils.data
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
from filelock import FileLock
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
@@ -24,22 +25,6 @@ from ray.util.sgd.torch import TrainingOperator
|
||||
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
dataset = datasets.MNIST(
|
||||
root="~/mnist/",
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(32),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, ), (0.5, )),
|
||||
]))
|
||||
if config.get("test_mode"):
|
||||
dataset = torch.utils.data.Subset(dataset, list(range(64)))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
return dataloader
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, latent_vector_size, features=32, num_channels=1):
|
||||
super(Generator, self).__init__()
|
||||
@@ -101,35 +86,57 @@ class LeNet(nn.Module):
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
discriminator = Discriminator()
|
||||
discriminator.apply(weights_init)
|
||||
|
||||
generator = Generator(
|
||||
latent_vector_size=config.get("latent_vector_size", 100))
|
||||
generator.apply(weights_init)
|
||||
return discriminator, generator
|
||||
|
||||
|
||||
def optimizer_creator(models, config):
|
||||
net_d, net_g = models
|
||||
discriminator_opt = optim.Adam(
|
||||
net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
|
||||
generator_opt = optim.Adam(
|
||||
net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
|
||||
return discriminator_opt, generator_opt
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class GANOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
discriminator = Discriminator()
|
||||
discriminator.apply(weights_init)
|
||||
|
||||
generator = Generator(
|
||||
latent_vector_size=config.get("latent_vector_size", 100))
|
||||
generator.apply(weights_init)
|
||||
models = (discriminator, generator)
|
||||
|
||||
discriminator_opt = optim.Adam(
|
||||
discriminator.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
betas=(0.5, 0.999))
|
||||
generator_opt = optim.Adam(
|
||||
generator.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
betas=(0.5, 0.999))
|
||||
optimizers = (discriminator_opt, generator_opt)
|
||||
|
||||
with FileLock(".ray.lock"):
|
||||
dataset = datasets.MNIST(
|
||||
root="~/mnist/",
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(32),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, ), (0.5, )),
|
||||
]))
|
||||
if config.get("test_mode"):
|
||||
dataset = torch.utils.data.Subset(dataset, list(range(64)))
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
self.models, self.optimizers, self.criterion = self.register(
|
||||
models=models, optimizers=optimizers, criterion=nn.BCELoss())
|
||||
self.register_data(
|
||||
train_loader=train_dataloader, validation_loader=None)
|
||||
|
||||
self.model = self.models[0]
|
||||
self.optimizer = self.optimizers[0]
|
||||
|
||||
self.classifier = LeNet()
|
||||
self.classifier.load_state_dict(
|
||||
torch.load(config["classification_model_path"]))
|
||||
@@ -232,10 +239,6 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
"classification_model_path": MODEL_PATH
|
||||
}
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.BCELoss,
|
||||
training_operator_cls=GANOperator,
|
||||
num_workers=num_workers,
|
||||
config=config,
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
from os.path import join
|
||||
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from tqdm import trange
|
||||
|
||||
import torch.nn as nn
|
||||
@@ -130,11 +131,13 @@ def main():
|
||||
|
||||
ray.init(address=args.ray_address)
|
||||
|
||||
trainer = TorchTrainer(
|
||||
CustomTrainingOperator = TrainingOperator.from_creators(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=loss_creator,
|
||||
data_creator=data_creator,
|
||||
loss_creator=loss_creator)
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=CustomTrainingOperator,
|
||||
use_tqdm=True,
|
||||
use_fp16=args.amp,
|
||||
apex_args={"opt_level": "O1"},
|
||||
|
||||
@@ -7,6 +7,47 @@ in the documentation.
|
||||
"""
|
||||
# yapf: disable
|
||||
|
||||
# __torch_operator_start__
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
# Setup all components needed for training here. This could include
|
||||
# data, models, optimizers, loss & schedulers.
|
||||
|
||||
# Setup data loaders.
|
||||
train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2,
|
||||
5)
|
||||
train_loader = DataLoader(train_dataset,
|
||||
batch_size=config["batch_size"])
|
||||
val_loader = DataLoader(val_dataset,
|
||||
batch_size=config["batch_size"])
|
||||
|
||||
# Setup model.
|
||||
model = nn.Linear(1, 1)
|
||||
|
||||
# Setup optimizer.
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
|
||||
# Setup loss.
|
||||
criterion = torch.nn.BCELoss()
|
||||
|
||||
# Setup scheduler.
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
|
||||
|
||||
# Register all of these components with Ray SGD.
|
||||
# This allows Ray SGD to do framework level setup like Cuda, DDP,
|
||||
# Distributed Sampling, FP16.
|
||||
# We also assign the return values of self.register to instance
|
||||
# attributes so we can access it in our custom training/validation
|
||||
# methods.
|
||||
self.model, self.optimizer, self.criterion, self.scheduler = \
|
||||
self.register(models=model, optimizers=optimizer,
|
||||
criterion=criterion,
|
||||
scheduler=scheduler)
|
||||
self.register_data(train_loader=train_loader, validation_loader=val_loader)
|
||||
# __torch_operator_end__
|
||||
|
||||
# __torch_model_start__
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -103,6 +144,21 @@ def scheduler_creator(optimizer, config):
|
||||
|
||||
# __torch_scheduler_end__
|
||||
|
||||
# __backwards_compat__start
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
MyTrainingOperator = TrainingOperator.from_creators(
|
||||
model_creator=model_creator, optimizer_creator=optimizer_creator,
|
||||
loss_creator=loss_creator, scheduler_creator=scheduler_creator,
|
||||
data_creator=data_creator)
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
scheduler_step_freq="epoch", # if scheduler_creator is passed in
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __backwards_compat_end
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
@@ -114,12 +170,8 @@ ray.init()
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
scheduler_step_freq="epoch", # if scheduler_creator is set
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
scheduler_step_freq="epoch", # if scheduler is used
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from filelock import FileLock
|
||||
from torch import nn
|
||||
import torchvision
|
||||
|
||||
@@ -50,28 +51,6 @@ def get_dataset(name,
|
||||
return ds, num_classes
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
# Within a machine, this code runs synchronously.
|
||||
dataset, num_classes = get_dataset(
|
||||
args.dataset, "train", get_transform(train=True))
|
||||
config["num_classes"] = num_classes
|
||||
dataset_test, _ = get_dataset(
|
||||
args.dataset, "val", get_transform(train=False))
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.data_workers,
|
||||
collate_fn=utils.collate_fn,
|
||||
drop_last=True)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test,
|
||||
batch_size=1,
|
||||
num_workers=args.data_workers,
|
||||
collate_fn=utils.collate_fn)
|
||||
return data_loader, data_loader_test
|
||||
|
||||
|
||||
def get_transform(train):
|
||||
base_size = 520
|
||||
crop_size = 480
|
||||
@@ -101,7 +80,75 @@ def criterion(inputs, target):
|
||||
return losses["out"] + 0.5 * losses["aux"]
|
||||
|
||||
|
||||
def get_optimizer(model, aux_loss):
|
||||
params_to_optimize = [
|
||||
{
|
||||
"params": [
|
||||
p for p in model.backbone.parameters() if p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for p in model.classifier.parameters() if p.requires_grad
|
||||
]
|
||||
},
|
||||
]
|
||||
if aux_loss:
|
||||
params = [
|
||||
p for p in model.aux_classifier.parameters() if p.requires_grad
|
||||
]
|
||||
params_to_optimize.append({"params": params, "lr": args.lr * 10})
|
||||
optimizer = torch.optim.SGD(
|
||||
params_to_optimize,
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
return optimizer
|
||||
|
||||
|
||||
class SegOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
args = config["args"]
|
||||
# Create Data Loaders.
|
||||
with FileLock(".ray.lock"):
|
||||
# Within a machine, this code runs synchronously.
|
||||
dataset, num_classes = get_dataset(
|
||||
args.dataset, "train", get_transform(train=True))
|
||||
config["num_classes"] = num_classes
|
||||
dataset_test, _ = get_dataset(
|
||||
args.dataset, "val", get_transform(train=False))
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.data_workers,
|
||||
collate_fn=utils.collate_fn,
|
||||
drop_last=True)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test,
|
||||
batch_size=1,
|
||||
num_workers=args.data_workers,
|
||||
collate_fn=utils.collate_fn)
|
||||
|
||||
# Create model.
|
||||
model = torchvision.models.segmentation.__dict__[args.model](
|
||||
num_classes=config["num_classes"],
|
||||
aux_loss=args.aux_loss,
|
||||
pretrained=args.pretrained)
|
||||
if config["num_workers"] > 1:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
# Create optimizer.
|
||||
optimizer = get_optimizer(model, aux_loss=args.aux_loss)
|
||||
|
||||
# Register components.
|
||||
self.model, self.optimizer = self.register(
|
||||
models=model,
|
||||
optimizers=optimizer,
|
||||
train_loader=data_loader,
|
||||
validation_loader=data_loader_test)
|
||||
|
||||
def train_batch(self, batch, batch_info):
|
||||
image, target = batch
|
||||
image, target = image.to(self.device), target.to(self.device)
|
||||
@@ -132,43 +179,6 @@ class SegOperator(TrainingOperator):
|
||||
return confmat
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
args = config["args"]
|
||||
params_to_optimize = [
|
||||
{
|
||||
"params": [
|
||||
p for p in model.backbone.parameters() if p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for p in model.classifier.parameters() if p.requires_grad
|
||||
]
|
||||
},
|
||||
]
|
||||
if args.aux_loss:
|
||||
params = [
|
||||
p for p in model.aux_classifier.parameters() if p.requires_grad
|
||||
]
|
||||
params_to_optimize.append({"params": params, "lr": args.lr * 10})
|
||||
return torch.optim.SGD(
|
||||
params_to_optimize,
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
args = config["args"]
|
||||
model = torchvision.models.segmentation.__dict__[args.model](
|
||||
num_classes=config["num_classes"],
|
||||
aux_loss=args.aux_loss,
|
||||
pretrained=args.pretrained)
|
||||
if config["num_workers"] > 1:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
@@ -176,9 +186,6 @@ def main(args):
|
||||
start_time = time.time()
|
||||
config = {"args": args, "num_workers": args.num_workers}
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
training_operator_cls=SegOperator,
|
||||
use_tqdm=True,
|
||||
use_fp16=True,
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
|
||||
|
||||
class LinearDataset(torch.utils.data.Dataset):
|
||||
@@ -67,12 +68,12 @@ def data_creator(config):
|
||||
|
||||
|
||||
def train_example(num_workers=1, use_gpu=False):
|
||||
CustomTrainingOperator = TrainingOperator.from_creators(
|
||||
model_creator=model_creator, optimizer_creator=optimizer_creator,
|
||||
data_creator=data_creator, scheduler_creator=scheduler_creator,
|
||||
loss_creator=nn.MSELoss)
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
training_operator_cls=CustomTrainingOperator,
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
config={
|
||||
|
||||
@@ -92,81 +92,76 @@ def announce_training(args, dataset_len, t_total):
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
with FileLock(os.path.expanduser("~/.download.lock")):
|
||||
args = config["args"]
|
||||
processor = processors[args.task_name]()
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def optimizer_creator(model, cfg):
|
||||
args = cfg["args"]
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0
|
||||
},
|
||||
]
|
||||
|
||||
return AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
eps=args.adam_epsilon)
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
args = config["args"]
|
||||
start = time.time()
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name
|
||||
if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
logger.info(f"tokenizer instantiation time: {time.time() - start}")
|
||||
|
||||
train_dataset = load_and_cache_examples(
|
||||
args, args.task_name, tokenizer, evaluate=False)
|
||||
train_sampler = RandomSampler(
|
||||
train_dataset) if not dist.is_initialized() else None
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.per_device_train_batch_size)
|
||||
|
||||
|
||||
class TransformerOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.args = args = config["args"]
|
||||
start = time.time()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name
|
||||
if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
logger.info(f"tokenizer instantiation time: {time.time() - start}")
|
||||
|
||||
# Load data.
|
||||
train_dataset = load_and_cache_examples(
|
||||
args, args.task_name, self.tokenizer, evaluate=False)
|
||||
train_sampler = RandomSampler(
|
||||
train_dataset) if not dist.is_initialized() else None
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.per_device_train_batch_size)
|
||||
|
||||
# Create model.
|
||||
with FileLock(os.path.expanduser("~/.download.lock")):
|
||||
processor = processors[args.task_name]()
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
args.config_name
|
||||
if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=model_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
# Create optimizer.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
eps=args.adam_epsilon)
|
||||
|
||||
# Register components.
|
||||
self.model, self.optimizer = self.register(
|
||||
models=model,
|
||||
optimizers=optimizer,
|
||||
train_loader=train_loader,
|
||||
validation_loader=None)
|
||||
|
||||
self.train_data_len = len(self.train_loader)
|
||||
self._warmup_scheduler = get_linear_schedule_with_warmup(
|
||||
@@ -334,9 +329,6 @@ def main():
|
||||
# Training
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
training_operator_cls=TransformerOperator,
|
||||
use_fp16=args.fp16,
|
||||
apex_args={"opt_level": args.fp16_opt_level},
|
||||
|
||||
@@ -36,6 +36,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
),
|
||||
)
|
||||
|
||||
# Use FileLock to prevent parallel writes that may corrupt data.
|
||||
with FileLock("/tmp/load_and_cache_examples.lock"):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s",
|
||||
|
||||
@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
@@ -37,12 +37,9 @@ def data_creator(config):
|
||||
|
||||
|
||||
# __torch_tune_example__
|
||||
def tune_example(num_workers=1, use_gpu=False):
|
||||
def tune_example(operator_cls, num_workers=1, use_gpu=False):
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss, # Note that we specify a Loss class.
|
||||
training_operator_cls=operator_cls,
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
config={BATCH_SIZE: 128}
|
||||
@@ -81,4 +78,8 @@ if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init(address=args.address)
|
||||
tune_example(num_workers=args.num_workers, use_gpu=args.use_gpu)
|
||||
CustomTrainingOperator = TrainingOperator.from_creators(
|
||||
model_creator=model_creator, optimizer_creator=optimizer_creator,
|
||||
data_creator=data_creator, loss_creator=nn.MSELoss)
|
||||
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu)
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS
|
||||
from ray.util.sgd import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
amp = None
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
from collections import Iterable
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
@@ -32,12 +21,7 @@ class TorchRunner:
|
||||
"""Manages a PyTorch model for training."""
|
||||
|
||||
def __init__(self,
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
training_operator_cls,
|
||||
config=None,
|
||||
use_gpu=False,
|
||||
serialize_data_creation=True,
|
||||
@@ -45,22 +29,11 @@ class TorchRunner:
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq=None):
|
||||
self.model_creator = model_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.data_creator = data_creator
|
||||
self.scheduler_creator = scheduler_creator
|
||||
self.training_operator_cls = training_operator_cls or TrainingOperator
|
||||
self.training_operator_cls = training_operator_cls
|
||||
self.config = {} if config is None else config
|
||||
|
||||
self.timers = utils.TimerCollection()
|
||||
self.epochs = 0
|
||||
self.models = None
|
||||
self.optimizers = None
|
||||
self.criterion = None
|
||||
self.schedulers = None
|
||||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
self.training_operator = None
|
||||
self.serialize_data_creation = serialize_data_creation
|
||||
self.use_gpu = use_gpu
|
||||
@@ -73,107 +46,16 @@ class TorchRunner:
|
||||
"https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
def _validate_loaders(self, loaders):
|
||||
assert loaders, "Loaders need to be returned in data_creator."
|
||||
if isinstance(loaders, (tuple, list)):
|
||||
if len(loaders) == 1:
|
||||
return loaders, None
|
||||
elif len(loaders) == 2:
|
||||
return loaders
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Number of loaders must be <= 2. Got {loaders}")
|
||||
# No great way of checking type otherwise
|
||||
return loaders, None
|
||||
|
||||
def _initialize_dataloaders(self):
|
||||
logger.debug("Instantiating dataloaders.")
|
||||
loaders = None
|
||||
if self.serialize_data_creation:
|
||||
logger.debug("Serializing the dataloading process.")
|
||||
with FileLock(
|
||||
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
|
||||
loaders = self.data_creator(self.config)
|
||||
else:
|
||||
loaders = self.data_creator(self.config)
|
||||
train_loader, val_loader = self._validate_loaders(loaders)
|
||||
|
||||
self.train_loader, self.validation_loader = train_loader, val_loader
|
||||
|
||||
def _create_loss(self):
|
||||
if not self.loss_creator:
|
||||
return
|
||||
logger.debug("Creating loss.")
|
||||
if inspect.isclass(self.loss_creator) and issubclass(
|
||||
self.loss_creator, torch.nn.modules.loss._Loss):
|
||||
self.criterion = self.loss_creator()
|
||||
else:
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
if hasattr(self.criterion, "cuda"):
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
def _create_schedulers_if_available(self):
|
||||
# Learning rate schedules are optional.
|
||||
if not self.scheduler_creator:
|
||||
return
|
||||
self.schedulers = self.scheduler_creator(self.given_optimizers,
|
||||
self.config)
|
||||
|
||||
if not isinstance(self.schedulers, Iterable):
|
||||
self.schedulers = [self.schedulers]
|
||||
|
||||
def _try_setup_apex(self):
|
||||
"""Sets up the model for fp16 training via apex if available."""
|
||||
if self.use_fp16 and amp:
|
||||
self.models, self.optimizers = amp.initialize(
|
||||
self.models, self.optimizers, **self.apex_args)
|
||||
|
||||
def setup(self):
|
||||
"""Merges setup_components and setup_operator in one call."""
|
||||
self.setup_components()
|
||||
self.setup_operator()
|
||||
|
||||
def setup_components(self):
|
||||
"""Runs the creator functions without any distributed coordination."""
|
||||
logger.debug("Loading data.")
|
||||
if self.data_creator and callable(self.data_creator):
|
||||
self._initialize_dataloaders()
|
||||
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, Iterable):
|
||||
self.models = [self.models]
|
||||
assert all(isinstance(model, nn.Module) for model in self.models), (
|
||||
f"All models must be PyTorch models: {self.models}.")
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer.")
|
||||
self.optimizers = self.optimizer_creator(self.given_models,
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
self._create_loss()
|
||||
|
||||
def setup_operator(self):
|
||||
"""Create the training operator."""
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
train_loader=self.train_loader,
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=0,
|
||||
schedulers=self.schedulers,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
@@ -196,7 +78,6 @@ class TorchRunner:
|
||||
info.update({
|
||||
NUM_STEPS: num_steps,
|
||||
USE_FP16: self.use_fp16,
|
||||
SCHEDULER_STEP: self.scheduler_step_freq
|
||||
})
|
||||
with self.timers.record("train_epoch"):
|
||||
if iterator is None:
|
||||
@@ -223,15 +104,17 @@ class TorchRunner:
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_loader is None:
|
||||
raise ValueError("No validation dataloader provided.")
|
||||
raise ValueError("No validation dataloader provided. Make sure"
|
||||
"you pass in a validation_loader to "
|
||||
"TrainingOperator.register_data.")
|
||||
info = info or {}
|
||||
self._toggle_profiling(profile=profile)
|
||||
validation_loader = self.validation_loader
|
||||
|
||||
with self.timers.record("validation"):
|
||||
iterator = self.validation_loader
|
||||
iterator = validation_loader
|
||||
if num_steps:
|
||||
iterator = itertools.islice(
|
||||
iter(self.validation_loader), num_steps)
|
||||
iterator = itertools.islice(iterator, num_steps)
|
||||
validation_stats = self.training_operator.validate(
|
||||
iterator, info=info)
|
||||
if profile:
|
||||
@@ -255,32 +138,35 @@ class TorchRunner:
|
||||
"models": [model.state_dict() for model in self.models],
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers]
|
||||
}
|
||||
if self.schedulers:
|
||||
schedulers = self.schedulers
|
||||
if schedulers:
|
||||
state.update({
|
||||
"schedulers": [
|
||||
scheduler.state_dict() for scheduler in self.schedulers
|
||||
scheduler.state_dict() for scheduler in schedulers
|
||||
]
|
||||
})
|
||||
# Check if fp16 is True and if NVIDIA Apex is imported.
|
||||
if self.use_fp16 and amp:
|
||||
state.update({"amp": amp.state_dict()})
|
||||
if self.use_fp16 and self.training_operator._amp:
|
||||
state.update({"amp": self.training_operator._amp.state_dict()})
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state):
|
||||
"""Sets the state of the model."""
|
||||
for model, state_dict in zip(self.models, state["models"]):
|
||||
models = self.models
|
||||
for model, state_dict in zip(models, state["models"]):
|
||||
model.load_state_dict(state_dict)
|
||||
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
|
||||
optimizers = self.optimizers
|
||||
for optimizer, state_dict in zip(optimizers, state["optimizers"]):
|
||||
optimizer.load_state_dict(state_dict)
|
||||
if self.schedulers:
|
||||
for scheduler, state_dict in zip(self.schedulers,
|
||||
state["schedulers"]):
|
||||
schedulers = self.schedulers
|
||||
if schedulers:
|
||||
for scheduler, state_dict in zip(schedulers, state["schedulers"]):
|
||||
scheduler.load_state_dict(state_dict)
|
||||
|
||||
if self.use_fp16 and "amp" in state and amp:
|
||||
amp.load_state_dict(state["amp"])
|
||||
if self.use_fp16 and "amp" in state and self.training_operator._amp:
|
||||
self.training_operator._amp.load_state_dict(state["amp"])
|
||||
self.epochs = state["epoch"]
|
||||
self.training_operator.load_state_dict(state_dict)
|
||||
self.training_operator.load_state_dict(state["operator"])
|
||||
|
||||
def state_stream(self):
|
||||
"""Returns a bytes object for the state dict."""
|
||||
@@ -304,14 +190,67 @@ class TorchRunner:
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
del self.training_operator
|
||||
del self.validation_loader
|
||||
del self.train_loader
|
||||
del self.criterion
|
||||
del self.optimizers
|
||||
del self.models
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
if not hasattr(self.training_operator, "_original_models"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered models. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._original_models
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
if not hasattr(self.training_operator, "_optimizers"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered optimizers. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._optimizers
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
if not hasattr(self.training_operator, "_schedulers"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered schedulers. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._schedulers
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
if not hasattr(self.training_operator, "_train_loader"):
|
||||
logger.warning("Training Operator does not have any "
|
||||
"registered train loader. If this is "
|
||||
"unexepected, make sure to call "
|
||||
"self.register_data(...) inside the setup method "
|
||||
"of your Training Operator.")
|
||||
return None
|
||||
return self.training_operator._train_loader
|
||||
|
||||
@property
|
||||
def validation_loader(self):
|
||||
if not hasattr(self.training_operator, "_validation_loader"):
|
||||
logger.warning("Training Operator does not have any "
|
||||
"registered validation loader. If this is "
|
||||
"unexepected, make sure to call "
|
||||
"self.register_data(...) inside the setup method "
|
||||
"of your Training Operator.")
|
||||
return None
|
||||
return self.training_operator._validation_loader
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
if not hasattr(self.training_operator, "_criterion"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered criterion. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._criterion
|
||||
|
||||
@property
|
||||
def given_models(self):
|
||||
if len(self.models) > 1:
|
||||
|
||||
@@ -13,6 +13,7 @@ from ray.exceptions import RayActorError
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.utils.util import merge_dicts
|
||||
from ray.util import log_once
|
||||
from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
@@ -49,80 +50,44 @@ class TorchTrainer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def model_creator(config):
|
||||
return nn.Linear(1, 1)
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
|
||||
def setup(self, config):
|
||||
model = nn.Linear(1, 1)
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(), lr=config.get("lr", 1e-4))
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
return torch.optim.SGD(
|
||||
model.parameters(), lr=config.get("lr", 1e-4))
|
||||
batch_size = config["batch_size"]
|
||||
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size)
|
||||
val_loader = DataLoader(val_data, batch_size=batch_size)
|
||||
|
||||
self.model, self.optimizer = self.register(
|
||||
models=model,
|
||||
optimizers=optimizer,
|
||||
criterion=loss)
|
||||
|
||||
def data_creator(config):
|
||||
batch_size = config["batch_size"]
|
||||
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size)
|
||||
val_loader = DataLoader(val_data, batch_size=batch_size)
|
||||
return train_loader, val_loader
|
||||
self.register_data(
|
||||
train_loader=train_loader,
|
||||
validation_loader=val_loader)
|
||||
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
config={"batch_size": 32},
|
||||
use_gpu=True
|
||||
)
|
||||
for i in range(4):
|
||||
trainer.train()
|
||||
|
||||
The creator functions will execute before distributed coordination and
|
||||
training is setup. This is so that creator functions that download
|
||||
large datasets will not trigger any timeouts.
|
||||
|
||||
The order of operations for creator functions are:
|
||||
|
||||
``data_creator`` -> ``model_creator`` -> ``optimizer_creator`` ->
|
||||
``scheduler_creator`` -> ``loss_creator``.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> Model(s)): Constructor function that takes in
|
||||
config and returns the model(s) to be optimized. These must be
|
||||
``torch.nn.Module`` objects. If multiple models are returned,
|
||||
a ``training_operator_cls`` must be specified. You do not need to
|
||||
handle GPU/devices in this function; RaySGD will do that under
|
||||
the hood.
|
||||
data_creator (dict -> Iterable(s)): Constructor function
|
||||
that takes in the passed config and returns one or
|
||||
two Iterable objects. Note that even though two Iterable objects
|
||||
can be returned, only one will be used for training, and the
|
||||
other will be used for validation. If not provided, you must
|
||||
provide a custom TrainingOperator.
|
||||
optimizer_creator ((models, dict) -> optimizers): Constructor
|
||||
function that takes in the return values from
|
||||
``model_creator`` and the passed config and returns One or
|
||||
more Torch optimizer objects. You do not need to handle
|
||||
GPU/devices in this function; ``RaySGD`` will do that for you.
|
||||
loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
|
||||
function for the training loss. This can be either a function that
|
||||
takes in the provided config for customization or a subclass
|
||||
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
|
||||
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
|
||||
If not provided, you must provide a custom TrainingOperator.
|
||||
scheduler_creator ((optimizers, dict) -> scheduler):
|
||||
A constructor function for the torch scheduler. This is
|
||||
a function that takes in the generated optimizers (from
|
||||
``optimizer_creator``) provided config for customization.
|
||||
Be sure to set ``scheduler_step_freq`` to increment the
|
||||
scheduler correctly.
|
||||
training_operator_cls (type): Custom training operator class
|
||||
that subclasses the TrainingOperator class. This class
|
||||
will be copied onto all remote workers and used to specify
|
||||
custom training and validation operations. Defaults to
|
||||
TrainingOperator.
|
||||
training components and custom training and validation operations.
|
||||
config (dict): Custom configuration value to be passed to
|
||||
all creator and operator constructors.
|
||||
all operator constructors.
|
||||
num_workers (int): the number of workers used in distributed
|
||||
training. If 1, the worker will not be wrapped with
|
||||
DistributedDataParallel.
|
||||
@@ -134,10 +99,6 @@ class TorchTrainer:
|
||||
support "nccl", "gloo", and "auto". If "auto", RaySGD will
|
||||
automatically use "nccl" if `use_gpu` is True, and "gloo"
|
||||
otherwise.
|
||||
serialize_data_creation (bool): A filelock will be used
|
||||
to ensure no race conditions in data downloading among
|
||||
different workers on the same node (using the local file system).
|
||||
Defaults to True.
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
timeout_s (float): Seconds before the torch process group
|
||||
@@ -171,12 +132,7 @@ class TorchTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
training_operator_cls,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
num_workers=1,
|
||||
@@ -185,16 +141,33 @@ class TorchTrainer:
|
||||
backend="auto",
|
||||
wrap_ddp=True,
|
||||
timeout_s=NCCL_TIMEOUT_S,
|
||||
serialize_data_creation=True,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
add_dist_sampler=True,
|
||||
scheduler_step_freq=None,
|
||||
# Deprecated Args.
|
||||
num_replicas=None,
|
||||
batch_size=None,
|
||||
model_creator=None,
|
||||
data_creator=None,
|
||||
optimizer_creator=None,
|
||||
scheduler_creator=None,
|
||||
loss_creator=None,
|
||||
serialize_data_creation=None,
|
||||
data_loader_args=None,
|
||||
):
|
||||
if (model_creator or data_creator or optimizer_creator
|
||||
or scheduler_creator or loss_creator):
|
||||
raise DeprecationWarning(
|
||||
"Creator functions are deprecated. You should create a "
|
||||
"custom TrainingOperator, override setup, and register all "
|
||||
"training state there. See TrainingOperator for more info. "
|
||||
"If you would still like to use creator functions, you can "
|
||||
"do CustomOperator = TrainingOperator.from_creators("
|
||||
"model_creator, ...) and pass in CustomOperator into "
|
||||
"TorchTrainer.")
|
||||
|
||||
if num_workers > 1 and not dist.is_available():
|
||||
raise ValueError(
|
||||
("Distributed PyTorch is not supported on macOS. "
|
||||
@@ -202,10 +175,6 @@ class TorchTrainer:
|
||||
"For more information, see "
|
||||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
if not (callable(model_creator) and callable(optimizer_creator)):
|
||||
raise ValueError(
|
||||
"Must provide a callable model_creator and optimizer_creator.")
|
||||
|
||||
if num_replicas is not None:
|
||||
raise DeprecationWarning(
|
||||
"num_replicas is deprecated. Use num_workers instead.")
|
||||
@@ -217,24 +186,23 @@ class TorchTrainer:
|
||||
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
|
||||
"batch size to be used across all workers.")
|
||||
|
||||
if serialize_data_creation is True:
|
||||
if log_once("serialize_data_creation"):
|
||||
logging.warning(
|
||||
"serialize_data_creation is deprecated and will be "
|
||||
"ignored. If you require serialized data loading you "
|
||||
"should implement this in TrainingOperator.setup. "
|
||||
"You may find FileLock useful here.")
|
||||
|
||||
if data_loader_args:
|
||||
raise ValueError(
|
||||
raise DeprecationWarning(
|
||||
"data_loader_args is deprecated. You can return a "
|
||||
"torch.utils.data.DataLoader in data_creator. Ray will "
|
||||
"automatically set a DistributedSampler if a DataLoader is "
|
||||
"returned and num_workers > 1.")
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.data_creator = data_creator
|
||||
self.scheduler_creator = scheduler_creator
|
||||
self.training_operator_cls = training_operator_cls
|
||||
|
||||
if not training_operator_cls and not loss_creator:
|
||||
raise ValueError("If a loss_creator is not provided, you must "
|
||||
"provide a custom training operator.")
|
||||
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
if use_gpu == "auto":
|
||||
@@ -269,7 +237,7 @@ class TorchTrainer:
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
if scheduler_creator:
|
||||
if scheduler_step_freq:
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
@@ -309,11 +277,6 @@ class TorchTrainer:
|
||||
worker_config[BATCH_SIZE] = batch_size_per_worker
|
||||
|
||||
params = dict(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
serialize_data_creation=self.serialize_data_creation,
|
||||
@@ -328,7 +291,7 @@ class TorchTrainer:
|
||||
self.local_worker = TorchRunner(**params)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
self.local_worker.setup()
|
||||
self.local_worker.setup_operator()
|
||||
else:
|
||||
params.update(
|
||||
backend=self.backend,
|
||||
@@ -355,15 +318,6 @@ class TorchTrainer:
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
address = setup_address()
|
||||
|
||||
# Runs the creator functions.
|
||||
remote_component_setup = [
|
||||
worker.setup_components.remote()
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
self.local_worker.setup_components()
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_component_setup)
|
||||
|
||||
# Setup the process group among all workers.
|
||||
remote_pgroup_setups = [
|
||||
worker.setup_process_group.remote(address, i + 1, num_workers,
|
||||
@@ -377,10 +331,10 @@ class TorchTrainer:
|
||||
|
||||
# Runs code that requires all creator functions to have run.
|
||||
remote_operator_setups = [
|
||||
worker.setup_ddp_and_operator.remote()
|
||||
worker.setup_operator.remote()
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.setup_ddp_and_operator()
|
||||
self.local_worker.setup_operator()
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_operator_setups)
|
||||
|
||||
@@ -421,10 +375,10 @@ class TorchTrainer:
|
||||
|
||||
Returns:
|
||||
(dict | list) A dictionary of metrics for training.
|
||||
You can provide custom metrics by passing in a custom
|
||||
``training_operator_cls``. If ``reduce_results=False``,
|
||||
this will return a list of metric dictionaries whose
|
||||
length will be equal to ``num_workers``.
|
||||
You can provide custom metrics by implementing a custom
|
||||
training loop. If ``reduce_results=False``, this will return a
|
||||
list of metric dictionaries whose length will be equal to
|
||||
``num_workers``.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
assert isinstance(dataset, Dataset) is not None \
|
||||
@@ -577,12 +531,12 @@ class TorchTrainer:
|
||||
return worker_stats
|
||||
|
||||
def update_scheduler(self, metric):
|
||||
"""Calls ``scheduler.step(metric)`` on all schedulers.
|
||||
"""Calls ``scheduler.step(metric)`` on all registered schedulers.
|
||||
|
||||
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
|
||||
"""
|
||||
self.apply_all_operators(
|
||||
lambda op: [sched.step(metric) for sched in op.schedulers])
|
||||
lambda op: [sched.step(metric) for sched in op._schedulers])
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the learned model(s)."""
|
||||
@@ -729,10 +683,7 @@ class TorchTrainer:
|
||||
.. code-block:: python
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
num_gpus=2
|
||||
)
|
||||
analysis = tune.run(
|
||||
@@ -781,10 +732,7 @@ class BaseTorchTrainable(Trainable):
|
||||
.. code-block:: python
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
num_gpus=2
|
||||
)
|
||||
# TorchTrainable is subclass of BaseTorchTrainable.
|
||||
|
||||
@@ -1,10 +1,23 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from filelock import FileLock
|
||||
|
||||
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
|
||||
NUM_SAMPLES)
|
||||
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
|
||||
SCHEDULER_STEP_BATCH, SCHEDULER_STEP)
|
||||
from ray.util.sgd.torch.constants import (
|
||||
SCHEDULER_STEP_EPOCH,
|
||||
NUM_STEPS,
|
||||
SCHEDULER_STEP_BATCH,
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DistributedSampler, DataLoader, IterableDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
amp = None
|
||||
|
||||
try:
|
||||
@@ -18,6 +31,7 @@ except ImportError:
|
||||
# Apex library is not installed, so we cannot enable mixed precision.
|
||||
# We don't log here because logging happens in the torch_runner,
|
||||
# where amp is initialized.
|
||||
logger.debug("apex is not installed.")
|
||||
pass
|
||||
|
||||
tqdm = None
|
||||
@@ -33,15 +47,58 @@ def _is_multiple(component):
|
||||
|
||||
|
||||
class TrainingOperator:
|
||||
"""Abstract class for custom training or validation loops.
|
||||
"""Abstract class to define training and validation state and logic.
|
||||
|
||||
The scheduler will only be called at a batch or epoch frequency, depending
|
||||
on the user parameter. Be sure to set ``scheduler_step_freq`` in
|
||||
``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler
|
||||
correctly during training. If using a learning rate scheduler
|
||||
that depends on validation loss, you can use ``trainer.update_scheduler``.
|
||||
You must subclass this class and override the ``setup`` method to define
|
||||
your training components such as the model, optimizer, data, loss,
|
||||
and scheduler. When you pass this class to ``TorchTrainer``, a copy of
|
||||
this class will be made on each worker.
|
||||
|
||||
For both training and validation, there are two granularities that
|
||||
.. code-block:: python
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
|
||||
def setup(self, config):
|
||||
model = nn.Linear(1, 1)
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(), lr=config.get("lr", 1e-4))
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
batch_size = config["batch_size"]
|
||||
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size)
|
||||
val_loader = DataLoader(val_data, batch_size=batch_size)
|
||||
|
||||
self.model, self.optimizer = self.register(
|
||||
models=model,
|
||||
optimizers=optimizer,
|
||||
criterion=loss)
|
||||
|
||||
self.register_data(
|
||||
train_loader=train_loader,
|
||||
validation_loader=val_loader)
|
||||
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
config={"batch_size": 32},
|
||||
use_gpu=True
|
||||
)
|
||||
for i in range(4):
|
||||
trainer.train()
|
||||
|
||||
This class provides default implementations for training and validation.
|
||||
Set ``self.model``, ``self.optimizer``, and
|
||||
``self.criterion`` to leverage the default training and validation loops.
|
||||
If ``self.scheduler`` is set, it will only be called at a batch or epoch
|
||||
frequency, depending on the user parameter. Set
|
||||
``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch"
|
||||
to increment the scheduler correctly during training. If using a
|
||||
learning rate scheduler that depends on validation loss, you can use
|
||||
``trainer.update_scheduler``.
|
||||
|
||||
If you want to provide custom training and validation loops, you can do
|
||||
so using this class as well. There are two granularities that
|
||||
you can provide customization: per epoch or per batch.
|
||||
You do not need to override both.
|
||||
|
||||
@@ -49,41 +106,30 @@ class TrainingOperator:
|
||||
:scale: 80%
|
||||
:align: center
|
||||
|
||||
If you are using multiple models, optimizers, or schedulers, you must
|
||||
implement custom training and validation.
|
||||
|
||||
Raises:
|
||||
ValueError if multiple models/optimizers/schedulers are provided.
|
||||
You are expected to subclass this class if you wish
|
||||
to train over multiple models/optimizers/schedulers.
|
||||
ValueError
|
||||
You are expected to either set ``self.model``,
|
||||
``self.optimizer``, and ``self.criterion`` instance attributes in
|
||||
setup or implement custom training & validation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
models,
|
||||
optimizers,
|
||||
train_loader,
|
||||
validation_loader,
|
||||
world_rank,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
device_ids=None,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False):
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
wrap_ddp=False,
|
||||
wrap_distributed_sampler=False,
|
||||
add_dist_sampler=False,
|
||||
scheduler_step_freq=None):
|
||||
# You are not expected to override this method.
|
||||
self._models = models # List of models
|
||||
assert isinstance(
|
||||
models,
|
||||
Iterable), (f"Components need to be iterable. Got: {type(models)}")
|
||||
self._optimizers = optimizers # List of optimizers
|
||||
assert isinstance(optimizers, Iterable), (
|
||||
f"Components need to be iterable. Got: {type(optimizers)}")
|
||||
self._train_loader = train_loader
|
||||
self._validation_loader = validation_loader
|
||||
self._world_rank = world_rank
|
||||
self._criterion = criterion
|
||||
self._schedulers = schedulers
|
||||
if schedulers:
|
||||
assert isinstance(schedulers, Iterable), (
|
||||
f"Components need to be iterable. Got: {type(schedulers)}")
|
||||
self._config = config
|
||||
self._use_fp16 = use_fp16
|
||||
self._device_ids = device_ids
|
||||
@@ -93,14 +139,12 @@ class TrainingOperator:
|
||||
raise ValueError("tqdm must be installed to use tqdm in training.")
|
||||
self._use_tqdm = use_tqdm
|
||||
self.global_step = 0
|
||||
self._apex_args = apex_args if apex_args else {}
|
||||
self._wrap_ddp = wrap_ddp
|
||||
self._wrap_distributed_sampler = wrap_distributed_sampler
|
||||
self._add_dist_sampler = add_dist_sampler
|
||||
self._scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
if type(self) is TrainingOperator:
|
||||
for component in (models, schedulers, optimizers):
|
||||
if _is_multiple(component):
|
||||
raise ValueError(
|
||||
"Need to provide a custom operator subclassing "
|
||||
"TrainingOperator if using multi-scheduler, "
|
||||
"multi-model or multi-optimizer training/validation.")
|
||||
self.timers = TimerCollection()
|
||||
self.setup(config)
|
||||
|
||||
@@ -109,13 +153,233 @@ class TrainingOperator:
|
||||
self.timers = timers
|
||||
|
||||
def setup(self, config):
|
||||
"""Override this method to implement custom operator setup.
|
||||
"""Override this method to implement operator setup.
|
||||
|
||||
You should call self.register and self.register_data here to
|
||||
register training components and data loaders with Ray SGD.
|
||||
|
||||
Args:
|
||||
config (dict): Custom configuration value to be passed to
|
||||
all creator and operator constructors. Same as ``self.config``.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def register(self, *, models, optimizers, criterion=None, schedulers=None):
|
||||
"""Registers parameters with Ray SGD and sets up training components.
|
||||
|
||||
By calling this method to register your models, optimizers,
|
||||
criterion, and schedulers, Ray SGD will automatically handle
|
||||
necessary setup such as GPU/devices, Distributed Data Parallel, and
|
||||
Fp16. The registered components are returned and should be set as
|
||||
instance attributes to access during training/validation.
|
||||
|
||||
If more than one model, optimizer, or scheduler is passed in,
|
||||
you should implement your own custom training loop.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = ...
|
||||
optimizer = ...
|
||||
train_loader = ...
|
||||
val_loader = ...
|
||||
loss = ...
|
||||
|
||||
self.model, self.optimizer, self.criterion = self.register(
|
||||
models=model, optimizers=optimizer, criterion=loss)
|
||||
|
||||
# At this point DDP, Cuda, and Fp16
|
||||
# are set up for all our components. We then use
|
||||
# self.model, self.optimizer, etc. in our training loop.
|
||||
|
||||
self.register_data(train_loader=train_loader,
|
||||
validation_loader=val_loader)
|
||||
|
||||
|
||||
Args:
|
||||
models (torch.nn.Module or Iterable[nn.Module]): Pytorch model or
|
||||
multiple Pytorch models to use for training. If
|
||||
`use_gpu=True` is passed into ``TorchTrainer``, and Cuda is
|
||||
available, models will automatically be placed on GPU.
|
||||
If ``wrap_ddp=True`` is passed into ``TorchTrainer``,
|
||||
models will be wrapped in DDP. If wrap_ddp is False,
|
||||
you should handle DDP for your models in setup.
|
||||
optimizers (torch.optim.Optimizer or Iterable[
|
||||
torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch
|
||||
optimizers to use for training.
|
||||
criterion (Callable, optional): Function to return loss
|
||||
metric given features and target. If not provided,
|
||||
must implement a custom training loop.
|
||||
schedulers (torch.optim.lr_scheduler or Iterable[
|
||||
torch.optim.lr_scheduler], optional): A learning rate
|
||||
scheduler or multiple learning rate schedulers.
|
||||
|
||||
Returns:
|
||||
Tuple of model, optimizer, criterion if not None, and scheduler
|
||||
if not None.
|
||||
|
||||
"""
|
||||
return_vals = []
|
||||
logger.debug("Registering models.")
|
||||
self._original_models = models
|
||||
if not isinstance(self._original_models, Iterable):
|
||||
self._original_models = [self._original_models]
|
||||
assert all(
|
||||
isinstance(model, nn.Module) for model in self._original_models), (
|
||||
f"All models must be PyTorch models: {self._original_models}.")
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self._original_models = [
|
||||
model.cuda() for model in self._original_models
|
||||
]
|
||||
|
||||
logger.debug("Registering optimizers.")
|
||||
self._optimizers = optimizers
|
||||
if not isinstance(self._optimizers, Iterable):
|
||||
self._optimizers = [self._optimizers]
|
||||
|
||||
if schedulers:
|
||||
logger.debug("Registering scheduler.")
|
||||
self._schedulers = schedulers
|
||||
if not isinstance(self._schedulers, Iterable):
|
||||
self._schedulers = [self._schedulers]
|
||||
else:
|
||||
self._schedulers = None
|
||||
|
||||
if criterion:
|
||||
logger.debug("Registering loss.")
|
||||
self._criterion = criterion
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
if hasattr(self._criterion, "cuda"):
|
||||
self._criterion = self._criterion.cuda()
|
||||
else:
|
||||
self._criterion = None
|
||||
|
||||
logger.debug("Setting up Apex.")
|
||||
if self.use_fp16 and amp:
|
||||
self._models, self._optimizers = amp.initialize(
|
||||
self._models, self._optimizers, **self._apex_args)
|
||||
self._amp = amp
|
||||
|
||||
if self._wrap_ddp:
|
||||
logging.debug("Setting up DDP for models.")
|
||||
self._models = [
|
||||
DistributedDataParallel(model, device_ids=self.device_ids)
|
||||
for model in self._original_models
|
||||
]
|
||||
else:
|
||||
self._models = self._original_models
|
||||
|
||||
if len(self._models) == 1:
|
||||
return_vals.append(self._models[0])
|
||||
else:
|
||||
return_vals.append(self._models)
|
||||
|
||||
if len(self._optimizers) == 1:
|
||||
return_vals.append(self._optimizers[0])
|
||||
else:
|
||||
return_vals.append(self._optimizers)
|
||||
|
||||
if self._criterion is not None:
|
||||
return_vals.append(self._criterion)
|
||||
|
||||
if self._schedulers is not None:
|
||||
if self.scheduler_step_freq is None:
|
||||
raise ValueError("scheduler_step_freq passed into "
|
||||
"TorchTrainer cannot be None if you "
|
||||
"are registering schedulers. Set this to "
|
||||
"'manual' if you will be manually stepping "
|
||||
"the schedulers.")
|
||||
if len(self._schedulers) == 1:
|
||||
return_vals.append(self._schedulers[0])
|
||||
else:
|
||||
return_vals.append(self._schedulers)
|
||||
|
||||
return tuple(return_vals)
|
||||
|
||||
def register_data(self, *, train_loader=None, validation_loader=None):
|
||||
"""Registers data loaders with Ray SGD.
|
||||
|
||||
Calling this method will automatically setup Distributed Sampler for
|
||||
these data loaders if add_dist_sampler=True is passed into the
|
||||
TorchTrainer. This method does not return the wrapped data loaders.
|
||||
You should use the iterators passed into train_epoch and validate
|
||||
instead.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = ...
|
||||
optimizer = ...
|
||||
train_loader = ...
|
||||
val_loader = ...
|
||||
loss = ...
|
||||
|
||||
self.model, self.optimizer, self.criterion = self.register(
|
||||
models=model, optimizers=optimizer, criterion=loss)
|
||||
|
||||
self.register_data(train_loader=train_loader,
|
||||
validation_loader=val_loader)
|
||||
|
||||
# At this point the data loaders are registered with
|
||||
# Ray SGD and are wrapped with Distributed Samplers if
|
||||
# applicable.
|
||||
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
# If providing custom training or validation methods,
|
||||
# the registered data loaders are passed in through the
|
||||
# iterator parameter.
|
||||
...
|
||||
|
||||
Args:
|
||||
train_loader (Iterator): An iterator for training
|
||||
data. If None is explicitly passed in, a Ray SGD Dataset
|
||||
must be passed in through TorchTrainer.train. Ray SGD will
|
||||
automatically use a Distributed Sampler if TorchTrainer(...,
|
||||
add_dist_sampler=True).
|
||||
validation_loader (Iterator): An iterator for validation
|
||||
data. Ray SGD will automatically use a Distributed Sampler
|
||||
if TorchTrainer(..., add_dist_sampler=True).
|
||||
"""
|
||||
|
||||
logger.debug("Registering data loaders..")
|
||||
self._train_loader = train_loader
|
||||
self._validation_loader = validation_loader
|
||||
|
||||
if self._wrap_distributed_sampler:
|
||||
logging.debug("Wrapping data loaders with DistributedSampler.")
|
||||
|
||||
def with_sampler(loader):
|
||||
# Automatically set the DistributedSampler
|
||||
data_loader_args = {
|
||||
"dataset": loader.dataset,
|
||||
"batch_size": loader.batch_size,
|
||||
"shuffle": False,
|
||||
"num_workers": loader.num_workers,
|
||||
"collate_fn": loader.collate_fn,
|
||||
"pin_memory": loader.pin_memory,
|
||||
"drop_last": loader.drop_last,
|
||||
"timeout": loader.timeout,
|
||||
"worker_init_fn": loader.worker_init_fn,
|
||||
"sampler": DistributedSampler(loader.dataset)
|
||||
}
|
||||
return DataLoader(**data_loader_args)
|
||||
|
||||
def should_wrap_dataloader(loader):
|
||||
return (isinstance(loader, DataLoader)
|
||||
and not isinstance(loader.dataset, IterableDataset))
|
||||
|
||||
if should_wrap_dataloader(self._train_loader):
|
||||
if self._add_dist_sampler:
|
||||
self._train_loader = with_sampler(self._train_loader)
|
||||
|
||||
if self._validation_loader is not None and should_wrap_dataloader(
|
||||
self._validation_loader):
|
||||
if self._add_dist_sampler:
|
||||
self._validation_loader = with_sampler(
|
||||
self._validation_loader)
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
"""Runs one standard training pass over the training dataloader.
|
||||
@@ -156,6 +420,15 @@ class TrainingOperator:
|
||||
Returns:
|
||||
A dict of metrics from training.
|
||||
"""
|
||||
if not hasattr(self, "model"):
|
||||
raise RuntimeError("Either set self.model in setup function or "
|
||||
"override this method to implement a custom "
|
||||
"training loop.")
|
||||
model = self.model
|
||||
scheduler = None
|
||||
if hasattr(self, "scheduler"):
|
||||
scheduler = self.scheduler
|
||||
|
||||
if self.use_tqdm and self.world_rank == 0:
|
||||
desc = ""
|
||||
if info is not None and "epoch_idx" in info:
|
||||
@@ -163,15 +436,19 @@ class TrainingOperator:
|
||||
desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
|
||||
else:
|
||||
desc = f"{info['epoch_idx'] + 1}e"
|
||||
|
||||
# TODO: Implement len for Dataset?
|
||||
total = info[NUM_STEPS]
|
||||
if total is None:
|
||||
if hasattr(iterator, "__len__"):
|
||||
total = len(iterator)
|
||||
|
||||
_progress_bar = tqdm(
|
||||
total=info[NUM_STEPS] or len(self.train_loader),
|
||||
desc=desc,
|
||||
unit="batch",
|
||||
leave=False)
|
||||
total=total, desc=desc, unit="batch", leave=False)
|
||||
|
||||
metric_meters = AverageMeterCollection()
|
||||
|
||||
self.model.train()
|
||||
model.train()
|
||||
for batch_idx, batch in enumerate(iterator):
|
||||
batch_info = {
|
||||
"batch_idx": batch_idx,
|
||||
@@ -187,15 +464,14 @@ class TrainingOperator:
|
||||
postfix.update(loss=metrics["train_loss"])
|
||||
_progress_bar.set_postfix(postfix)
|
||||
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
self.scheduler.step()
|
||||
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH:
|
||||
scheduler.step()
|
||||
|
||||
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
|
||||
self.global_step += 1
|
||||
|
||||
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
|
||||
self.scheduler.step()
|
||||
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH:
|
||||
scheduler.step()
|
||||
|
||||
return metric_meters.summary()
|
||||
|
||||
@@ -211,9 +487,7 @@ class TrainingOperator:
|
||||
automatically.
|
||||
|
||||
You can provide custom loss metrics and training operations if you
|
||||
override this method. If overriding this method, you can access model,
|
||||
optimizer, criterion via ``self.model``, ``self.optimizer``,
|
||||
and ``self.criterion``.
|
||||
override this method.
|
||||
|
||||
You do not need to override this method if you plan to
|
||||
override ``train_epoch``.
|
||||
@@ -232,6 +506,21 @@ class TrainingOperator:
|
||||
calculate averages.
|
||||
|
||||
"""
|
||||
if not hasattr(self, "model"):
|
||||
raise RuntimeError("Either set self.model in setup function or "
|
||||
"override this method to implement a custom "
|
||||
"training loop.")
|
||||
if not hasattr(self, "optimizer"):
|
||||
raise RuntimeError("Either set self.optimizer in setup function "
|
||||
"or override this method to implement a custom "
|
||||
"training loop.")
|
||||
if not hasattr(self, "criterion"):
|
||||
raise RuntimeError("Either set self.criterion in setup function "
|
||||
"or override this method to implement a custom "
|
||||
"training loop.")
|
||||
model = self.model
|
||||
optimizer = self.optimizer
|
||||
criterion = self.criterion
|
||||
# unpack features into list to support multiple inputs model
|
||||
*features, target = batch
|
||||
# Create non_blocking tensors for distributed training
|
||||
@@ -243,21 +532,21 @@ class TrainingOperator:
|
||||
|
||||
# Compute output.
|
||||
with self.timers.record("fwd"):
|
||||
output = self.model(*features)
|
||||
loss = self.criterion(output, target)
|
||||
output = model(*features)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Compute gradients in a backward pass.
|
||||
with self.timers.record("grad"):
|
||||
self.optimizer.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
if self.use_fp16:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
# Call step of optimizer to update model params.
|
||||
with self.timers.record("apply"):
|
||||
self.optimizer.step()
|
||||
optimizer.step()
|
||||
|
||||
return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)}
|
||||
|
||||
@@ -267,9 +556,8 @@ class TrainingOperator:
|
||||
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
|
||||
over the validation dataloader.
|
||||
|
||||
If overriding this method, you can access model, criterion via
|
||||
``self.model`` and ``self.criterion``. You also do not need to call
|
||||
``validate_batch`` if overriding this method.
|
||||
You also do not need to call ``validate_batch`` if overriding this
|
||||
method.
|
||||
|
||||
Args:
|
||||
val_iterator (iter): Iterable constructed from the
|
||||
@@ -284,10 +572,15 @@ class TrainingOperator:
|
||||
from ``validate_batch`` and dividing it by the sum of
|
||||
``num_samples`` from all calls to ``self.validate_batch``.
|
||||
"""
|
||||
if not hasattr(self, "model"):
|
||||
raise RuntimeError("Either set self.model in setup function or "
|
||||
"override this method to implement a custom "
|
||||
"validation loop.")
|
||||
model = self.model
|
||||
metric_meters = AverageMeterCollection()
|
||||
|
||||
# switch to evaluate mode
|
||||
self.model.eval()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(val_iterator):
|
||||
batch_info = {"batch_idx": batch_idx}
|
||||
@@ -319,6 +612,16 @@ class TrainingOperator:
|
||||
by default, ``validate`` uses "num_samples" to
|
||||
calculate averages.
|
||||
"""
|
||||
if not hasattr(self, "model"):
|
||||
raise RuntimeError("Either set self.model in setup function or "
|
||||
"override this method to implement a custom "
|
||||
"training loop.")
|
||||
if not hasattr(self, "criterion"):
|
||||
raise RuntimeError("Either set self.criterion in setup function "
|
||||
"or override this method to implement a custom "
|
||||
"training loop.")
|
||||
model = self.model
|
||||
criterion = self.criterion
|
||||
# unpack features into list to support multiple inputs model
|
||||
*features, target = batch
|
||||
if self.use_gpu:
|
||||
@@ -330,8 +633,8 @@ class TrainingOperator:
|
||||
# compute output
|
||||
|
||||
with self.timers.record("eval_fwd"):
|
||||
output = self.model(*features)
|
||||
loss = self.criterion(output, target)
|
||||
output = model(*features)
|
||||
loss = criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
num_correct = (predicted == target).sum().item()
|
||||
@@ -344,6 +647,9 @@ class TrainingOperator:
|
||||
|
||||
def state_dict(self):
|
||||
"""Override this to return a representation of the operator state.
|
||||
Any argument passed into self.register and self.register_data will
|
||||
automatically be saved.
|
||||
Use this method to save any additional state.
|
||||
|
||||
Returns:
|
||||
dict: The state dict of the operator."""
|
||||
@@ -351,11 +657,81 @@ class TrainingOperator:
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Override this to load the representation of the operator state.
|
||||
|
||||
Anything passed into self.register and self.register_data will
|
||||
automatically be loaded. Use this method to load any additional state.
|
||||
Args:
|
||||
state_dict (dict): State dict as returned by the operator. """
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_creators(cls,
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
data_creator=None,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
serialize_data_creation=True):
|
||||
"""A utility method to create a custom TrainingOperator class from
|
||||
creator functions. This is useful for backwards compatibility with
|
||||
previous versions of Ray. To provide custom training and validation,
|
||||
you should subclass the class that is returned by this method instead
|
||||
of ``TrainingOperator``.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> Model(s)): Constructor function that takes
|
||||
in config and returns the model(s) to be optimized. These
|
||||
must be ``torch.nn.Module`` objects. If multiple models are
|
||||
returned, a ``training_operator_cls`` must be specified.
|
||||
You do not need to handle GPU/devices in this function;
|
||||
RaySGD will do that under the hood.
|
||||
data_creator (dict -> Iterable(s)): Constructor function
|
||||
that takes in the passed config and returns one or
|
||||
two Iterable objects. Note that even though two Iterable
|
||||
objects can be returned, only one will be used for training,
|
||||
and the other will be used for validation. If not provided,
|
||||
you must pass in a Dataset to ``TorchTrainer.train``.
|
||||
optimizer_creator ((models, dict) -> optimizers): Constructor
|
||||
function that takes in the return values from
|
||||
``model_creator`` and the passed config and returns One or
|
||||
more Torch optimizer objects. You do not need to handle
|
||||
GPU/devices in this function; ``RaySGD`` will do that for you.
|
||||
loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
|
||||
function for the training loss. This can be either a function
|
||||
that takes in the provided config for customization or a
|
||||
subclass of ``torch.nn.modules.loss._Loss``, which is most
|
||||
Pytorch loss classes. For example,
|
||||
``loss_creator=torch.nn.BCELoss``. If not provided, you must
|
||||
provide a custom TrainingOperator.
|
||||
scheduler_creator ((optimizers, dict) -> scheduler):
|
||||
A constructor function for the torch scheduler. This is
|
||||
a function that takes in the generated optimizers (from
|
||||
``optimizer_creator``) provided config for customization.
|
||||
Be sure to set ``scheduler_step_freq`` to increment the
|
||||
scheduler correctly.
|
||||
serialize_data_creation (bool): A filelock will be used
|
||||
to ensure no race conditions in data downloading among
|
||||
different workers on the same node (using the local file
|
||||
system). Defaults to True.
|
||||
|
||||
Returns:
|
||||
A TrainingOperator class with a ``setup`` method that utilizes
|
||||
the passed in creator functions.
|
||||
"""
|
||||
|
||||
if not (callable(model_creator) and callable(optimizer_creator)):
|
||||
raise ValueError(
|
||||
"Must provide a callable model_creator and optimizer_creator.")
|
||||
|
||||
class CustomCreatorOperator(CreatorOperator):
|
||||
_model_creator = model_creator
|
||||
_optimizer_creator = optimizer_creator
|
||||
_data_creator = data_creator
|
||||
_loss_creator = loss_creator
|
||||
_scheduler_creator = scheduler_creator
|
||||
_serialize_data_creation = serialize_data_creation
|
||||
|
||||
return CustomCreatorOperator
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""torch.device: The appropriate torch device, at your convenience."""
|
||||
@@ -366,58 +742,11 @@ class TrainingOperator:
|
||||
"""dict: Provided into TorchTrainer."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""First or only model created by the provided ``model_creator``."""
|
||||
return self._models[0]
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
"""List of models created by the provided ``model_creator``."""
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""First or only optimizer(s) created by the ``optimizer_creator``."""
|
||||
return self._optimizers[0]
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""List of optimizers created by the ``optimizer_creator``."""
|
||||
return self._optimizers
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
"""Iterable: 1st Dataloader from ``data_creator``.
|
||||
"""
|
||||
return self._train_loader
|
||||
|
||||
@property
|
||||
def validation_loader(self):
|
||||
"""Iterable: 2nd Dataloader from ``data_creator``."""
|
||||
return self._validation_loader
|
||||
|
||||
@property
|
||||
def world_rank(self):
|
||||
"""int: The rank of the parent runner. Always 0 if not distributed."""
|
||||
return self._world_rank
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
"""Criterion created by the provided ``loss_creator``."""
|
||||
return self._criterion
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""First or only scheduler(s) created by the ``scheduler_creator``."""
|
||||
if self._schedulers:
|
||||
return self._schedulers[0]
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
"""List of schedulers created by the ``scheduler_creator``."""
|
||||
return self._schedulers
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
"""Returns True if cuda is available and use_gpu is True."""
|
||||
@@ -441,31 +770,138 @@ class TrainingOperator:
|
||||
"""
|
||||
return self._device_ids
|
||||
|
||||
@property
|
||||
def scheduler_step_freq(self):
|
||||
"""Optional[str]: The ``scheduler_step_freq`` passed into
|
||||
``TorchTrainer``
|
||||
|
||||
class _TestingOperator(TrainingOperator):
|
||||
def train_epoch(self, iterator, info):
|
||||
func = self.config.get("custom_func")
|
||||
if callable(func):
|
||||
return func(self, iterator, info)
|
||||
return {"done": 1}
|
||||
This is useful to determine when to call scheduler.step.
|
||||
"""
|
||||
return self._scheduler_step_freq
|
||||
|
||||
|
||||
class _TestMetricsOperator(TrainingOperator):
|
||||
class CreatorOperator(TrainingOperator):
|
||||
"""A subclass of TrainingOperator specifically for defining training
|
||||
state using creator functions.
|
||||
"""
|
||||
|
||||
def _validate_loaders(self, loaders):
|
||||
assert loaders, "Loaders need to be returned in data_creator."
|
||||
if isinstance(loaders, (tuple, list)):
|
||||
if len(loaders) == 1:
|
||||
return loaders, None
|
||||
elif len(loaders) == 2:
|
||||
return loaders
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Number of loaders must be <= 2. Got {loaders}")
|
||||
# No great way of checking type otherwise
|
||||
return loaders, None
|
||||
|
||||
def _initialize_dataloaders(self, config):
|
||||
logger.debug("Instantiating dataloaders.")
|
||||
loaders = None
|
||||
if self._serialize_data_creation:
|
||||
logger.debug("Serializing the dataloading process.")
|
||||
with FileLock(
|
||||
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
|
||||
loaders = self.__class__._data_creator(config)
|
||||
else:
|
||||
loaders = self.__class__._data_creator(config)
|
||||
train_loader, val_loader = self._validate_loaders(loaders)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
def setup(self, config):
|
||||
self._train_scores = config["scores"].copy()
|
||||
self._val_scores = config["val_scores"].copy()
|
||||
self.key = config["key"]
|
||||
kwargs = {}
|
||||
logger.debug("Loading data.")
|
||||
train_loader = None
|
||||
validation_loader = None
|
||||
if self.__class__._data_creator and callable(
|
||||
self.__class__._data_creator):
|
||||
train_loader, validation_loader = self._initialize_dataloaders(
|
||||
config)
|
||||
|
||||
def train_batch(self, batch, batch_info=None):
|
||||
metrics = super(_TestMetricsOperator, self).train_batch(
|
||||
batch, batch_info)
|
||||
num_samples = metrics[NUM_SAMPLES]
|
||||
metrics.update({self.key: self._train_scores.pop(0) / num_samples})
|
||||
return metrics
|
||||
logger.debug("Creating model")
|
||||
models = self.__class__._model_creator(config)
|
||||
|
||||
def validate_batch(self, batch, batch_info=None):
|
||||
metrics = super(_TestMetricsOperator, self).validate_batch(
|
||||
batch, batch_info)
|
||||
num_samples = metrics[NUM_SAMPLES]
|
||||
metrics.update({self.key: self._val_scores.pop(0) / num_samples})
|
||||
return metrics
|
||||
kwargs["models"] = models
|
||||
|
||||
logger.debug("Creating optimizer.")
|
||||
optimizers = self.__class__._optimizer_creator(models, config)
|
||||
|
||||
kwargs["optimizers"] = optimizers
|
||||
|
||||
if self.__class__._scheduler_creator:
|
||||
logger.debug("Creating scheduler.")
|
||||
schedulers = self.__class__._scheduler_creator(optimizers, config)
|
||||
kwargs["schedulers"] = schedulers
|
||||
|
||||
if self.__class__._loss_creator:
|
||||
logger.debug("Creating loss.")
|
||||
if inspect.isclass(self.__class__._loss_creator) and issubclass(
|
||||
self.__class__._loss_creator, torch.nn.modules.loss._Loss):
|
||||
criterion = self.__class__._loss_creator()
|
||||
else:
|
||||
criterion = self.__class__._loss_creator(config)
|
||||
kwargs["criterion"] = criterion
|
||||
|
||||
state = self.register(**kwargs)
|
||||
self.models, self.optimizers = state[:2]
|
||||
if isinstance(self.models, tuple):
|
||||
self.model = self.models[0]
|
||||
else:
|
||||
self.model = self.models
|
||||
|
||||
if isinstance(self.optimizers, tuple):
|
||||
self.optimizer = self.optimizers[0]
|
||||
else:
|
||||
self.optimizer = self.optimizers
|
||||
|
||||
if len(state) >= 3:
|
||||
self.criterion = state[2]
|
||||
if len(state) == 4:
|
||||
self.schedulers = state[3]
|
||||
if isinstance(self.schedulers, tuple):
|
||||
self.scheduler = self.schedulers[0]
|
||||
else:
|
||||
self.scheduler = self.schedulers
|
||||
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=validation_loader)
|
||||
|
||||
|
||||
def get_test_operator(operator_cls):
|
||||
class _TestingOperator(operator_cls):
|
||||
def train_epoch(self, iterator, info):
|
||||
func = self.config.get("custom_func")
|
||||
if callable(func):
|
||||
return func(self, iterator, info)
|
||||
return {"done": 1}
|
||||
|
||||
return _TestingOperator
|
||||
|
||||
|
||||
def get_test_metrics_operator(operator_cls):
|
||||
class _TestMetricsOperator(operator_cls):
|
||||
def setup(self, config):
|
||||
super(_TestMetricsOperator, self).setup(config)
|
||||
self._train_scores = config["scores"].copy()
|
||||
self._val_scores = config["val_scores"].copy()
|
||||
self.key = config["key"]
|
||||
|
||||
def train_batch(self, batch, batch_info=None):
|
||||
metrics = super(_TestMetricsOperator, self).train_batch(
|
||||
batch, batch_info)
|
||||
num_samples = metrics[NUM_SAMPLES]
|
||||
metrics.update({self.key: self._train_scores.pop(0) / num_samples})
|
||||
return metrics
|
||||
|
||||
def validate_batch(self, batch, batch_info=None):
|
||||
metrics = super(_TestMetricsOperator, self).validate_batch(
|
||||
batch, batch_info)
|
||||
num_samples = metrics[NUM_SAMPLES]
|
||||
metrics.update({self.key: self._val_scores.pop(0) / num_samples})
|
||||
return metrics
|
||||
|
||||
return _TestMetricsOperator
|
||||
|
||||
Reference in New Issue
Block a user