[RLlib] Issue 4965: Fixes PyTorch grad clipping logic and adds grad clipping to QMIX. (#25584)

This commit is contained in:
Artur Niederfahrenhorst 2022-06-08 19:40:57 +02:00 committed by GitHub
parent a9e7836e8c
commit 9226643433
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 19 deletions

View file

@ -22,6 +22,7 @@ from ray.rllib.utils.metrics import (
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import deprecation_warning
class QMixConfig(SimpleQConfig):
@ -70,7 +71,7 @@ class QMixConfig(SimpleQConfig):
self.double_q = True
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_norm_clipping = 10
self.grad_clip = 10
# Override some of TrainerConfig's default values with QMix-specific values.
# .training()
@ -149,6 +150,7 @@ class QMixConfig(SimpleQConfig):
optim_alpha: Optional[float] = None,
optim_eps: Optional[float] = None,
grad_norm_clipping: Optional[float] = None,
grad_clip: Optional[float] = None,
**kwargs,
) -> "QMixConfig":
"""Sets the training related configuration.
@ -162,8 +164,9 @@ class QMixConfig(SimpleQConfig):
replay_buffer_config:
optim_alpha: RMSProp alpha.
optim_eps: RMSProp epsilon.
grad_norm_clipping: If not None, clip gradients during optimization at
grad_clip: If not None, clip gradients during optimization at
this value.
grad_norm_clipping: Depcrecated in favor of grad_clip
Returns:
This updated TrainerConfig object.
@ -171,6 +174,19 @@ class QMixConfig(SimpleQConfig):
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if grad_norm_clipping is not None:
deprecation_warning(
old="grad_norm_clipping",
new="grad_clip",
help="Parameter `grad_norm_clipping` has been "
"deprecated in favor of grad_clip in QMix. "
"This is now the same parameter as in other "
"algorithms. `grad_clip` will be overwritten by "
"`grad_norm_clipping={}`".format(grad_norm_clipping),
error=False,
)
grad_clip = grad_norm_clipping
if mixer is not None:
self.mixer = mixer
if mixing_embed_dim is not None:
@ -185,8 +201,8 @@ class QMixConfig(SimpleQConfig):
self.optim_alpha = optim_alpha
if optim_eps is not None:
self.optim_eps = optim_eps
if grad_norm_clipping is not None:
self.grad_norm_clipping = grad_norm_clipping
if grad_clip is not None:
self.grad_clip = grad_clip
return self

View file

@ -19,6 +19,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.torch_utils import apply_grad_clipping
# Torch must be installed.
torch, nn = try_import_torch(error=True)
@ -266,7 +267,7 @@ class QMixTorchPolicy(TorchPolicy):
)
from torch.optim import RMSprop
self.optimiser = RMSprop(
self.rmsprop_optimizer = RMSprop(
params=self.params,
lr=config["lr"],
alpha=config["optim_alpha"],
@ -444,23 +445,18 @@ class QMixTorchPolicy(TorchPolicy):
)
# Optimise
self.optimiser.zero_grad()
self.rmsprop_optimizer.zero_grad()
loss_out.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.params, self.config["grad_norm_clipping"]
)
self.optimiser.step()
self.rmsprop_optimizer.step()
mask_elems = mask.sum().item()
stats = {
"loss": loss_out.item(),
"grad_norm": grad_norm
if isinstance(grad_norm, float)
else grad_norm.item(),
"td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
"q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems,
"target_mean": (targets * mask).sum().item() / mask_elems,
}
stats.update(apply_grad_clipping(self, self.rmsprop_optimizer, loss_out))
return {LEARNER_STATS_KEY: stats}
@override(TorchPolicy)

View file

@ -51,12 +51,15 @@ def apply_grad_clipping(
# clip_grad_norm_. Would fail otherwise.
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
if params:
grad_gnorm = nn.utils.clip_grad_norm_(
params, policy.config["grad_clip"]
)
if isinstance(grad_gnorm, torch.Tensor):
grad_gnorm = grad_gnorm.cpu().numpy()
info["grad_gnorm"] = grad_gnorm
# PyTorch clips gradients inplace and returns the norm before clipping
# We therefore need to compute grad_gnorm further down (fixes #4965)
clip_value = policy.config["grad_clip"]
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
if isinstance(global_norm, torch.Tensor):
global_norm = global_norm.cpu().numpy()
info["grad_gnorm"] = min(global_norm, clip_value)
return info