diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index cfe62c120..489aa3504 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -52,11 +52,13 @@ class PytorchTrainble(tune.Trainable): self.model.load_state_dict(torch.load(checkpoint_path)) def reset_config(self, new_config): - del self.optimizer - self.optimizer = optim.SGD( - self.model.parameters(), - lr=new_config.get("lr", 0.01), - momentum=new_config.get("momentum", 0.9)) + for param_group in self.optimizer.param_groups: + if "lr" in new_config: + param_group["lr"] = new_config["lr"] + if "momentum" in new_config: + param_group["momentum"] = new_config["momentum"] + + self.config = new_config return True diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py index 83862a2ab..d4d7948f2 100644 --- a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py +++ b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py @@ -275,16 +275,13 @@ class PytorchTrainable(tune.Trainable): self.optimizerG.load_state_dict(checkpoint["optimG"]) def reset_config(self, new_config): - del self.optimizerD - del self.optimizerG - self.optimizerD = optim.Adam( - self.netD.parameters(), - lr=new_config.get("netD_lr"), - betas=(beta1, 0.999)) - self.optimizerG = optim.Adam( - self.netG.parameters(), - lr=new_config.get("netG_lr"), - betas=(beta1, 0.999)) + if "netD_lr" in new_config: + for param_group in self.optimizerD.param_groups: + param_group["lr"] = new_config["netD_lr"] + if "netG_lr" in new_config: + for param_group in self.optimizerG.param_groups: + param_group["lr"] = new_config["netG_lr"] + self.config = new_config return True