[RLlib] Unify gnorm mixin for tf and torch policies. (#26102)

This commit is contained in:
Artur Niederfahrenhorst 2022-07-24 15:31:09 +02:00 committed by GitHub
parent c44d9ff397
commit e9a8f7d9ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 50 additions and 32 deletions

View file

@ -32,6 +32,7 @@ from ray.rllib.policy.tf_mixins import (
LearningRateSchedule,
KLCoeffMixin,
ValueNetworkMixin,
GradStatsMixin,
)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
@ -99,6 +100,7 @@ def get_appo_tf_policy(name: str, base: type) -> type:
EntropyCoeffSchedule,
ValueNetworkMixin,
TargetNetworkMixin,
GradStatsMixin,
base,
):
def __init__(
@ -137,6 +139,7 @@ def get_appo_tf_policy(name: str, base: type) -> type:
)
ValueNetworkMixin.__init__(self, config)
KLCoeffMixin.__init__(self, config)
GradStatsMixin.__init__(self)
# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.

View file

@ -19,6 +19,7 @@ from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import explained_variance
from ray.rllib.policy.tf_mixins import GradStatsMixin
from ray.rllib.utils.typing import (
LocalOptimizer,
ModelGradients,
@ -271,6 +272,7 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
VTraceOptimizer,
LearningRateSchedule,
EntropyCoeffSchedule,
GradStatsMixin,
base,
):
def __init__(
@ -298,6 +300,7 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
existing_model=existing_model,
)
GradStatsMixin.__init__(self)
VTraceClipGradients.__init__(self)
VTraceOptimizer.__init__(self)
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
@ -417,22 +420,6 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
),
}
@override(base)
def grad_stats_fn(
self, train_batch: SampleBatch, grads: ModelGradients
) -> Dict[str, TensorType]:
# We have support for more than one loss (list of lists of grads).
if self.config.get("_tf_policy_handles_more_than_one_loss"):
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
# Old case: We have a single list of grads (only one loss term and
# optimizer).
else:
grad_gnorm = tf.linalg.global_norm(grads)
return {
"grad_gnorm": grad_gnorm,
}
@override(base)
def get_batch_divisibility_req(self) -> int:
return self.config["rollout_fragment_length"]

View file

@ -124,7 +124,7 @@ class TestPPO(unittest.TestCase):
for fw in framework_iterator(config, with_eager_tracing=True):
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env))
for lstm in [True, False]:
for lstm in [False]:
print("LSTM={}".format(lstm))
config.training(
model=dict(

View file

@ -324,6 +324,26 @@ class ValueNetworkMixin:
return self._cached_extra_action_fetches
class GradStatsMixin:
def __init__(self):
pass
def grad_stats_fn(
self, train_batch: SampleBatch, grads: ModelGradients
) -> Dict[str, TensorType]:
# We have support for more than one loss (list of lists of grads).
if self.config.get("_tf_policy_handles_more_than_one_loss"):
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
# Old case: We have a single list of grads (only one loss term and
# optimizer).
else:
grad_gnorm = tf.linalg.global_norm(grads)
return {
"grad_gnorm": grad_gnorm,
}
# TODO: find a better place for this util, since it's not technically MixIns.
@DeveloperAPI
def compute_gradients(

View file

@ -46,23 +46,31 @@ def apply_grad_clipping(
An info dict containing the "grad_norm" key and the resulting clipped
gradients.
"""
info = {}
if policy.config["grad_clip"]:
for param_group in optimizer.param_groups:
# Make sure we only pass params with grad != None into torch
# clip_grad_norm_. Would fail otherwise.
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
if params:
# PyTorch clips gradients inplace and returns the norm before clipping
# We therefore need to compute grad_gnorm further down (fixes #4965)
clip_value = policy.config["grad_clip"]
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
grad_gnorm = 0
if policy.config["grad_clip"] is not None:
clip_value = policy.config["grad_clip"]
else:
clip_value = np.inf
if isinstance(global_norm, torch.Tensor):
global_norm = global_norm.cpu().numpy()
for param_group in optimizer.param_groups:
# Make sure we only pass params with grad != None into torch
# clip_grad_norm_. Would fail otherwise.
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
if params:
# PyTorch clips gradients inplace and returns the norm before clipping
# We therefore need to compute grad_gnorm further down (fixes #4965)
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
info["grad_gnorm"] = min(global_norm, clip_value)
return info
if isinstance(global_norm, torch.Tensor):
global_norm = global_norm.cpu().numpy()
grad_gnorm += min(global_norm, clip_value)
if grad_gnorm > 0:
return {"grad_gnorm": grad_gnorm}
else:
# No grads available
return {}
@Deprecated(