From 8707a721d97cd4bf337eea213429bf7c61d1e376 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 25 Dec 2019 17:10:09 -0800 Subject: [PATCH] [tune] update params for optimizer in reset_config (#6522) * reset config update lr * add default * Update pbt_dcgan_mnist.py * Update pbt_convnet_example.py Co-authored-by: Richard Liaw --- python/ray/tune/examples/pbt_convnet_example.py | 12 +++++++----- .../examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py | 17 +++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index cfe62c120..489aa3504 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -52,11 +52,13 @@ class PytorchTrainble(tune.Trainable): self.model.load_state_dict(torch.load(checkpoint_path)) def reset_config(self, new_config): - del self.optimizer - self.optimizer = optim.SGD( - self.model.parameters(), - lr=new_config.get("lr", 0.01), - momentum=new_config.get("momentum", 0.9)) + for param_group in self.optimizer.param_groups: + if "lr" in new_config: + param_group["lr"] = new_config["lr"] + if "momentum" in new_config: + param_group["momentum"] = new_config["momentum"] + + self.config = new_config return True diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py index 83862a2ab..d4d7948f2 100644 --- a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py +++ b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py @@ -275,16 +275,13 @@ class PytorchTrainable(tune.Trainable): self.optimizerG.load_state_dict(checkpoint["optimG"]) def reset_config(self, new_config): - del self.optimizerD - del self.optimizerG - self.optimizerD = optim.Adam( - self.netD.parameters(), - lr=new_config.get("netD_lr"), - betas=(beta1, 0.999)) - self.optimizerG = optim.Adam( - self.netG.parameters(), - lr=new_config.get("netG_lr"), - betas=(beta1, 0.999)) + if "netD_lr" in new_config: + for param_group in self.optimizerD.param_groups: + param_group["lr"] = new_config["netD_lr"] + if "netG_lr" in new_config: + for param_group in self.optimizerG.param_groups: + param_group["lr"] = new_config["netG_lr"] + self.config = new_config return True