From 584adb45b87a2414ba55b0ad75504cb356a9cb30 Mon Sep 17 00:00:00 2001 From: cgraywang Date: Wed, 8 May 2019 21:39:07 -0700 Subject: [PATCH] [tune] Add MXNet Gluon example on CIFAR-10 (#4683) --- python/ray/tune/examples/README.rst | 2 + .../ray/tune/examples/tune_cifar10_gluon.py | 224 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 python/ray/tune/examples/tune_cifar10_gluon.py diff --git a/python/ray/tune/examples/README.rst b/python/ray/tune/examples/README.rst index 172bb4ef4..79fef2c05 100644 --- a/python/ray/tune/examples/README.rst +++ b/python/ray/tune/examples/README.rst @@ -60,3 +60,5 @@ Contributed Examples A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler. - `genetic_example `__: Optimizing the michalewicz function using the contributed GeneticSearch search algorithm with AsyncHyperBandScheduler. +- `tune_cifar10_gluon `__: + MXNet Gluon example to use Tune with the function-based API on CIFAR-10 dataset. diff --git a/python/ray/tune/examples/tune_cifar10_gluon.py b/python/ray/tune/examples/tune_cifar10_gluon.py new file mode 100644 index 000000000..63db20186 --- /dev/null +++ b/python/ray/tune/examples/tune_cifar10_gluon.py @@ -0,0 +1,224 @@ +from __future__ import print_function + +import argparse +import random + +import mxnet as mx +import numpy as np + +from mxnet import gluon, init +from mxnet import autograd as ag +from mxnet.gluon import nn +from mxnet.gluon.data.vision import transforms +from gluoncv.model_zoo import get_model +from gluoncv.data import transforms as gcv_transforms + +# Training settings +parser = argparse.ArgumentParser(description="CIFAR-10 Example") +parser.add_argument( + "--model", + required=True, + type=str, + default="resnet50_v1b", + help="name of the pretrained model from gluoncv model zoo" + "(default: resnet50_v1b).") +parser.add_argument( + "--batch_size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)") +parser.add_argument( + "--epochs", + type=int, + default=1, + metavar="N", + help="number of epochs to train (default: 1)") +parser.add_argument( + "--num_gpus", + default=0, + type=int, + help="number of gpus to use, 0 indicates cpu only (default: 0)") +parser.add_argument( + "--num_workers", + default=4, + type=int, + help="number of preprocessing workers (default: 4)") +parser.add_argument( + "--classes", + type=int, + default=10, + metavar="N", + help="number of outputs (default: 10)") +parser.add_argument( + "--lr", + default=0.001, + type=float, + help="initial learning rate (default: 0.001)") +parser.add_argument( + "--momentum", + default=0.9, + type=float, + help="initial momentum (default: 0.9)") +parser.add_argument( + "--wd", default=1e-4, type=float, help="weight decay (default: 1e-4)") +parser.add_argument( + "--expname", type=str, default="cifar10exp", help="experiments location") +parser.add_argument( + "--num_samples", + type=int, + default=20, + metavar="N", + help="number of samples (default: 20)") +parser.add_argument( + "--scheduler", + type=str, + default="fifo", + help="FIFO or AsyncHyperBandScheduler.") +parser.add_argument( + "--seed", + type=int, + default=1, + metavar="S", + help="random seed (default: 1)") +parser.add_argument( + "--smoke_test", action="store_true", help="Finish quickly for testing") +args = parser.parse_args() + + +def train_cifar10(args, config, reporter): + vars(args).update(config) + np.random.seed(args.seed) + random.seed(args.seed) + mx.random.seed(args.seed) + + # Set Hyper-params + batch_size = args.batch_size * max(args.num_gpus, 1) + ctx = [mx.gpu(i) + for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()] + + # Define DataLoader + transform_train = transforms.Compose([ + gcv_transforms.RandomCrop(32, pad=4), + transforms.RandomFlipLeftRight(), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], + [0.2023, 0.1994, 0.2010]) + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], + [0.2023, 0.1994, 0.2010]) + ]) + + train_data = gluon.data.DataLoader( + gluon.data.vision.CIFAR10(train=True).transform_first(transform_train), + batch_size=batch_size, + shuffle=True, + last_batch="discard", + num_workers=args.num_workers) + + test_data = gluon.data.DataLoader( + gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), + batch_size=batch_size, + shuffle=False, + num_workers=args.num_workers) + + # Load model architecture and Initialize the net with pretrained model + finetune_net = get_model(args.model, pretrained=True) + with finetune_net.name_scope(): + finetune_net.fc = nn.Dense(args.classes) + finetune_net.fc.initialize(init.Xavier(), ctx=ctx) + finetune_net.collect_params().reset_ctx(ctx) + finetune_net.hybridize() + + # Define trainer + trainer = gluon.Trainer(finetune_net.collect_params(), "sgd", { + "learning_rate": args.lr, + "momentum": args.momentum, + "wd": args.wd + }) + L = gluon.loss.SoftmaxCrossEntropyLoss() + metric = mx.metric.Accuracy() + + def train(epoch): + for i, batch in enumerate(train_data): + data = gluon.utils.split_and_load( + batch[0], ctx_list=ctx, batch_axis=0, even_split=False) + label = gluon.utils.split_and_load( + batch[1], ctx_list=ctx, batch_axis=0, even_split=False) + with ag.record(): + outputs = [finetune_net(X) for X in data] + loss = [L(yhat, y) for yhat, y in zip(outputs, label)] + for l in loss: + l.backward() + + trainer.step(batch_size) + mx.nd.waitall() + + def test(): + test_loss = 0 + for i, batch in enumerate(test_data): + data = gluon.utils.split_and_load( + batch[0], ctx_list=ctx, batch_axis=0, even_split=False) + label = gluon.utils.split_and_load( + batch[1], ctx_list=ctx, batch_axis=0, even_split=False) + outputs = [finetune_net(X) for X in data] + loss = [L(yhat, y) for yhat, y in zip(outputs, label)] + + test_loss += sum(l.mean().asscalar() for l in loss) / len(loss) + metric.update(label, outputs) + + _, test_acc = metric.get() + test_loss /= len(test_data) + reporter(mean_loss=test_loss, mean_accuracy=test_acc) + + for epoch in range(1, args.epochs + 1): + train(epoch) + test() + + +if __name__ == "__main__": + args = parser.parse_args() + + import ray + from ray import tune + from ray.tune.schedulers import AsyncHyperBandScheduler, FIFOScheduler + + ray.init() + if args.scheduler == "fifo": + sched = FIFOScheduler() + elif args.scheduler == "asynchyperband": + sched = AsyncHyperBandScheduler( + time_attr="training_iteration", + reward_attr="neg_mean_loss", + max_t=400, + grace_period=60) + else: + raise NotImplementedError + tune.register_trainable( + "TRAIN_FN", + lambda config, reporter: train_cifar10(args, config, reporter)) + tune.run( + "TRAIN_FN", + name=args.expname, + verbose=2, + scheduler=sched, + **{ + "stop": { + "mean_accuracy": 0.98, + "training_iteration": 1 if args.smoke_test else args.epochs + }, + "resources_per_trial": { + "cpu": int(args.num_workers), + "gpu": int(args.num_gpus) + }, + "num_samples": 1 if args.smoke_test else args.num_samples, + "config": { + "lr": tune.sample_from( + lambda spec: np.power(10.0, np.random.uniform(-4, -1))), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.85, 0.95)), + } + })