mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 4965: Fixes PyTorch grad clipping logic and adds grad clipping to QMIX. (#25584)
This commit is contained in:
parent
a9e7836e8c
commit
9226643433
3 changed files with 34 additions and 19 deletions
|
@ -22,6 +22,7 @@ from ray.rllib.utils.metrics import (
|
|||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
|
||||
class QMixConfig(SimpleQConfig):
|
||||
|
@ -70,7 +71,7 @@ class QMixConfig(SimpleQConfig):
|
|||
self.double_q = True
|
||||
self.optim_alpha = 0.99
|
||||
self.optim_eps = 0.00001
|
||||
self.grad_norm_clipping = 10
|
||||
self.grad_clip = 10
|
||||
|
||||
# Override some of TrainerConfig's default values with QMix-specific values.
|
||||
# .training()
|
||||
|
@ -149,6 +150,7 @@ class QMixConfig(SimpleQConfig):
|
|||
optim_alpha: Optional[float] = None,
|
||||
optim_eps: Optional[float] = None,
|
||||
grad_norm_clipping: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> "QMixConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -162,8 +164,9 @@ class QMixConfig(SimpleQConfig):
|
|||
replay_buffer_config:
|
||||
optim_alpha: RMSProp alpha.
|
||||
optim_eps: RMSProp epsilon.
|
||||
grad_norm_clipping: If not None, clip gradients during optimization at
|
||||
grad_clip: If not None, clip gradients during optimization at
|
||||
this value.
|
||||
grad_norm_clipping: Depcrecated in favor of grad_clip
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
|
@ -171,6 +174,19 @@ class QMixConfig(SimpleQConfig):
|
|||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if grad_norm_clipping is not None:
|
||||
deprecation_warning(
|
||||
old="grad_norm_clipping",
|
||||
new="grad_clip",
|
||||
help="Parameter `grad_norm_clipping` has been "
|
||||
"deprecated in favor of grad_clip in QMix. "
|
||||
"This is now the same parameter as in other "
|
||||
"algorithms. `grad_clip` will be overwritten by "
|
||||
"`grad_norm_clipping={}`".format(grad_norm_clipping),
|
||||
error=False,
|
||||
)
|
||||
grad_clip = grad_norm_clipping
|
||||
|
||||
if mixer is not None:
|
||||
self.mixer = mixer
|
||||
if mixing_embed_dim is not None:
|
||||
|
@ -185,8 +201,8 @@ class QMixConfig(SimpleQConfig):
|
|||
self.optim_alpha = optim_alpha
|
||||
if optim_eps is not None:
|
||||
self.optim_eps = optim_eps
|
||||
if grad_norm_clipping is not None:
|
||||
self.grad_norm_clipping = grad_norm_clipping
|
||||
if grad_clip is not None:
|
||||
self.grad_clip = grad_clip
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
|
||||
# Torch must be installed.
|
||||
torch, nn = try_import_torch(error=True)
|
||||
|
@ -266,7 +267,7 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
)
|
||||
from torch.optim import RMSprop
|
||||
|
||||
self.optimiser = RMSprop(
|
||||
self.rmsprop_optimizer = RMSprop(
|
||||
params=self.params,
|
||||
lr=config["lr"],
|
||||
alpha=config["optim_alpha"],
|
||||
|
@ -444,23 +445,18 @@ class QMixTorchPolicy(TorchPolicy):
|
|||
)
|
||||
|
||||
# Optimise
|
||||
self.optimiser.zero_grad()
|
||||
self.rmsprop_optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.params, self.config["grad_norm_clipping"]
|
||||
)
|
||||
self.optimiser.step()
|
||||
self.rmsprop_optimizer.step()
|
||||
|
||||
mask_elems = mask.sum().item()
|
||||
stats = {
|
||||
"loss": loss_out.item(),
|
||||
"grad_norm": grad_norm
|
||||
if isinstance(grad_norm, float)
|
||||
else grad_norm.item(),
|
||||
"td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
|
||||
"q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems,
|
||||
"target_mean": (targets * mask).sum().item() / mask_elems,
|
||||
}
|
||||
stats.update(apply_grad_clipping(self, self.rmsprop_optimizer, loss_out))
|
||||
return {LEARNER_STATS_KEY: stats}
|
||||
|
||||
@override(TorchPolicy)
|
||||
|
|
|
@ -51,12 +51,15 @@ def apply_grad_clipping(
|
|||
# clip_grad_norm_. Would fail otherwise.
|
||||
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
|
||||
if params:
|
||||
grad_gnorm = nn.utils.clip_grad_norm_(
|
||||
params, policy.config["grad_clip"]
|
||||
)
|
||||
if isinstance(grad_gnorm, torch.Tensor):
|
||||
grad_gnorm = grad_gnorm.cpu().numpy()
|
||||
info["grad_gnorm"] = grad_gnorm
|
||||
# PyTorch clips gradients inplace and returns the norm before clipping
|
||||
# We therefore need to compute grad_gnorm further down (fixes #4965)
|
||||
clip_value = policy.config["grad_clip"]
|
||||
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
|
||||
|
||||
if isinstance(global_norm, torch.Tensor):
|
||||
global_norm = global_norm.cpu().numpy()
|
||||
|
||||
info["grad_gnorm"] = min(global_norm, clip_value)
|
||||
return info
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue