mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] update params for optimizer in reset_config (#6522)
* reset config update lr * add default * Update pbt_dcgan_mnist.py * Update pbt_convnet_example.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
aa7b861332
commit
8707a721d9
2 changed files with 14 additions and 15 deletions
|
@ -52,11 +52,13 @@ class PytorchTrainble(tune.Trainable):
|
|||
self.model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
def reset_config(self, new_config):
|
||||
del self.optimizer
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=new_config.get("lr", 0.01),
|
||||
momentum=new_config.get("momentum", 0.9))
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if "lr" in new_config:
|
||||
param_group["lr"] = new_config["lr"]
|
||||
if "momentum" in new_config:
|
||||
param_group["momentum"] = new_config["momentum"]
|
||||
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -275,16 +275,13 @@ class PytorchTrainable(tune.Trainable):
|
|||
self.optimizerG.load_state_dict(checkpoint["optimG"])
|
||||
|
||||
def reset_config(self, new_config):
|
||||
del self.optimizerD
|
||||
del self.optimizerG
|
||||
self.optimizerD = optim.Adam(
|
||||
self.netD.parameters(),
|
||||
lr=new_config.get("netD_lr"),
|
||||
betas=(beta1, 0.999))
|
||||
self.optimizerG = optim.Adam(
|
||||
self.netG.parameters(),
|
||||
lr=new_config.get("netG_lr"),
|
||||
betas=(beta1, 0.999))
|
||||
if "netD_lr" in new_config:
|
||||
for param_group in self.optimizerD.param_groups:
|
||||
param_group["lr"] = new_config["netD_lr"]
|
||||
if "netG_lr" in new_config:
|
||||
for param_group in self.optimizerG.param_groups:
|
||||
param_group["lr"] = new_config["netG_lr"]
|
||||
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue