mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:54:27 +08:00
[tune/sgd] DCGAN example self-contained, turn example into modu… (#8012)
* ok * done * run_benchmarks * should_make_examples_usable
This commit is contained in:
@@ -49,6 +49,8 @@ beta1 = 0.5
|
||||
# iterations of actual training in each Trainable _train
|
||||
train_iterations_per_step = 5
|
||||
|
||||
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
|
||||
|
||||
|
||||
def get_data_loader():
|
||||
dataset = dset.MNIST(
|
||||
@@ -305,6 +307,16 @@ if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
import urllib.request
|
||||
# Download a pre-trained MNIST model for inception score calculation.
|
||||
# This is a tiny model (<100kb).
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print("downloading model")
|
||||
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
|
||||
urllib.request.urlretrieve(
|
||||
"https://github.com/ray-project/ray/raw/master/python/ray/tune/"
|
||||
"examples/pbt_dcgan_mnist/mnist_cnn.pt", MODEL_PATH)
|
||||
|
||||
dataloader = get_data_loader()
|
||||
if not args.smoke_test:
|
||||
# Plot some training images
|
||||
@@ -322,10 +334,7 @@ if __name__ == "__main__":
|
||||
|
||||
# load the pretrained mnist classification model for inception_score
|
||||
mnist_cnn = Net()
|
||||
model_path = os.path.join(
|
||||
os.path.dirname(ray.__file__),
|
||||
"tune/examples/pbt_dcgan_mnist/mnist_cnn.pt")
|
||||
mnist_cnn.load_state_dict(torch.load(model_path))
|
||||
mnist_cnn.load_state_dict(torch.load(MODEL_PATH))
|
||||
mnist_cnn.eval()
|
||||
mnist_model_ref = ray.put(mnist_cnn)
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ from ray.util.sgd.torch import TrainingOperator
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch Synthetic Benchmark",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", default=False, help="finish quickly.")
|
||||
parser.add_argument(
|
||||
"--fp16", action="store_true", default=False, help="use fp16 training")
|
||||
|
||||
@@ -49,6 +51,16 @@ parser.add_argument(
|
||||
help="Disables cluster training")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.smoke_test:
|
||||
args.model = "resnet18"
|
||||
args.batch_size = 1
|
||||
args.num_iters = 1
|
||||
args.num_batches_per_iter = 2
|
||||
args.num_warmup_batches = 2
|
||||
args.local = True
|
||||
args.no_cuda = True
|
||||
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
device = "GPU" if args.cuda else "CPU"
|
||||
|
||||
@@ -68,7 +80,6 @@ class Training(TrainingOperator):
|
||||
self.data, self.target = data, target
|
||||
|
||||
def train_epoch(self, *pargs, **kwargs):
|
||||
# print(self.model)
|
||||
def benchmark():
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(self.data)
|
||||
@@ -76,11 +87,11 @@ class Training(TrainingOperator):
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# print("Running warmup...")
|
||||
print("Running warmup...")
|
||||
if self.global_step == 0:
|
||||
timeit.timeit(benchmark, number=args.num_warmup_batches)
|
||||
self.global_step += 1
|
||||
# print("Running benchmark...")
|
||||
print("Running benchmark...")
|
||||
time = timeit.timeit(benchmark, number=args.num_batches_per_iter)
|
||||
img_sec = args.batch_size * args.num_batches_per_iter / time
|
||||
return {"img_sec": img_sec}
|
||||
@@ -99,7 +110,7 @@ if __name__ == "__main__":
|
||||
model_creator=lambda cfg: getattr(models, args.model)(),
|
||||
optimizer_creator=lambda model, cfg: optim.SGD(
|
||||
model.parameters(), lr=0.01 * cfg.get("lr_scaler")),
|
||||
data_creator=lambda cfg: LinearDataset(4, 2),
|
||||
data_creator=lambda cfg: LinearDataset(4, 2), # Mock dataset.
|
||||
initialization_hook=init_hook,
|
||||
config=dict(
|
||||
lr_scaler=num_workers),
|
||||
|
||||
@@ -21,6 +21,8 @@ from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.utils import override
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
|
||||
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
dataset = datasets.MNIST(
|
||||
@@ -227,9 +229,7 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
config = {
|
||||
"test_mode": test_mode,
|
||||
"batch_size": 16 if test_mode else 512 // num_workers,
|
||||
"classification_model_path": os.path.join(
|
||||
os.path.dirname(ray.__file__),
|
||||
"util/sgd/torch/examples/mnist_cnn.pt")
|
||||
"classification_model_path": MODEL_PATH
|
||||
}
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
@@ -256,6 +256,16 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import urllib.request
|
||||
# Download a pre-trained MNIST model for inception score calculation.
|
||||
# This is a tiny model (<100kb).
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print("downloading model")
|
||||
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
|
||||
urllib.request.urlretrieve(
|
||||
"https://github.com/ray-project/ray/raw/master/python/ray/tune/"
|
||||
"examples/pbt_dcgan_mnist/mnist_cnn.pt", MODEL_PATH)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user