[RaySGD] Simplify Builder Process (#10321)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-09-08 15:19:40 -07:00
committed by GitHub
parent 69c1a9dd08
commit 415be78cc0
20 changed files with 1436 additions and 1113 deletions
+130 -170
View File
@@ -11,8 +11,8 @@ from torch.utils.data import DataLoader
import ray
from ray import tune
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.training_operator import (_TestingOperator,
_TestMetricsOperator)
from ray.util.sgd.torch.training_operator import (
get_test_operator, get_test_metrics_operator, TrainingOperator)
from ray.util.sgd.torch.constants import SCHEDULER_STEP
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
BATCH_SIZE)
@@ -44,13 +44,12 @@ def ray_start_4_cpus():
dist.destroy_process_group()
Operator = TrainingOperator.from_creators(
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)
def test_single_step(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=1)
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
@@ -59,49 +58,9 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
def test_resize(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=1)
trainer.train(num_steps=1)
trainer.max_replicas = 2
results = trainer.train(num_steps=1, reduce_results=False)
assert len(results) == 2
def test_non_serialized_data(ray_start_2_cpus): # noqa: F811
duration = 10
def slow_data(func):
def slowed_func(*args, **kwargs):
time.sleep(duration)
return func(*args, **kwargs)
return slowed_func
start = time.time()
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=slow_data(data_creator),
optimizer_creator=optimizer_creator,
serialize_data_creation=False,
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
elapsed = time.time() - start
assert elapsed < duration * 2
trainer.shutdown()
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
TestOperator = get_test_operator(Operator)
trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2)
trainer.train(num_steps=1)
trainer.shutdown()
with pytest.raises(RuntimeError):
@@ -111,11 +70,7 @@ def test_dead_trainer(ray_start_2_cpus): # noqa: F811
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=num_workers)
training_operator_cls=Operator, num_workers=num_workers)
for i in range(3):
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["val_loss"]
@@ -165,22 +120,26 @@ def test_multi_model(ray_start_2_cpus, num_workers):
iterator=iter(data))
return result
def multi_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1)
class MultiModelOperator(TrainingOperator):
def setup(self, config):
models = nn.Linear(1, 1), nn.Linear(1, 1)
opts = [
torch.optim.SGD(model.parameters(), lr=0.0001)
for model in models
]
loss = nn.MSELoss()
train_dataloader, val_dataloader = data_creator(config)
self.models, self.optimizers, self.criterion = self.register(
models=models, optimizers=opts, criterion=loss)
self.register_data(
train_loader=train_dataloader,
validation_loader=val_dataloader)
def multi_optimizer_creator(models, config):
opts = [
torch.optim.SGD(model.parameters(), lr=0.0001) for model in models
]
return opts[0], opts[1]
TestOperator = get_test_operator(MultiModelOperator)
trainer1 = TorchTrainer(
model_creator=multi_model_creator,
data_creator=data_creator,
optimizer_creator=multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
training_operator_cls=TestOperator,
num_workers=num_workers)
trainer1.train()
state = trainer1.state_dict()
@@ -190,12 +149,8 @@ def test_multi_model(ray_start_2_cpus, num_workers):
trainer1.shutdown()
trainer2 = TorchTrainer(
model_creator=multi_model_creator,
data_creator=data_creator,
optimizer_creator=multi_optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
training_operator_cls=TestOperator,
num_workers=num_workers)
trainer2.load_state_dict(state)
@@ -252,17 +207,29 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
]
return schedulers[0] if len(schedulers) == 1 else schedulers
class MultiModelOperator(TrainingOperator):
def setup(self, config):
models = multi_model_creator(config)
optimizers = multi_optimizer_creator(models, config)
schedulers = multi_scheduler_creator(optimizers, config)
train_loader, val_loader = data_creator(config)
loss = nn.MSELoss()
self.models, self.optimizers, self.criterion, self.schedulers = \
self.register(models=models, optimizers=optimizers,
schedulers=schedulers,
criterion=loss)
self.register_data(
train_loader=train_loader, validation_loader=val_loader)
TestOperator = get_test_operator(MultiModelOperator)
for model_count in range(1, 3):
for optimizer_count in range(1, 3):
for scheduler_count in range(1, 3):
trainer = TorchTrainer(
model_creator=multi_model_creator,
data_creator=data_creator,
optimizer_creator=multi_optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=multi_scheduler_creator,
scheduler_step_freq="epoch",
training_operator_cls=_TestingOperator,
training_operator_cls=TestOperator,
num_workers=num_workers,
config={
"models": model_count,
@@ -284,24 +251,31 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=30, gamma=0.1)
class TestTrainingOperator(TrainingOperator):
def setup(self, config):
model = model_creator(config)
optimizer = optimizer_creator(model, config)
train_loader, val_loader = data_creator(config)
scheduler = scheduler_creator(optimizer, config)
loss = nn.MSELoss()
self.model, self.optimizer, self.criterion, self.scheduler = \
self.register(
models=model, optimizers=optimizer,
criterion=loss, schedulers=scheduler)
self.register_data(
train_loader=train_loader, validation_loader=val_loader)
if scheduler_freq is None:
with pytest.raises(ValueError):
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=scheduler_creator,
config={"custom_func": train_epoch},
training_operator_cls=TestTrainingOperator,
scheduler_step_freq=scheduler_freq)
else:
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
scheduler_creator=scheduler_creator,
training_operator_cls=TestTrainingOperator,
scheduler_step_freq=scheduler_freq)
for i in range(3):
@@ -310,11 +284,7 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
def test_profiling(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss())
trainer = TorchTrainer(training_operator_cls=Operator)
stats = trainer.train(profile=True)
assert "profile" in stats
@@ -334,11 +304,14 @@ def test_dataset(ray_start_4_cpus):
model_creator = mlp_identity.model_creator
optimizer_creator = mlp_identity.optimizer_creator
dataset_creator = mlp_identity.dataset_creator
trainer = TorchTrainer(
DatasetOperator = TrainingOperator.from_creators(
model_creator=model_creator,
data_creator=None,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
loss_creator=nn.MSELoss)
trainer = TorchTrainer(
training_operator_cls=DatasetOperator,
num_workers=2,
)
@@ -366,12 +339,13 @@ def test_split_batch(ray_start_2_cpus):
data_size = 600
batch_size = 21
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
data_creator,
loss_creator=lambda config: nn.MSELoss())
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
training_operator_cls=TestOperator,
num_workers=2,
config={
BATCH_SIZE: batch_size,
@@ -398,11 +372,13 @@ def test_reduce_result(ray_start_2_cpus):
data_size = 600
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
data_creator,
loss_creator=lambda config: nn.MSELoss())
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
training_operator_cls=TestOperator,
num_workers=2,
config={"data_size": data_size})
list_stats = trainer.train(reduce_results=False, profile=True)
@@ -426,11 +402,10 @@ def test_metrics(ray_start_2_cpus, num_workers):
train_scores = [1] + ([0] * num_train_steps)
val_scores = [1] + ([0] * num_val_steps)
TestOperator = get_test_metrics_operator(Operator)
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
training_operator_cls=TestOperator,
num_workers=num_workers,
config={
"scores": train_scores,
@@ -439,8 +414,7 @@ def test_metrics(ray_start_2_cpus, num_workers):
"batch_size": batch_size,
"data_size": data_size,
"val_size": val_size
},
training_operator_cls=_TestMetricsOperator)
})
stats = trainer.train(num_steps=num_train_steps)
# Test that we output mean and last of custom metrics in an epoch
@@ -475,11 +449,9 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
train_scores = [np.nan] + ([0] * num_train_steps)
val_scores = [np.nan] + ([0] * num_val_steps)
TestOperator = get_test_metrics_operator(Operator)
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
training_operator_cls=TestOperator,
num_workers=num_workers,
config={
"scores": train_scores,
@@ -488,8 +460,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
"batch_size": batch_size,
"data_size": data_size,
"val_size": val_size
},
training_operator_cls=_TestMetricsOperator)
})
stats = trainer.train(num_steps=num_train_steps)
assert "score" in stats
@@ -506,19 +477,20 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
from torch.optim.lr_scheduler import ReduceLROnPlateau
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
data_creator,
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
scheduler_step_freq="manual",
training_operator_cls=_TestingOperator)
loss_creator=lambda config: nn.MSELoss())
TestOperator = get_test_operator(TestOperator)
trainer = TorchTrainer(
scheduler_step_freq="manual", training_operator_cls=TestOperator)
trainer.update_scheduler(0.5)
trainer.update_scheduler(0.5)
assert all(
trainer.apply_all_operators(
lambda op: op.schedulers[0].last_epoch == 2))
lambda op: op._schedulers[0].last_epoch == 2))
trainer.shutdown()
@@ -526,10 +498,7 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
TorchTrainable = TorchTrainer.as_trainable(
**{
"model_creator": model_creator,
"data_creator": data_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": lambda config: nn.MSELoss(),
"training_operator_cls": Operator,
"num_workers": num_workers,
"use_gpu": False,
"backend": "gloo",
@@ -560,11 +529,7 @@ def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
def test_save_and_restore(ray_start_2_cpus, num_workers,
tmp_path): # noqa: F811
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=num_workers)
training_operator_cls=Operator, num_workers=num_workers)
trainer1.train()
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
@@ -574,11 +539,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers,
trainer1.shutdown()
trainer2 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=num_workers)
training_operator_cls=Operator, num_workers=num_workers)
trainer2.load(checkpoint_path)
model2 = trainer2.get_model()
@@ -597,12 +558,7 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
if not dist.is_available():
return
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
wrap_ddp=False,
num_workers=2)
training_operator_cls=Operator, wrap_ddp=False, num_workers=2)
trainer1.train()
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
@@ -613,12 +569,7 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
trainer1.shutdown()
trainer2 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
wrap_ddp=False,
num_workers=2)
training_operator_cls=Operator, wrap_ddp=False, num_workers=2)
trainer2.load(checkpoint_path)
model2 = trainer2.get_model()
@@ -672,12 +623,14 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
step_with_fail = gen_step_with_fail(3)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=single_loader,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
training_operator_cls=TestOperator,
config={"batch_size": 100000},
num_workers=2)
@@ -697,13 +650,15 @@ def test_resize(ray_start_2_cpus): # noqa: F811
step_with_fail = gen_step_with_fail(1)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=single_loader,
optimizer_creator=optimizer_creator,
training_operator_cls=TestOperator,
config={"batch_size": 100000},
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
@ray.remote
@@ -728,13 +683,16 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
step_with_fail = gen_step_with_fail(2)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=single_loader,
optimizer_creator=optimizer_creator,
training_operator_cls=TestOperator,
config={"batch_size": 100000},
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
# MAX RETRIES SHOULD BE ON BY DEFAULT
@@ -778,12 +736,13 @@ def test_multi_input_model(ray_start_2_cpus):
)
return train_loader, None
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=1)
Operator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
data_creator,
loss_creator=lambda config: nn.MSELoss())
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
@@ -794,4 +753,5 @@ def test_multi_input_model(ray_start_2_cpus):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))
+54 -60
View File
@@ -50,19 +50,22 @@ def create_dataloaders(config):
class TestTorchRunner(unittest.TestCase):
def setUp(self):
self.Operator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
create_dataloaders,
loss_creator=loss_creator)
def testValidate(self):
class MockOperator(TrainingOperator):
class MockOperator(self.Operator):
def setup(self, config):
super(MockOperator, self).setup(config)
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
runner = TorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
training_operator_cls=MockOperator)
runner.setup()
runner = TorchRunner(training_operator_cls=MockOperator)
runner.setup_operator()
runner.train_epoch()
runner.train_epoch()
result = runner.train_epoch()
@@ -72,21 +75,17 @@ class TestTorchRunner(unittest.TestCase):
self.assertEqual(result["epoch"], 3)
def testtrain_epoch(self):
class MockOperator(TrainingOperator):
class MockOperator(self.Operator):
def setup(self, config):
super(MockOperator, self).setup(config)
self.count = 0
def train_epoch(self, *args, **kwargs):
self.count += 1
return {"count": self.count}
runner = TorchRunner(
model_creator,
create_dataloaders,
optimizer_creator,
loss_creator,
training_operator_cls=MockOperator)
runner.setup()
runner = TorchRunner(training_operator_cls=MockOperator)
runner.setup_operator()
runner.train_epoch(num_steps=1)
runner.train_epoch(num_steps=1)
result = runner.train_epoch()
@@ -95,11 +94,6 @@ class TestTorchRunner(unittest.TestCase):
self.assertEqual(result["epoch"], 3)
def testGivens(self):
class MockOperator(TrainingOperator):
def setup(self, config):
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
def three_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
@@ -109,20 +103,27 @@ class TestTorchRunner(unittest.TestCase):
]
return opts[0], opts[1], opts[2]
runner = TorchRunner(
three_model_creator,
single_loader,
three_optimizer_creator,
loss_creator,
training_operator_cls=MockOperator)
runner.setup()
class MockOperator(TrainingOperator):
def setup(self, config):
models = three_model_creator(config)
optimizers = three_optimizer_creator(models, config)
loader = single_loader(config)
loss = loss_creator(config)
self.models, self.optimizers, self.criterion = \
self.register(models=models, optimizers=optimizers,
criterion=loss)
self.register_data(train_loader=loader, validation_loader=None)
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
self.validate = MagicMock(returns=dict(mean_accuracy=10))
runner = TorchRunner(training_operator_cls=MockOperator)
runner.setup_operator()
self.assertEqual(len(runner.given_models), 3)
self.assertEqual(len(runner.given_optimizers), 3)
runner2 = TorchRunner(model_creator, single_loader, optimizer_creator,
loss_creator)
runner2.setup()
runner2 = TorchRunner(training_operator_cls=self.Operator)
runner2.setup_operator()
self.assertNotEqual(runner2.given_models, runner2.models)
self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
@@ -132,49 +133,42 @@ class TestTorchRunner(unittest.TestCase):
return (LinearDataset(2, 5), LinearDataset(2, 5, size=400),
LinearDataset(2, 5, size=400))
runner = TorchRunner(model_creator, three_data_loader,
optimizer_creator, loss_creator)
with self.assertRaises(ValueError):
runner.setup()
ThreeOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
three_data_loader,
loss_creator=loss_creator)
runner2 = TorchRunner(model_creator, three_data_loader,
optimizer_creator, loss_creator)
runner = TorchRunner(training_operator_cls=ThreeOperator)
with self.assertRaises(ValueError):
runner2.setup()
runner.setup_operator()
runner2 = TorchRunner(training_operator_cls=ThreeOperator)
with self.assertRaises(ValueError):
runner2.setup_operator()
def testSingleLoader(self):
runner = TorchRunner(model_creator, single_loader, optimizer_creator,
loss_creator)
runner.setup()
SingleOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=loss_creator)
runner = TorchRunner(training_operator_cls=SingleOperator)
runner.setup_operator()
runner.train_epoch()
with self.assertRaises(ValueError):
runner.validate()
def testNativeLoss(self):
runner = TorchRunner(
NativeOperator = TrainingOperator.from_creators(
model_creator,
single_loader,
optimizer_creator,
single_loader,
loss_creator=nn.MSELoss)
runner.setup()
runner = TorchRunner(training_operator_cls=NativeOperator)
runner.setup_operator()
runner.train_epoch()
def testMultiModel(self):
def multi_model_creator(config):
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
def multi_optimizer_creator(models, config):
opts = [
torch.optim.SGD(model.parameters(), lr=0.1) for model in models
]
return opts[0], opts[1], opts[2]
runner = TorchRunner(multi_model_creator, single_loader,
multi_optimizer_creator, loss_creator)
with self.assertRaises(ValueError):
runner.setup()
class TestLocalDistributedRunner(unittest.TestCase):
def setUp(self):
+7 -2
View File
@@ -4,6 +4,7 @@ logger = logging.getLogger(__name__)
TorchTrainer = None
TrainingOperator = None
BaseTorchTrainable = None
CreatorOperator = None
try:
import torch # noqa: F401
@@ -11,9 +12,13 @@ try:
from ray.util.sgd.torch.torch_trainer import (TorchTrainer,
BaseTorchTrainable)
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd.torch.training_operator import (TrainingOperator,
CreatorOperator)
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
__all__ = [
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
"CreatorOperator"
]
except ImportError as e:
logger.warning(e)
logger.warning("PyTorch not found. TorchTrainer will not be available")
@@ -4,7 +4,6 @@ import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.utils import setup_process_group
@@ -43,9 +42,6 @@ class DistributedTorchRunner(TorchRunner):
self.add_dist_sampler = add_dist_sampler
self.world_rank = None
def setup(self):
raise RuntimeError("Need to call setup commands separately.")
def setup_process_group(self, url, world_rank, world_size, timeout):
"""Connects the distributed PyTorch backend.
@@ -60,7 +56,7 @@ class DistributedTorchRunner(TorchRunner):
setup_process_group(
url, world_rank, world_size, timeout, backend=self.backend)
def setup_ddp_and_operator(self):
def setup_operator(self):
"""Runs distributed coordination components.
This helps avoid timeouts due to creator functions (perhaps
@@ -70,29 +66,18 @@ class DistributedTorchRunner(TorchRunner):
if self.use_gpu and torch.cuda.is_available():
device_ids = self.get_device_ids()
# Wrap dataloaders
self._wrap_dataloaders()
training_models = self.models
if self.wrap_ddp:
# This needs to happen after apex
training_models = [
DistributedDataParallel(model, device_ids=device_ids)
for model in self.models
]
self.training_operator = self.training_operator_cls(
self.config,
models=training_models,
optimizers=self.optimizers,
criterion=self.criterion,
train_loader=self.train_loader,
validation_loader=self.validation_loader,
world_rank=self.world_rank,
schedulers=self.schedulers,
device_ids=device_ids,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
wrap_ddp=self.wrap_ddp,
wrap_distributed_sampler=True,
add_dist_sampler=self.add_dist_sampler,
scheduler_step_freq=self.scheduler_step_freq)
def get_device_ids(self):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
@@ -72,6 +72,17 @@ def init_hook():
class Training(TrainingOperator):
def setup(self, config):
model = getattr(models, config.get("model"))()
optimizer = optim.SGD(
model.parameters(), lr=0.01 * config["lr_scaler"])
train_data = LinearDataset(4,
2) # Have to use dummy data for training.
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
)
self.register_data(train_loader=train_data, validation_loader=None)
data = torch.randn(args.batch_size, 3, 224, 224)
target = torch.LongTensor(args.batch_size).random_() % 1000
if args.cuda:
@@ -107,14 +118,12 @@ if __name__ == "__main__":
print("Number of %ss: %d" % (device, num_workers))
trainer = TorchTrainer(
model_creator=lambda cfg: getattr(models, args.model)(),
optimizer_creator=lambda model, cfg: optim.SGD(
model.parameters(), lr=0.01 * cfg.get("lr_scaler")),
data_creator=lambda cfg: LinearDataset(4, 2), # Mock dataset.
initialization_hook=init_hook,
config=dict(
lr_scaler=num_workers),
training_operator_cls=Training,
initialization_hook=init_hook,
config={
"lr_scaler": num_workers,
"model": args.model
},
num_workers=num_workers,
use_gpu=args.cuda,
use_fp16=args.fp16,
@@ -2,6 +2,8 @@ import os
import torch
import torch.nn as nn
import argparse
from filelock import FileLock
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
@@ -9,9 +11,9 @@ import torchvision.transforms as transforms
from tqdm import trange
import ray
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
from ray.util.sgd.torch.resnet import ResNet18
from ray.util.sgd.utils import BATCH_SIZE
from ray.util.sgd.utils import BATCH_SIZE, override
def initialization_hook():
@@ -24,47 +26,66 @@ def initialization_hook():
# os.environ["NCCL_DEBUG"] = "INFO"
def cifar_creator(config):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
]) # meanstd transformation
class CifarTrainingOperator(TrainingOperator):
@override(TrainingOperator)
def setup(self, config):
# Create model.
model = ResNet18(config)
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(
root="~/data", train=True, download=True, transform=transform_train)
validation_dataset = CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
# Create optimizer.
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.1),
momentum=config.get("momentum", 0.9))
if config["test_mode"]:
train_dataset = Subset(train_dataset, list(range(64)))
validation_dataset = Subset(validation_dataset, list(range(64)))
# Load in training and validation data.
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
]) # meanstd transformation
train_loader = DataLoader(
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
validation_loader = DataLoader(
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
return train_loader, validation_loader
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
with FileLock(".ray.lock"):
train_dataset = CIFAR10(
root="~/data",
train=True,
download=True,
transform=transform_train)
validation_dataset = CIFAR10(
root="~/data",
train=False,
download=False,
transform=transform_test)
if config["test_mode"]:
train_dataset = Subset(train_dataset, list(range(64)))
validation_dataset = Subset(validation_dataset, list(range(64)))
def optimizer_creator(model, config):
"""Returns optimizer"""
return torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.1),
momentum=config.get("momentum", 0.9))
train_loader = DataLoader(
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
validation_loader = DataLoader(
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
# Create scheduler.
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[150, 250, 350], gamma=0.1)
def scheduler_creator(optimizer, config):
return torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[150, 250, 350], gamma=0.1)
# Create loss.
criterion = nn.CrossEntropyLoss()
# Register all components.
self.model, self.optimizer, self.criterion, self.scheduler = \
self.register(models=model, optimizers=optimizer,
criterion=criterion, schedulers=scheduler)
self.register_data(
train_loader=train_loader, validation_loader=validation_loader)
if __name__ == "__main__":
@@ -105,11 +126,7 @@ if __name__ == "__main__":
ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True)
trainer1 = TorchTrainer(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
scheduler_creator=scheduler_creator,
training_operator_cls=CifarTrainingOperator,
initialization_hook=initialization_hook,
num_workers=args.num_workers,
config={
@@ -3,6 +3,8 @@ import os
import torch
import torch.nn as nn
import argparse
from filelock import FileLock
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from torch.utils.data import DataLoader, Subset
@@ -11,9 +13,9 @@ import torchvision.transforms as transforms
import ray
from ray.tune import CLIReporter
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
from ray.util.sgd.torch.resnet import ResNet18
from ray.util.sgd.utils import BATCH_SIZE
from ray.util.sgd.utils import BATCH_SIZE, override
def initialization_hook():
@@ -26,42 +28,62 @@ def initialization_hook():
# os.environ["NCCL_DEBUG"] = "INFO"
def cifar_creator(config):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
]) # meanstd transformation
class CifarTrainingOperator(TrainingOperator):
@override(TrainingOperator)
def setup(self, config):
# Create model.
model = ResNet18(config)
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(
root="~/data", train=True, download=True, transform=transform_train)
validation_dataset = CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
# Create optimizer.
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.1),
momentum=config.get("momentum", 0.9))
if config.get("test_mode"):
train_dataset = Subset(train_dataset, list(range(64)))
validation_dataset = Subset(validation_dataset, list(range(64)))
# Load in training and validation data.
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
]) # meanstd transformation
train_loader = DataLoader(
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
validation_loader = DataLoader(
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
return train_loader, validation_loader
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
with FileLock(".ray.lock"):
train_dataset = CIFAR10(
root="~/data",
train=True,
download=True,
transform=transform_train)
validation_dataset = CIFAR10(
root="~/data",
train=False,
download=False,
transform=transform_test)
def optimizer_creator(model, config):
"""Returns optimizer"""
return torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.1),
momentum=config.get("momentum", 0.9))
if config.get("test_mode"):
train_dataset = Subset(train_dataset, list(range(64)))
validation_dataset = Subset(validation_dataset, list(range(64)))
train_loader = DataLoader(
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
validation_loader = DataLoader(
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
# Create loss.
criterion = nn.CrossEntropyLoss()
self.model, self.optimizer, self.criterion = \
self.register(models=model, optimizers=optimizer,
criterion=criterion,)
self.register_data(
train_loader=train_loader, validation_loader=validation_loader)
if __name__ == "__main__":
@@ -101,10 +123,7 @@ if __name__ == "__main__":
ray.init(address=args.address, log_to_driver=True)
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
training_operator_cls=CifarTrainingOperator,
initialization_hook=initialization_hook,
num_workers=args.num_workers,
config={
+48 -45
View File
@@ -9,6 +9,7 @@ import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from filelock import FileLock
from tqdm import trange
@@ -24,22 +25,6 @@ from ray.util.sgd.torch import TrainingOperator
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
def data_creator(config):
dataset = datasets.MNIST(
root="~/mnist/",
download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
if config.get("test_mode"):
dataset = torch.utils.data.Subset(dataset, list(range(64)))
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
return dataloader
class Generator(nn.Module):
def __init__(self, latent_vector_size, features=32, num_channels=1):
super(Generator, self).__init__()
@@ -101,35 +86,57 @@ class LeNet(nn.Module):
return F.log_softmax(x, dim=1)
def model_creator(config):
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
discriminator = Discriminator()
discriminator.apply(weights_init)
generator = Generator(
latent_vector_size=config.get("latent_vector_size", 100))
generator.apply(weights_init)
return discriminator, generator
def optimizer_creator(models, config):
net_d, net_g = models
discriminator_opt = optim.Adam(
net_d.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
generator_opt = optim.Adam(
net_g.parameters(), lr=config.get("lr", 0.01), betas=(0.5, 0.999))
return discriminator_opt, generator_opt
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class GANOperator(TrainingOperator):
def setup(self, config):
discriminator = Discriminator()
discriminator.apply(weights_init)
generator = Generator(
latent_vector_size=config.get("latent_vector_size", 100))
generator.apply(weights_init)
models = (discriminator, generator)
discriminator_opt = optim.Adam(
discriminator.parameters(),
lr=config.get("lr", 0.01),
betas=(0.5, 0.999))
generator_opt = optim.Adam(
generator.parameters(),
lr=config.get("lr", 0.01),
betas=(0.5, 0.999))
optimizers = (discriminator_opt, generator_opt)
with FileLock(".ray.lock"):
dataset = datasets.MNIST(
root="~/mnist/",
download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
if config.get("test_mode"):
dataset = torch.utils.data.Subset(dataset, list(range(64)))
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
self.models, self.optimizers, self.criterion = self.register(
models=models, optimizers=optimizers, criterion=nn.BCELoss())
self.register_data(
train_loader=train_dataloader, validation_loader=None)
self.model = self.models[0]
self.optimizer = self.optimizers[0]
self.classifier = LeNet()
self.classifier.load_state_dict(
torch.load(config["classification_model_path"]))
@@ -232,10 +239,6 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
"classification_model_path": MODEL_PATH
}
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.BCELoss,
training_operator_cls=GANOperator,
num_workers=num_workers,
config=config,
@@ -9,6 +9,7 @@
from os.path import join
from ray.util.sgd.torch import TrainingOperator
from tqdm import trange
import torch.nn as nn
@@ -130,11 +131,13 @@ def main():
ray.init(address=args.ray_address)
trainer = TorchTrainer(
CustomTrainingOperator = TrainingOperator.from_creators(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=loss_creator,
data_creator=data_creator,
loss_creator=loss_creator)
trainer = TorchTrainer(
training_operator_cls=CustomTrainingOperator,
use_tqdm=True,
use_fp16=args.amp,
apex_args={"opt_level": "O1"},
@@ -7,6 +7,47 @@ in the documentation.
"""
# yapf: disable
# __torch_operator_start__
from ray.util.sgd.torch import TrainingOperator
class MyTrainingOperator(TrainingOperator):
def setup(self, config):
# Setup all components needed for training here. This could include
# data, models, optimizers, loss & schedulers.
# Setup data loaders.
train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2,
5)
train_loader = DataLoader(train_dataset,
batch_size=config["batch_size"])
val_loader = DataLoader(val_dataset,
batch_size=config["batch_size"])
# Setup model.
model = nn.Linear(1, 1)
# Setup optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
# Setup loss.
criterion = torch.nn.BCELoss()
# Setup scheduler.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
# Register all of these components with Ray SGD.
# This allows Ray SGD to do framework level setup like Cuda, DDP,
# Distributed Sampling, FP16.
# We also assign the return values of self.register to instance
# attributes so we can access it in our custom training/validation
# methods.
self.model, self.optimizer, self.criterion, self.scheduler = \
self.register(models=model, optimizers=optimizer,
criterion=criterion,
scheduler=scheduler)
self.register_data(train_loader=train_loader, validation_loader=val_loader)
# __torch_operator_end__
# __torch_model_start__
import torch.nn as nn
@@ -103,6 +144,21 @@ def scheduler_creator(optimizer, config):
# __torch_scheduler_end__
# __backwards_compat__start
from ray.util.sgd import TorchTrainer
MyTrainingOperator = TrainingOperator.from_creators(
model_creator=model_creator, optimizer_creator=optimizer_creator,
loss_creator=loss_creator, scheduler_creator=scheduler_creator,
data_creator=data_creator)
trainer = TorchTrainer(
training_operator_cls=MyTrainingOperator,
scheduler_step_freq="epoch", # if scheduler_creator is passed in
config={"lr": 0.001, "batch_size": 64})
# __backwards_compat_end
# __torch_ray_start__
import ray
@@ -114,12 +170,8 @@ ray.init()
from ray.util.sgd import TorchTrainer
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=scheduler_creator,
scheduler_step_freq="epoch", # if scheduler_creator is set
training_operator_cls=MyTrainingOperator,
scheduler_step_freq="epoch", # if scheduler is used
config={"lr": 0.001, "batch_size": 64})
# __torch_trainer_end__
@@ -4,6 +4,7 @@ import time
import torch
import torch.utils.data
from filelock import FileLock
from torch import nn
import torchvision
@@ -50,28 +51,6 @@ def get_dataset(name,
return ds, num_classes
def data_creator(config):
# Within a machine, this code runs synchronously.
dataset, num_classes = get_dataset(
args.dataset, "train", get_transform(train=True))
config["num_classes"] = num_classes
dataset_test, _ = get_dataset(
args.dataset, "val", get_transform(train=False))
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.data_workers,
collate_fn=utils.collate_fn,
drop_last=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=1,
num_workers=args.data_workers,
collate_fn=utils.collate_fn)
return data_loader, data_loader_test
def get_transform(train):
base_size = 520
crop_size = 480
@@ -101,7 +80,75 @@ def criterion(inputs, target):
return losses["out"] + 0.5 * losses["aux"]
def get_optimizer(model, aux_loss):
params_to_optimize = [
{
"params": [
p for p in model.backbone.parameters() if p.requires_grad
]
},
{
"params": [
p for p in model.classifier.parameters() if p.requires_grad
]
},
]
if aux_loss:
params = [
p for p in model.aux_classifier.parameters() if p.requires_grad
]
params_to_optimize.append({"params": params, "lr": args.lr * 10})
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
return optimizer
class SegOperator(TrainingOperator):
def setup(self, config):
args = config["args"]
# Create Data Loaders.
with FileLock(".ray.lock"):
# Within a machine, this code runs synchronously.
dataset, num_classes = get_dataset(
args.dataset, "train", get_transform(train=True))
config["num_classes"] = num_classes
dataset_test, _ = get_dataset(
args.dataset, "val", get_transform(train=False))
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.data_workers,
collate_fn=utils.collate_fn,
drop_last=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=1,
num_workers=args.data_workers,
collate_fn=utils.collate_fn)
# Create model.
model = torchvision.models.segmentation.__dict__[args.model](
num_classes=config["num_classes"],
aux_loss=args.aux_loss,
pretrained=args.pretrained)
if config["num_workers"] > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# Create optimizer.
optimizer = get_optimizer(model, aux_loss=args.aux_loss)
# Register components.
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
train_loader=data_loader,
validation_loader=data_loader_test)
def train_batch(self, batch, batch_info):
image, target = batch
image, target = image.to(self.device), target.to(self.device)
@@ -132,43 +179,6 @@ class SegOperator(TrainingOperator):
return confmat
def optimizer_creator(model, config):
args = config["args"]
params_to_optimize = [
{
"params": [
p for p in model.backbone.parameters() if p.requires_grad
]
},
{
"params": [
p for p in model.classifier.parameters() if p.requires_grad
]
},
]
if args.aux_loss:
params = [
p for p in model.aux_classifier.parameters() if p.requires_grad
]
params_to_optimize.append({"params": params, "lr": args.lr * 10})
return torch.optim.SGD(
params_to_optimize,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
def model_creator(config):
args = config["args"]
model = torchvision.models.segmentation.__dict__[args.model](
num_classes=config["num_classes"],
aux_loss=args.aux_loss,
pretrained=args.pretrained)
if config["num_workers"] > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
@@ -176,9 +186,6 @@ def main(args):
start_time = time.time()
config = {"args": args, "num_workers": args.num_workers}
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
training_operator_cls=SegOperator,
use_tqdm=True,
use_fp16=True,
@@ -13,6 +13,7 @@ import torch
import torch.nn as nn
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
class LinearDataset(torch.utils.data.Dataset):
@@ -67,12 +68,12 @@ def data_creator(config):
def train_example(num_workers=1, use_gpu=False):
CustomTrainingOperator = TrainingOperator.from_creators(
model_creator=model_creator, optimizer_creator=optimizer_creator,
data_creator=data_creator, scheduler_creator=scheduler_creator,
loss_creator=nn.MSELoss)
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=scheduler_creator,
training_operator_cls=CustomTrainingOperator,
num_workers=num_workers,
use_gpu=use_gpu,
config={
@@ -92,81 +92,76 @@ def announce_training(args, dataset_len, t_total):
logger.info(" Total optimization steps = %d", t_total)
def model_creator(config):
with FileLock(os.path.expanduser("~/.download.lock")):
args = config["args"]
processor = processors[args.task_name]()
label_list = processor.get_labels()
num_labels = len(label_list)
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
return model
def optimizer_creator(model, cfg):
args = cfg["args"]
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": args.weight_decay,
},
{
"params": [
p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0
},
]
return AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
eps=args.adam_epsilon)
def data_creator(config):
args = config["args"]
start = time.time()
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name
if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
logger.info(f"tokenizer instantiation time: {time.time() - start}")
train_dataset = load_and_cache_examples(
args, args.task_name, tokenizer, evaluate=False)
train_sampler = RandomSampler(
train_dataset) if not dist.is_initialized() else None
return DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.per_device_train_batch_size)
class TransformerOperator(TrainingOperator):
def setup(self, config):
self.args = args = config["args"]
start = time.time()
self.tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name
if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
logger.info(f"tokenizer instantiation time: {time.time() - start}")
# Load data.
train_dataset = load_and_cache_examples(
args, args.task_name, self.tokenizer, evaluate=False)
train_sampler = RandomSampler(
train_dataset) if not dist.is_initialized() else None
train_loader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.per_device_train_batch_size)
# Create model.
with FileLock(os.path.expanduser("~/.download.lock")):
processor = processors[args.task_name]()
label_list = processor.get_labels()
num_labels = len(label_list)
model_config = AutoConfig.from_pretrained(
args.config_name
if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=model_config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
# Create optimizer.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": args.weight_decay,
},
{
"params": [
p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
eps=args.adam_epsilon)
# Register components.
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
train_loader=train_loader,
validation_loader=None)
self.train_data_len = len(self.train_loader)
self._warmup_scheduler = get_linear_schedule_with_warmup(
@@ -334,9 +329,6 @@ def main():
# Training
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
training_operator_cls=TransformerOperator,
use_fp16=args.fp16,
apex_args={"opt_level": args.fp16_opt_level},
@@ -36,6 +36,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
),
)
# Use FileLock to prevent parallel writes that may corrupt data.
with FileLock("/tmp/load_and_cache_examples.lock"):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s",
@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
import ray
from ray import tune
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch import TorchTrainer, TrainingOperator
from ray.util.sgd.utils import BATCH_SIZE
from ray.util.sgd.torch.examples.train_example import LinearDataset
@@ -37,12 +37,9 @@ def data_creator(config):
# __torch_tune_example__
def tune_example(num_workers=1, use_gpu=False):
def tune_example(operator_cls, num_workers=1, use_gpu=False):
TorchTrainable = TorchTrainer.as_trainable(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss, # Note that we specify a Loss class.
training_operator_cls=operator_cls,
num_workers=num_workers,
use_gpu=use_gpu,
config={BATCH_SIZE: 128}
@@ -81,4 +78,8 @@ if __name__ == "__main__":
args, _ = parser.parse_known_args()
ray.init(address=args.address)
tune_example(num_workers=args.num_workers, use_gpu=args.use_gpu)
CustomTrainingOperator = TrainingOperator.from_creators(
model_creator=model_creator, optimizer_creator=optimizer_creator,
data_creator=data_creator, loss_creator=nn.MSELoss)
tune_example(CustomTrainingOperator, num_workers=args.num_workers,
use_gpu=args.use_gpu)
+85 -146
View File
@@ -1,26 +1,15 @@
from filelock import FileLock
import logging
import inspect
import io
import itertools
import os
import tempfile
import torch
import torch.nn as nn
import ray
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS
from ray.util.sgd import utils
logger = logging.getLogger(__name__)
amp = None
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
try:
from apex import amp
except ImportError:
@@ -32,12 +21,7 @@ class TorchRunner:
"""Manages a PyTorch model for training."""
def __init__(self,
model_creator,
data_creator,
optimizer_creator,
loss_creator=None,
scheduler_creator=None,
training_operator_cls=None,
training_operator_cls,
config=None,
use_gpu=False,
serialize_data_creation=True,
@@ -45,22 +29,11 @@ class TorchRunner:
use_tqdm=False,
apex_args=None,
scheduler_step_freq=None):
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.data_creator = data_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls or TrainingOperator
self.training_operator_cls = training_operator_cls
self.config = {} if config is None else config
self.timers = utils.TimerCollection()
self.epochs = 0
self.models = None
self.optimizers = None
self.criterion = None
self.schedulers = None
self.train_loader = None
self.validation_loader = None
self.training_operator = None
self.serialize_data_creation = serialize_data_creation
self.use_gpu = use_gpu
@@ -73,107 +46,16 @@ class TorchRunner:
"https://www.github.com/nvidia/apex to use fp16 training.")
self.scheduler_step_freq = scheduler_step_freq
def _validate_loaders(self, loaders):
assert loaders, "Loaders need to be returned in data_creator."
if isinstance(loaders, (tuple, list)):
if len(loaders) == 1:
return loaders, None
elif len(loaders) == 2:
return loaders
else:
raise ValueError(
f"Number of loaders must be <= 2. Got {loaders}")
# No great way of checking type otherwise
return loaders, None
def _initialize_dataloaders(self):
logger.debug("Instantiating dataloaders.")
loaders = None
if self.serialize_data_creation:
logger.debug("Serializing the dataloading process.")
with FileLock(
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
loaders = self.data_creator(self.config)
else:
loaders = self.data_creator(self.config)
train_loader, val_loader = self._validate_loaders(loaders)
self.train_loader, self.validation_loader = train_loader, val_loader
def _create_loss(self):
if not self.loss_creator:
return
logger.debug("Creating loss.")
if inspect.isclass(self.loss_creator) and issubclass(
self.loss_creator, torch.nn.modules.loss._Loss):
self.criterion = self.loss_creator()
else:
self.criterion = self.loss_creator(self.config)
if self.use_gpu and torch.cuda.is_available():
if hasattr(self.criterion, "cuda"):
self.criterion = self.criterion.cuda()
def _create_schedulers_if_available(self):
# Learning rate schedules are optional.
if not self.scheduler_creator:
return
self.schedulers = self.scheduler_creator(self.given_optimizers,
self.config)
if not isinstance(self.schedulers, Iterable):
self.schedulers = [self.schedulers]
def _try_setup_apex(self):
"""Sets up the model for fp16 training via apex if available."""
if self.use_fp16 and amp:
self.models, self.optimizers = amp.initialize(
self.models, self.optimizers, **self.apex_args)
def setup(self):
"""Merges setup_components and setup_operator in one call."""
self.setup_components()
self.setup_operator()
def setup_components(self):
"""Runs the creator functions without any distributed coordination."""
logger.debug("Loading data.")
if self.data_creator and callable(self.data_creator):
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
f"All models must be PyTorch models: {self.models}.")
if self.use_gpu and torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer.")
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, Iterable):
self.optimizers = [self.optimizers]
self._create_schedulers_if_available()
self._try_setup_apex()
self._create_loss()
def setup_operator(self):
"""Create the training operator."""
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
train_loader=self.train_loader,
validation_loader=self.validation_loader,
world_rank=0,
schedulers=self.schedulers,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
def get_node_ip(self):
"""Returns the IP address of the current node."""
@@ -196,7 +78,6 @@ class TorchRunner:
info.update({
NUM_STEPS: num_steps,
USE_FP16: self.use_fp16,
SCHEDULER_STEP: self.scheduler_step_freq
})
with self.timers.record("train_epoch"):
if iterator is None:
@@ -223,15 +104,17 @@ class TorchRunner:
def validate(self, num_steps=None, profile=False, info=None):
"""Evaluates the model on the validation data set."""
if self.validation_loader is None:
raise ValueError("No validation dataloader provided.")
raise ValueError("No validation dataloader provided. Make sure"
"you pass in a validation_loader to "
"TrainingOperator.register_data.")
info = info or {}
self._toggle_profiling(profile=profile)
validation_loader = self.validation_loader
with self.timers.record("validation"):
iterator = self.validation_loader
iterator = validation_loader
if num_steps:
iterator = itertools.islice(
iter(self.validation_loader), num_steps)
iterator = itertools.islice(iterator, num_steps)
validation_stats = self.training_operator.validate(
iterator, info=info)
if profile:
@@ -255,32 +138,35 @@ class TorchRunner:
"models": [model.state_dict() for model in self.models],
"optimizers": [opt.state_dict() for opt in self.optimizers]
}
if self.schedulers:
schedulers = self.schedulers
if schedulers:
state.update({
"schedulers": [
scheduler.state_dict() for scheduler in self.schedulers
scheduler.state_dict() for scheduler in schedulers
]
})
# Check if fp16 is True and if NVIDIA Apex is imported.
if self.use_fp16 and amp:
state.update({"amp": amp.state_dict()})
if self.use_fp16 and self.training_operator._amp:
state.update({"amp": self.training_operator._amp.state_dict()})
return state
def load_state_dict(self, state):
"""Sets the state of the model."""
for model, state_dict in zip(self.models, state["models"]):
models = self.models
for model, state_dict in zip(models, state["models"]):
model.load_state_dict(state_dict)
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
optimizers = self.optimizers
for optimizer, state_dict in zip(optimizers, state["optimizers"]):
optimizer.load_state_dict(state_dict)
if self.schedulers:
for scheduler, state_dict in zip(self.schedulers,
state["schedulers"]):
schedulers = self.schedulers
if schedulers:
for scheduler, state_dict in zip(schedulers, state["schedulers"]):
scheduler.load_state_dict(state_dict)
if self.use_fp16 and "amp" in state and amp:
amp.load_state_dict(state["amp"])
if self.use_fp16 and "amp" in state and self.training_operator._amp:
self.training_operator._amp.load_state_dict(state["amp"])
self.epochs = state["epoch"]
self.training_operator.load_state_dict(state_dict)
self.training_operator.load_state_dict(state["operator"])
def state_stream(self):
"""Returns a bytes object for the state dict."""
@@ -304,14 +190,67 @@ class TorchRunner:
def shutdown(self):
"""Attempts to shut down the worker."""
del self.training_operator
del self.validation_loader
del self.train_loader
del self.criterion
del self.optimizers
del self.models
if torch.cuda.is_available():
torch.cuda.empty_cache()
@property
def models(self):
if not hasattr(self.training_operator, "_original_models"):
raise RuntimeError("Training Operator does not have any "
"registered models. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._original_models
@property
def optimizers(self):
if not hasattr(self.training_operator, "_optimizers"):
raise RuntimeError("Training Operator does not have any "
"registered optimizers. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._optimizers
@property
def schedulers(self):
if not hasattr(self.training_operator, "_schedulers"):
raise RuntimeError("Training Operator does not have any "
"registered schedulers. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._schedulers
@property
def train_loader(self):
if not hasattr(self.training_operator, "_train_loader"):
logger.warning("Training Operator does not have any "
"registered train loader. If this is "
"unexepected, make sure to call "
"self.register_data(...) inside the setup method "
"of your Training Operator.")
return None
return self.training_operator._train_loader
@property
def validation_loader(self):
if not hasattr(self.training_operator, "_validation_loader"):
logger.warning("Training Operator does not have any "
"registered validation loader. If this is "
"unexepected, make sure to call "
"self.register_data(...) inside the setup method "
"of your Training Operator.")
return None
return self.training_operator._validation_loader
@property
def criterion(self):
if not hasattr(self.training_operator, "_criterion"):
raise RuntimeError("Training Operator does not have any "
"registered criterion. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._criterion
@property
def given_models(self):
if len(self.models) > 1:
+61 -113
View File
@@ -13,6 +13,7 @@ from ray.exceptions import RayActorError
from ray.tune import Trainable
from ray.tune.resources import Resources
from ray.tune.utils.util import merge_dicts
from ray.util import log_once
from ray.util.sgd.torch.distributed_torch_runner import (
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
@@ -49,80 +50,44 @@ class TorchTrainer:
.. code-block:: python
def model_creator(config):
return nn.Linear(1, 1)
class MyTrainingOperator(TrainingOperator):
def setup(self, config):
model = nn.Linear(1, 1)
optimizer = torch.optim.SGD(
model.parameters(), lr=config.get("lr", 1e-4))
loss = torch.nn.MSELoss()
def optimizer_creator(model, config):
return torch.optim.SGD(
model.parameters(), lr=config.get("lr", 1e-4))
batch_size = config["batch_size"]
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
criterion=loss)
def data_creator(config):
batch_size = config["batch_size"]
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
return train_loader, val_loader
self.register_data(
train_loader=train_loader,
validation_loader=val_loader)
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
training_operator_cls=MyTrainingOperator,
config={"batch_size": 32},
use_gpu=True
)
for i in range(4):
trainer.train()
The creator functions will execute before distributed coordination and
training is setup. This is so that creator functions that download
large datasets will not trigger any timeouts.
The order of operations for creator functions are:
``data_creator`` -> ``model_creator`` -> ``optimizer_creator`` ->
``scheduler_creator`` -> ``loss_creator``.
Args:
model_creator (dict -> Model(s)): Constructor function that takes in
config and returns the model(s) to be optimized. These must be
``torch.nn.Module`` objects. If multiple models are returned,
a ``training_operator_cls`` must be specified. You do not need to
handle GPU/devices in this function; RaySGD will do that under
the hood.
data_creator (dict -> Iterable(s)): Constructor function
that takes in the passed config and returns one or
two Iterable objects. Note that even though two Iterable objects
can be returned, only one will be used for training, and the
other will be used for validation. If not provided, you must
provide a custom TrainingOperator.
optimizer_creator ((models, dict) -> optimizers): Constructor
function that takes in the return values from
``model_creator`` and the passed config and returns One or
more Torch optimizer objects. You do not need to handle
GPU/devices in this function; ``RaySGD`` will do that for you.
loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
function for the training loss. This can be either a function that
takes in the provided config for customization or a subclass
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
If not provided, you must provide a custom TrainingOperator.
scheduler_creator ((optimizers, dict) -> scheduler):
A constructor function for the torch scheduler. This is
a function that takes in the generated optimizers (from
``optimizer_creator``) provided config for customization.
Be sure to set ``scheduler_step_freq`` to increment the
scheduler correctly.
training_operator_cls (type): Custom training operator class
that subclasses the TrainingOperator class. This class
will be copied onto all remote workers and used to specify
custom training and validation operations. Defaults to
TrainingOperator.
training components and custom training and validation operations.
config (dict): Custom configuration value to be passed to
all creator and operator constructors.
all operator constructors.
num_workers (int): the number of workers used in distributed
training. If 1, the worker will not be wrapped with
DistributedDataParallel.
@@ -134,10 +99,6 @@ class TorchTrainer:
support "nccl", "gloo", and "auto". If "auto", RaySGD will
automatically use "nccl" if `use_gpu` is True, and "gloo"
otherwise.
serialize_data_creation (bool): A filelock will be used
to ensure no race conditions in data downloading among
different workers on the same node (using the local file system).
Defaults to True.
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
timeout_s (float): Seconds before the torch process group
@@ -171,12 +132,7 @@ class TorchTrainer:
def __init__(
self,
*,
model_creator,
data_creator,
optimizer_creator,
loss_creator=None,
scheduler_creator=None,
training_operator_cls=None,
training_operator_cls,
initialization_hook=None,
config=None,
num_workers=1,
@@ -185,16 +141,33 @@ class TorchTrainer:
backend="auto",
wrap_ddp=True,
timeout_s=NCCL_TIMEOUT_S,
serialize_data_creation=True,
use_fp16=False,
use_tqdm=False,
apex_args=None,
add_dist_sampler=True,
scheduler_step_freq=None,
# Deprecated Args.
num_replicas=None,
batch_size=None,
model_creator=None,
data_creator=None,
optimizer_creator=None,
scheduler_creator=None,
loss_creator=None,
serialize_data_creation=None,
data_loader_args=None,
):
if (model_creator or data_creator or optimizer_creator
or scheduler_creator or loss_creator):
raise DeprecationWarning(
"Creator functions are deprecated. You should create a "
"custom TrainingOperator, override setup, and register all "
"training state there. See TrainingOperator for more info. "
"If you would still like to use creator functions, you can "
"do CustomOperator = TrainingOperator.from_creators("
"model_creator, ...) and pass in CustomOperator into "
"TorchTrainer.")
if num_workers > 1 and not dist.is_available():
raise ValueError(
("Distributed PyTorch is not supported on macOS. "
@@ -202,10 +175,6 @@ class TorchTrainer:
"For more information, see "
"https://github.com/pytorch/examples/issues/467."))
if not (callable(model_creator) and callable(optimizer_creator)):
raise ValueError(
"Must provide a callable model_creator and optimizer_creator.")
if num_replicas is not None:
raise DeprecationWarning(
"num_replicas is deprecated. Use num_workers instead.")
@@ -217,24 +186,23 @@ class TorchTrainer:
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
"batch size to be used across all workers.")
if serialize_data_creation is True:
if log_once("serialize_data_creation"):
logging.warning(
"serialize_data_creation is deprecated and will be "
"ignored. If you require serialized data loading you "
"should implement this in TrainingOperator.setup. "
"You may find FileLock useful here.")
if data_loader_args:
raise ValueError(
raise DeprecationWarning(
"data_loader_args is deprecated. You can return a "
"torch.utils.data.DataLoader in data_creator. Ray will "
"automatically set a DistributedSampler if a DataLoader is "
"returned and num_workers > 1.")
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.data_creator = data_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls
if not training_operator_cls and not loss_creator:
raise ValueError("If a loss_creator is not provided, you must "
"provide a custom training operator.")
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
if use_gpu == "auto":
@@ -269,7 +237,7 @@ class TorchTrainer:
self.local_worker = DeactivatedRunner()
self.remote_workers = []
if scheduler_creator:
if scheduler_step_freq:
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
@@ -309,11 +277,6 @@ class TorchTrainer:
worker_config[BATCH_SIZE] = batch_size_per_worker
params = dict(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
config=worker_config,
serialize_data_creation=self.serialize_data_creation,
@@ -328,7 +291,7 @@ class TorchTrainer:
self.local_worker = TorchRunner(**params)
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
self.local_worker.setup()
self.local_worker.setup_operator()
else:
params.update(
backend=self.backend,
@@ -355,15 +318,6 @@ class TorchTrainer:
# Compute URL for initializing distributed PyTorch
address = setup_address()
# Runs the creator functions.
remote_component_setup = [
worker.setup_components.remote()
for i, worker in enumerate(self.remote_workers)
]
self.local_worker.setup_components()
# Get setup tasks in order to throw errors on failure
ray.get(remote_component_setup)
# Setup the process group among all workers.
remote_pgroup_setups = [
worker.setup_process_group.remote(address, i + 1, num_workers,
@@ -377,10 +331,10 @@ class TorchTrainer:
# Runs code that requires all creator functions to have run.
remote_operator_setups = [
worker.setup_ddp_and_operator.remote()
worker.setup_operator.remote()
for worker in self.remote_workers
]
self.local_worker.setup_ddp_and_operator()
self.local_worker.setup_operator()
# Get setup tasks in order to throw errors on failure
ray.get(remote_operator_setups)
@@ -421,10 +375,10 @@ class TorchTrainer:
Returns:
(dict | list) A dictionary of metrics for training.
You can provide custom metrics by passing in a custom
``training_operator_cls``. If ``reduce_results=False``,
this will return a list of metric dictionaries whose
length will be equal to ``num_workers``.
You can provide custom metrics by implementing a custom
training loop. If ``reduce_results=False``, this will return a
list of metric dictionaries whose length will be equal to
``num_workers``.
"""
assert max_retries >= 0, "`max_retries` must be non-negative."
assert isinstance(dataset, Dataset) is not None \
@@ -577,12 +531,12 @@ class TorchTrainer:
return worker_stats
def update_scheduler(self, metric):
"""Calls ``scheduler.step(metric)`` on all schedulers.
"""Calls ``scheduler.step(metric)`` on all registered schedulers.
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
"""
self.apply_all_operators(
lambda op: [sched.step(metric) for sched in op.schedulers])
lambda op: [sched.step(metric) for sched in op._schedulers])
def get_model(self):
"""Returns the learned model(s)."""
@@ -729,10 +683,7 @@ class TorchTrainer:
.. code-block:: python
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
training_operator_cls=MyTrainingOperator,
num_gpus=2
)
analysis = tune.run(
@@ -781,10 +732,7 @@ class BaseTorchTrainable(Trainable):
.. code-block:: python
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
training_operator_cls=MyTrainingOperator,
num_gpus=2
)
# TorchTrainable is subclass of BaseTorchTrainable.
+572 -136
View File
@@ -1,10 +1,23 @@
import inspect
import logging
import os
import tempfile
import torch
import torch.nn as nn
from filelock import FileLock
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
NUM_SAMPLES)
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
SCHEDULER_STEP_BATCH, SCHEDULER_STEP)
from ray.util.sgd.torch.constants import (
SCHEDULER_STEP_EPOCH,
NUM_STEPS,
SCHEDULER_STEP_BATCH,
)
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler, DataLoader, IterableDataset
logger = logging.getLogger(__name__)
amp = None
try:
@@ -18,6 +31,7 @@ except ImportError:
# Apex library is not installed, so we cannot enable mixed precision.
# We don't log here because logging happens in the torch_runner,
# where amp is initialized.
logger.debug("apex is not installed.")
pass
tqdm = None
@@ -33,15 +47,58 @@ def _is_multiple(component):
class TrainingOperator:
"""Abstract class for custom training or validation loops.
"""Abstract class to define training and validation state and logic.
The scheduler will only be called at a batch or epoch frequency, depending
on the user parameter. Be sure to set ``scheduler_step_freq`` in
``TorchTrainer`` to either "batch" or "epoch" to increment the scheduler
correctly during training. If using a learning rate scheduler
that depends on validation loss, you can use ``trainer.update_scheduler``.
You must subclass this class and override the ``setup`` method to define
your training components such as the model, optimizer, data, loss,
and scheduler. When you pass this class to ``TorchTrainer``, a copy of
this class will be made on each worker.
For both training and validation, there are two granularities that
.. code-block:: python
class MyTrainingOperator(TrainingOperator):
def setup(self, config):
model = nn.Linear(1, 1)
optimizer = torch.optim.SGD(
model.parameters(), lr=config.get("lr", 1e-4))
loss = torch.nn.MSELoss()
batch_size = config["batch_size"]
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
self.model, self.optimizer = self.register(
models=model,
optimizers=optimizer,
criterion=loss)
self.register_data(
train_loader=train_loader,
validation_loader=val_loader)
trainer = TorchTrainer(
training_operator_cls=MyTrainingOperator,
config={"batch_size": 32},
use_gpu=True
)
for i in range(4):
trainer.train()
This class provides default implementations for training and validation.
Set ``self.model``, ``self.optimizer``, and
``self.criterion`` to leverage the default training and validation loops.
If ``self.scheduler`` is set, it will only be called at a batch or epoch
frequency, depending on the user parameter. Set
``scheduler_step_freq`` in ``TorchTrainer`` to either "batch" or "epoch"
to increment the scheduler correctly during training. If using a
learning rate scheduler that depends on validation loss, you can use
``trainer.update_scheduler``.
If you want to provide custom training and validation loops, you can do
so using this class as well. There are two granularities that
you can provide customization: per epoch or per batch.
You do not need to override both.
@@ -49,41 +106,30 @@ class TrainingOperator:
:scale: 80%
:align: center
If you are using multiple models, optimizers, or schedulers, you must
implement custom training and validation.
Raises:
ValueError if multiple models/optimizers/schedulers are provided.
You are expected to subclass this class if you wish
to train over multiple models/optimizers/schedulers.
ValueError
You are expected to either set ``self.model``,
``self.optimizer``, and ``self.criterion`` instance attributes in
setup or implement custom training & validation.
"""
def __init__(self,
config,
models,
optimizers,
train_loader,
validation_loader,
world_rank,
criterion=None,
schedulers=None,
device_ids=None,
use_gpu=False,
use_fp16=False,
use_tqdm=False):
use_tqdm=False,
apex_args=None,
wrap_ddp=False,
wrap_distributed_sampler=False,
add_dist_sampler=False,
scheduler_step_freq=None):
# You are not expected to override this method.
self._models = models # List of models
assert isinstance(
models,
Iterable), (f"Components need to be iterable. Got: {type(models)}")
self._optimizers = optimizers # List of optimizers
assert isinstance(optimizers, Iterable), (
f"Components need to be iterable. Got: {type(optimizers)}")
self._train_loader = train_loader
self._validation_loader = validation_loader
self._world_rank = world_rank
self._criterion = criterion
self._schedulers = schedulers
if schedulers:
assert isinstance(schedulers, Iterable), (
f"Components need to be iterable. Got: {type(schedulers)}")
self._config = config
self._use_fp16 = use_fp16
self._device_ids = device_ids
@@ -93,14 +139,12 @@ class TrainingOperator:
raise ValueError("tqdm must be installed to use tqdm in training.")
self._use_tqdm = use_tqdm
self.global_step = 0
self._apex_args = apex_args if apex_args else {}
self._wrap_ddp = wrap_ddp
self._wrap_distributed_sampler = wrap_distributed_sampler
self._add_dist_sampler = add_dist_sampler
self._scheduler_step_freq = scheduler_step_freq
if type(self) is TrainingOperator:
for component in (models, schedulers, optimizers):
if _is_multiple(component):
raise ValueError(
"Need to provide a custom operator subclassing "
"TrainingOperator if using multi-scheduler, "
"multi-model or multi-optimizer training/validation.")
self.timers = TimerCollection()
self.setup(config)
@@ -109,13 +153,233 @@ class TrainingOperator:
self.timers = timers
def setup(self, config):
"""Override this method to implement custom operator setup.
"""Override this method to implement operator setup.
You should call self.register and self.register_data here to
register training components and data loaders with Ray SGD.
Args:
config (dict): Custom configuration value to be passed to
all creator and operator constructors. Same as ``self.config``.
"""
pass
raise NotImplementedError
def register(self, *, models, optimizers, criterion=None, schedulers=None):
"""Registers parameters with Ray SGD and sets up training components.
By calling this method to register your models, optimizers,
criterion, and schedulers, Ray SGD will automatically handle
necessary setup such as GPU/devices, Distributed Data Parallel, and
Fp16. The registered components are returned and should be set as
instance attributes to access during training/validation.
If more than one model, optimizer, or scheduler is passed in,
you should implement your own custom training loop.
.. code-block:: python
class MyTrainingOperator(TrainingOperator):
def setup(self, config):
model = ...
optimizer = ...
train_loader = ...
val_loader = ...
loss = ...
self.model, self.optimizer, self.criterion = self.register(
models=model, optimizers=optimizer, criterion=loss)
# At this point DDP, Cuda, and Fp16
# are set up for all our components. We then use
# self.model, self.optimizer, etc. in our training loop.
self.register_data(train_loader=train_loader,
validation_loader=val_loader)
Args:
models (torch.nn.Module or Iterable[nn.Module]): Pytorch model or
multiple Pytorch models to use for training. If
`use_gpu=True` is passed into ``TorchTrainer``, and Cuda is
available, models will automatically be placed on GPU.
If ``wrap_ddp=True`` is passed into ``TorchTrainer``,
models will be wrapped in DDP. If wrap_ddp is False,
you should handle DDP for your models in setup.
optimizers (torch.optim.Optimizer or Iterable[
torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch
optimizers to use for training.
criterion (Callable, optional): Function to return loss
metric given features and target. If not provided,
must implement a custom training loop.
schedulers (torch.optim.lr_scheduler or Iterable[
torch.optim.lr_scheduler], optional): A learning rate
scheduler or multiple learning rate schedulers.
Returns:
Tuple of model, optimizer, criterion if not None, and scheduler
if not None.
"""
return_vals = []
logger.debug("Registering models.")
self._original_models = models
if not isinstance(self._original_models, Iterable):
self._original_models = [self._original_models]
assert all(
isinstance(model, nn.Module) for model in self._original_models), (
f"All models must be PyTorch models: {self._original_models}.")
if self.use_gpu and torch.cuda.is_available():
self._original_models = [
model.cuda() for model in self._original_models
]
logger.debug("Registering optimizers.")
self._optimizers = optimizers
if not isinstance(self._optimizers, Iterable):
self._optimizers = [self._optimizers]
if schedulers:
logger.debug("Registering scheduler.")
self._schedulers = schedulers
if not isinstance(self._schedulers, Iterable):
self._schedulers = [self._schedulers]
else:
self._schedulers = None
if criterion:
logger.debug("Registering loss.")
self._criterion = criterion
if self.use_gpu and torch.cuda.is_available():
if hasattr(self._criterion, "cuda"):
self._criterion = self._criterion.cuda()
else:
self._criterion = None
logger.debug("Setting up Apex.")
if self.use_fp16 and amp:
self._models, self._optimizers = amp.initialize(
self._models, self._optimizers, **self._apex_args)
self._amp = amp
if self._wrap_ddp:
logging.debug("Setting up DDP for models.")
self._models = [
DistributedDataParallel(model, device_ids=self.device_ids)
for model in self._original_models
]
else:
self._models = self._original_models
if len(self._models) == 1:
return_vals.append(self._models[0])
else:
return_vals.append(self._models)
if len(self._optimizers) == 1:
return_vals.append(self._optimizers[0])
else:
return_vals.append(self._optimizers)
if self._criterion is not None:
return_vals.append(self._criterion)
if self._schedulers is not None:
if self.scheduler_step_freq is None:
raise ValueError("scheduler_step_freq passed into "
"TorchTrainer cannot be None if you "
"are registering schedulers. Set this to "
"'manual' if you will be manually stepping "
"the schedulers.")
if len(self._schedulers) == 1:
return_vals.append(self._schedulers[0])
else:
return_vals.append(self._schedulers)
return tuple(return_vals)
def register_data(self, *, train_loader=None, validation_loader=None):
"""Registers data loaders with Ray SGD.
Calling this method will automatically setup Distributed Sampler for
these data loaders if add_dist_sampler=True is passed into the
TorchTrainer. This method does not return the wrapped data loaders.
You should use the iterators passed into train_epoch and validate
instead.
.. code-block:: python
class MyTrainingOperator(TrainingOperator):
def setup(self, config):
model = ...
optimizer = ...
train_loader = ...
val_loader = ...
loss = ...
self.model, self.optimizer, self.criterion = self.register(
models=model, optimizers=optimizer, criterion=loss)
self.register_data(train_loader=train_loader,
validation_loader=val_loader)
# At this point the data loaders are registered with
# Ray SGD and are wrapped with Distributed Samplers if
# applicable.
def train_epoch(self, iterator, info):
# If providing custom training or validation methods,
# the registered data loaders are passed in through the
# iterator parameter.
...
Args:
train_loader (Iterator): An iterator for training
data. If None is explicitly passed in, a Ray SGD Dataset
must be passed in through TorchTrainer.train. Ray SGD will
automatically use a Distributed Sampler if TorchTrainer(...,
add_dist_sampler=True).
validation_loader (Iterator): An iterator for validation
data. Ray SGD will automatically use a Distributed Sampler
if TorchTrainer(..., add_dist_sampler=True).
"""
logger.debug("Registering data loaders..")
self._train_loader = train_loader
self._validation_loader = validation_loader
if self._wrap_distributed_sampler:
logging.debug("Wrapping data loaders with DistributedSampler.")
def with_sampler(loader):
# Automatically set the DistributedSampler
data_loader_args = {
"dataset": loader.dataset,
"batch_size": loader.batch_size,
"shuffle": False,
"num_workers": loader.num_workers,
"collate_fn": loader.collate_fn,
"pin_memory": loader.pin_memory,
"drop_last": loader.drop_last,
"timeout": loader.timeout,
"worker_init_fn": loader.worker_init_fn,
"sampler": DistributedSampler(loader.dataset)
}
return DataLoader(**data_loader_args)
def should_wrap_dataloader(loader):
return (isinstance(loader, DataLoader)
and not isinstance(loader.dataset, IterableDataset))
if should_wrap_dataloader(self._train_loader):
if self._add_dist_sampler:
self._train_loader = with_sampler(self._train_loader)
if self._validation_loader is not None and should_wrap_dataloader(
self._validation_loader):
if self._add_dist_sampler:
self._validation_loader = with_sampler(
self._validation_loader)
def train_epoch(self, iterator, info):
"""Runs one standard training pass over the training dataloader.
@@ -156,6 +420,15 @@ class TrainingOperator:
Returns:
A dict of metrics from training.
"""
if not hasattr(self, "model"):
raise RuntimeError("Either set self.model in setup function or "
"override this method to implement a custom "
"training loop.")
model = self.model
scheduler = None
if hasattr(self, "scheduler"):
scheduler = self.scheduler
if self.use_tqdm and self.world_rank == 0:
desc = ""
if info is not None and "epoch_idx" in info:
@@ -163,15 +436,19 @@ class TrainingOperator:
desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
else:
desc = f"{info['epoch_idx'] + 1}e"
# TODO: Implement len for Dataset?
total = info[NUM_STEPS]
if total is None:
if hasattr(iterator, "__len__"):
total = len(iterator)
_progress_bar = tqdm(
total=info[NUM_STEPS] or len(self.train_loader),
desc=desc,
unit="batch",
leave=False)
total=total, desc=desc, unit="batch", leave=False)
metric_meters = AverageMeterCollection()
self.model.train()
model.train()
for batch_idx, batch in enumerate(iterator):
batch_info = {
"batch_idx": batch_idx,
@@ -187,15 +464,14 @@ class TrainingOperator:
postfix.update(loss=metrics["train_loss"])
_progress_bar.set_postfix(postfix)
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
self.scheduler.step()
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH:
scheduler.step()
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
self.scheduler.step()
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH:
scheduler.step()
return metric_meters.summary()
@@ -211,9 +487,7 @@ class TrainingOperator:
automatically.
You can provide custom loss metrics and training operations if you
override this method. If overriding this method, you can access model,
optimizer, criterion via ``self.model``, ``self.optimizer``,
and ``self.criterion``.
override this method.
You do not need to override this method if you plan to
override ``train_epoch``.
@@ -232,6 +506,21 @@ class TrainingOperator:
calculate averages.
"""
if not hasattr(self, "model"):
raise RuntimeError("Either set self.model in setup function or "
"override this method to implement a custom "
"training loop.")
if not hasattr(self, "optimizer"):
raise RuntimeError("Either set self.optimizer in setup function "
"or override this method to implement a custom "
"training loop.")
if not hasattr(self, "criterion"):
raise RuntimeError("Either set self.criterion in setup function "
"or override this method to implement a custom "
"training loop.")
model = self.model
optimizer = self.optimizer
criterion = self.criterion
# unpack features into list to support multiple inputs model
*features, target = batch
# Create non_blocking tensors for distributed training
@@ -243,21 +532,21 @@ class TrainingOperator:
# Compute output.
with self.timers.record("fwd"):
output = self.model(*features)
loss = self.criterion(output, target)
output = model(*features)
loss = criterion(output, target)
# Compute gradients in a backward pass.
with self.timers.record("grad"):
self.optimizer.zero_grad()
optimizer.zero_grad()
if self.use_fp16:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# Call step of optimizer to update model params.
with self.timers.record("apply"):
self.optimizer.step()
optimizer.step()
return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)}
@@ -267,9 +556,8 @@ class TrainingOperator:
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
over the validation dataloader.
If overriding this method, you can access model, criterion via
``self.model`` and ``self.criterion``. You also do not need to call
``validate_batch`` if overriding this method.
You also do not need to call ``validate_batch`` if overriding this
method.
Args:
val_iterator (iter): Iterable constructed from the
@@ -284,10 +572,15 @@ class TrainingOperator:
from ``validate_batch`` and dividing it by the sum of
``num_samples`` from all calls to ``self.validate_batch``.
"""
if not hasattr(self, "model"):
raise RuntimeError("Either set self.model in setup function or "
"override this method to implement a custom "
"validation loop.")
model = self.model
metric_meters = AverageMeterCollection()
# switch to evaluate mode
self.model.eval()
model.eval()
with torch.no_grad():
for batch_idx, batch in enumerate(val_iterator):
batch_info = {"batch_idx": batch_idx}
@@ -319,6 +612,16 @@ class TrainingOperator:
by default, ``validate`` uses "num_samples" to
calculate averages.
"""
if not hasattr(self, "model"):
raise RuntimeError("Either set self.model in setup function or "
"override this method to implement a custom "
"training loop.")
if not hasattr(self, "criterion"):
raise RuntimeError("Either set self.criterion in setup function "
"or override this method to implement a custom "
"training loop.")
model = self.model
criterion = self.criterion
# unpack features into list to support multiple inputs model
*features, target = batch
if self.use_gpu:
@@ -330,8 +633,8 @@ class TrainingOperator:
# compute output
with self.timers.record("eval_fwd"):
output = self.model(*features)
loss = self.criterion(output, target)
output = model(*features)
loss = criterion(output, target)
_, predicted = torch.max(output.data, 1)
num_correct = (predicted == target).sum().item()
@@ -344,6 +647,9 @@ class TrainingOperator:
def state_dict(self):
"""Override this to return a representation of the operator state.
Any argument passed into self.register and self.register_data will
automatically be saved.
Use this method to save any additional state.
Returns:
dict: The state dict of the operator."""
@@ -351,11 +657,81 @@ class TrainingOperator:
def load_state_dict(self, state_dict):
"""Override this to load the representation of the operator state.
Anything passed into self.register and self.register_data will
automatically be loaded. Use this method to load any additional state.
Args:
state_dict (dict): State dict as returned by the operator. """
pass
@classmethod
def from_creators(cls,
model_creator,
optimizer_creator,
data_creator=None,
loss_creator=None,
scheduler_creator=None,
serialize_data_creation=True):
"""A utility method to create a custom TrainingOperator class from
creator functions. This is useful for backwards compatibility with
previous versions of Ray. To provide custom training and validation,
you should subclass the class that is returned by this method instead
of ``TrainingOperator``.
Args:
model_creator (dict -> Model(s)): Constructor function that takes
in config and returns the model(s) to be optimized. These
must be ``torch.nn.Module`` objects. If multiple models are
returned, a ``training_operator_cls`` must be specified.
You do not need to handle GPU/devices in this function;
RaySGD will do that under the hood.
data_creator (dict -> Iterable(s)): Constructor function
that takes in the passed config and returns one or
two Iterable objects. Note that even though two Iterable
objects can be returned, only one will be used for training,
and the other will be used for validation. If not provided,
you must pass in a Dataset to ``TorchTrainer.train``.
optimizer_creator ((models, dict) -> optimizers): Constructor
function that takes in the return values from
``model_creator`` and the passed config and returns One or
more Torch optimizer objects. You do not need to handle
GPU/devices in this function; ``RaySGD`` will do that for you.
loss_creator (torch.nn.*Loss class | dict -> loss): A constructor
function for the training loss. This can be either a function
that takes in the provided config for customization or a
subclass of ``torch.nn.modules.loss._Loss``, which is most
Pytorch loss classes. For example,
``loss_creator=torch.nn.BCELoss``. If not provided, you must
provide a custom TrainingOperator.
scheduler_creator ((optimizers, dict) -> scheduler):
A constructor function for the torch scheduler. This is
a function that takes in the generated optimizers (from
``optimizer_creator``) provided config for customization.
Be sure to set ``scheduler_step_freq`` to increment the
scheduler correctly.
serialize_data_creation (bool): A filelock will be used
to ensure no race conditions in data downloading among
different workers on the same node (using the local file
system). Defaults to True.
Returns:
A TrainingOperator class with a ``setup`` method that utilizes
the passed in creator functions.
"""
if not (callable(model_creator) and callable(optimizer_creator)):
raise ValueError(
"Must provide a callable model_creator and optimizer_creator.")
class CustomCreatorOperator(CreatorOperator):
_model_creator = model_creator
_optimizer_creator = optimizer_creator
_data_creator = data_creator
_loss_creator = loss_creator
_scheduler_creator = scheduler_creator
_serialize_data_creation = serialize_data_creation
return CustomCreatorOperator
@property
def device(self):
"""torch.device: The appropriate torch device, at your convenience."""
@@ -366,58 +742,11 @@ class TrainingOperator:
"""dict: Provided into TorchTrainer."""
return self._config
@property
def model(self):
"""First or only model created by the provided ``model_creator``."""
return self._models[0]
@property
def models(self):
"""List of models created by the provided ``model_creator``."""
return self._models
@property
def optimizer(self):
"""First or only optimizer(s) created by the ``optimizer_creator``."""
return self._optimizers[0]
@property
def optimizers(self):
"""List of optimizers created by the ``optimizer_creator``."""
return self._optimizers
@property
def train_loader(self):
"""Iterable: 1st Dataloader from ``data_creator``.
"""
return self._train_loader
@property
def validation_loader(self):
"""Iterable: 2nd Dataloader from ``data_creator``."""
return self._validation_loader
@property
def world_rank(self):
"""int: The rank of the parent runner. Always 0 if not distributed."""
return self._world_rank
@property
def criterion(self):
"""Criterion created by the provided ``loss_creator``."""
return self._criterion
@property
def scheduler(self):
"""First or only scheduler(s) created by the ``scheduler_creator``."""
if self._schedulers:
return self._schedulers[0]
@property
def schedulers(self):
"""List of schedulers created by the ``scheduler_creator``."""
return self._schedulers
@property
def use_gpu(self):
"""Returns True if cuda is available and use_gpu is True."""
@@ -441,31 +770,138 @@ class TrainingOperator:
"""
return self._device_ids
@property
def scheduler_step_freq(self):
"""Optional[str]: The ``scheduler_step_freq`` passed into
``TorchTrainer``
class _TestingOperator(TrainingOperator):
def train_epoch(self, iterator, info):
func = self.config.get("custom_func")
if callable(func):
return func(self, iterator, info)
return {"done": 1}
This is useful to determine when to call scheduler.step.
"""
return self._scheduler_step_freq
class _TestMetricsOperator(TrainingOperator):
class CreatorOperator(TrainingOperator):
"""A subclass of TrainingOperator specifically for defining training
state using creator functions.
"""
def _validate_loaders(self, loaders):
assert loaders, "Loaders need to be returned in data_creator."
if isinstance(loaders, (tuple, list)):
if len(loaders) == 1:
return loaders, None
elif len(loaders) == 2:
return loaders
else:
raise ValueError(
f"Number of loaders must be <= 2. Got {loaders}")
# No great way of checking type otherwise
return loaders, None
def _initialize_dataloaders(self, config):
logger.debug("Instantiating dataloaders.")
loaders = None
if self._serialize_data_creation:
logger.debug("Serializing the dataloading process.")
with FileLock(
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
loaders = self.__class__._data_creator(config)
else:
loaders = self.__class__._data_creator(config)
train_loader, val_loader = self._validate_loaders(loaders)
return train_loader, val_loader
def setup(self, config):
self._train_scores = config["scores"].copy()
self._val_scores = config["val_scores"].copy()
self.key = config["key"]
kwargs = {}
logger.debug("Loading data.")
train_loader = None
validation_loader = None
if self.__class__._data_creator and callable(
self.__class__._data_creator):
train_loader, validation_loader = self._initialize_dataloaders(
config)
def train_batch(self, batch, batch_info=None):
metrics = super(_TestMetricsOperator, self).train_batch(
batch, batch_info)
num_samples = metrics[NUM_SAMPLES]
metrics.update({self.key: self._train_scores.pop(0) / num_samples})
return metrics
logger.debug("Creating model")
models = self.__class__._model_creator(config)
def validate_batch(self, batch, batch_info=None):
metrics = super(_TestMetricsOperator, self).validate_batch(
batch, batch_info)
num_samples = metrics[NUM_SAMPLES]
metrics.update({self.key: self._val_scores.pop(0) / num_samples})
return metrics
kwargs["models"] = models
logger.debug("Creating optimizer.")
optimizers = self.__class__._optimizer_creator(models, config)
kwargs["optimizers"] = optimizers
if self.__class__._scheduler_creator:
logger.debug("Creating scheduler.")
schedulers = self.__class__._scheduler_creator(optimizers, config)
kwargs["schedulers"] = schedulers
if self.__class__._loss_creator:
logger.debug("Creating loss.")
if inspect.isclass(self.__class__._loss_creator) and issubclass(
self.__class__._loss_creator, torch.nn.modules.loss._Loss):
criterion = self.__class__._loss_creator()
else:
criterion = self.__class__._loss_creator(config)
kwargs["criterion"] = criterion
state = self.register(**kwargs)
self.models, self.optimizers = state[:2]
if isinstance(self.models, tuple):
self.model = self.models[0]
else:
self.model = self.models
if isinstance(self.optimizers, tuple):
self.optimizer = self.optimizers[0]
else:
self.optimizer = self.optimizers
if len(state) >= 3:
self.criterion = state[2]
if len(state) == 4:
self.schedulers = state[3]
if isinstance(self.schedulers, tuple):
self.scheduler = self.schedulers[0]
else:
self.scheduler = self.schedulers
self.register_data(
train_loader=train_loader, validation_loader=validation_loader)
def get_test_operator(operator_cls):
class _TestingOperator(operator_cls):
def train_epoch(self, iterator, info):
func = self.config.get("custom_func")
if callable(func):
return func(self, iterator, info)
return {"done": 1}
return _TestingOperator
def get_test_metrics_operator(operator_cls):
class _TestMetricsOperator(operator_cls):
def setup(self, config):
super(_TestMetricsOperator, self).setup(config)
self._train_scores = config["scores"].copy()
self._val_scores = config["val_scores"].copy()
self.key = config["key"]
def train_batch(self, batch, batch_info=None):
metrics = super(_TestMetricsOperator, self).train_batch(
batch, batch_info)
num_samples = metrics[NUM_SAMPLES]
metrics.update({self.key: self._train_scores.pop(0) / num_samples})
return metrics
def validate_batch(self, batch, batch_info=None):
metrics = super(_TestMetricsOperator, self).validate_batch(
batch, batch_info)
num_samples = metrics[NUM_SAMPLES]
metrics.update({self.key: self._val_scores.pop(0) / num_samples})
return metrics
return _TestMetricsOperator