[tune/sgd] DCGAN example self-contained, turn example into modu… (#8012)

* ok

* done

* run_benchmarks

* should_make_examples_usable
This commit is contained in:
Richard Liaw
2020-04-16 17:55:27 -07:00
committed by GitHub
parent 0c80efa2a3
commit 6545534805
6 changed files with 44 additions and 11 deletions
@@ -49,6 +49,8 @@ beta1 = 0.5
# iterations of actual training in each Trainable _train
train_iterations_per_step = 5
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
def get_data_loader():
dataset = dset.MNIST(
@@ -305,6 +307,16 @@ if __name__ == "__main__":
args, _ = parser.parse_known_args()
ray.init()
import urllib.request
# Download a pre-trained MNIST model for inception score calculation.
# This is a tiny model (<100kb).
if not os.path.exists(MODEL_PATH):
print("downloading model")
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
urllib.request.urlretrieve(
"https://github.com/ray-project/ray/raw/master/python/ray/tune/"
"examples/pbt_dcgan_mnist/mnist_cnn.pt", MODEL_PATH)
dataloader = get_data_loader()
if not args.smoke_test:
# Plot some training images
@@ -322,10 +334,7 @@ if __name__ == "__main__":
# load the pretrained mnist classification model for inception_score
mnist_cnn = Net()
model_path = os.path.join(
os.path.dirname(ray.__file__),
"tune/examples/pbt_dcgan_mnist/mnist_cnn.pt")
mnist_cnn.load_state_dict(torch.load(model_path))
mnist_cnn.load_state_dict(torch.load(MODEL_PATH))
mnist_cnn.eval()
mnist_model_ref = ray.put(mnist_cnn)
@@ -16,6 +16,8 @@ from ray.util.sgd.torch import TrainingOperator
parser = argparse.ArgumentParser(
description="PyTorch Synthetic Benchmark",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--smoke-test", action="store_true", default=False, help="finish quickly.")
parser.add_argument(
"--fp16", action="store_true", default=False, help="use fp16 training")
@@ -49,6 +51,16 @@ parser.add_argument(
help="Disables cluster training")
args = parser.parse_args()
if args.smoke_test:
args.model = "resnet18"
args.batch_size = 1
args.num_iters = 1
args.num_batches_per_iter = 2
args.num_warmup_batches = 2
args.local = True
args.no_cuda = True
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = "GPU" if args.cuda else "CPU"
@@ -68,7 +80,6 @@ class Training(TrainingOperator):
self.data, self.target = data, target
def train_epoch(self, *pargs, **kwargs):
# print(self.model)
def benchmark():
self.optimizer.zero_grad()
output = self.model(self.data)
@@ -76,11 +87,11 @@ class Training(TrainingOperator):
loss.backward()
self.optimizer.step()
# print("Running warmup...")
print("Running warmup...")
if self.global_step == 0:
timeit.timeit(benchmark, number=args.num_warmup_batches)
self.global_step += 1
# print("Running benchmark...")
print("Running benchmark...")
time = timeit.timeit(benchmark, number=args.num_batches_per_iter)
img_sec = args.batch_size * args.num_batches_per_iter / time
return {"img_sec": img_sec}
@@ -99,7 +110,7 @@ if __name__ == "__main__":
model_creator=lambda cfg: getattr(models, args.model)(),
optimizer_creator=lambda model, cfg: optim.SGD(
model.parameters(), lr=0.01 * cfg.get("lr_scaler")),
data_creator=lambda cfg: LinearDataset(4, 2),
data_creator=lambda cfg: LinearDataset(4, 2), # Mock dataset.
initialization_hook=init_hook,
config=dict(
lr_scaler=num_workers),
+13 -3
View File
@@ -21,6 +21,8 @@ from ray.util.sgd import TorchTrainer
from ray.util.sgd.utils import override
from ray.util.sgd.torch import TrainingOperator
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
def data_creator(config):
dataset = datasets.MNIST(
@@ -227,9 +229,7 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
config = {
"test_mode": test_mode,
"batch_size": 16 if test_mode else 512 // num_workers,
"classification_model_path": os.path.join(
os.path.dirname(ray.__file__),
"util/sgd/torch/examples/mnist_cnn.pt")
"classification_model_path": MODEL_PATH
}
trainer = TorchTrainer(
model_creator=model_creator,
@@ -256,6 +256,16 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
if __name__ == "__main__":
import urllib.request
# Download a pre-trained MNIST model for inception score calculation.
# This is a tiny model (<100kb).
if not os.path.exists(MODEL_PATH):
print("downloading model")
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
urllib.request.urlretrieve(
"https://github.com/ray-project/ray/raw/master/python/ray/tune/"
"examples/pbt_dcgan_mnist/mnist_cnn.pt", MODEL_PATH)
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
Binary file not shown.