[tune] update params for optimizer in reset_config (#6522)

* reset config update lr

* add default

* Update pbt_dcgan_mnist.py

* Update pbt_convnet_example.py

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Yuhao Yang 2019-12-25 17:10:09 -08:00 committed by Richard Liaw
parent aa7b861332
commit 8707a721d9
2 changed files with 14 additions and 15 deletions

View file

@ -52,11 +52,13 @@ class PytorchTrainble(tune.Trainable):
self.model.load_state_dict(torch.load(checkpoint_path))
def reset_config(self, new_config):
del self.optimizer
self.optimizer = optim.SGD(
self.model.parameters(),
lr=new_config.get("lr", 0.01),
momentum=new_config.get("momentum", 0.9))
for param_group in self.optimizer.param_groups:
if "lr" in new_config:
param_group["lr"] = new_config["lr"]
if "momentum" in new_config:
param_group["momentum"] = new_config["momentum"]
self.config = new_config
return True

View file

@ -275,16 +275,13 @@ class PytorchTrainable(tune.Trainable):
self.optimizerG.load_state_dict(checkpoint["optimG"])
def reset_config(self, new_config):
del self.optimizerD
del self.optimizerG
self.optimizerD = optim.Adam(
self.netD.parameters(),
lr=new_config.get("netD_lr"),
betas=(beta1, 0.999))
self.optimizerG = optim.Adam(
self.netG.parameters(),
lr=new_config.get("netG_lr"),
betas=(beta1, 0.999))
if "netD_lr" in new_config:
for param_group in self.optimizerD.param_groups:
param_group["lr"] = new_config["netD_lr"]
if "netG_lr" in new_config:
for param_group in self.optimizerG.param_groups:
param_group["lr"] = new_config["netG_lr"]
self.config = new_config
return True