mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Unify gnorm mixin for tf and torch policies. (#26102)
This commit is contained in:
parent
c44d9ff397
commit
e9a8f7d9ae
5 changed files with 50 additions and 32 deletions
|
@ -32,6 +32,7 @@ from ray.rllib.policy.tf_mixins import (
|
|||
LearningRateSchedule,
|
||||
KLCoeffMixin,
|
||||
ValueNetworkMixin,
|
||||
GradStatsMixin,
|
||||
)
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
|
@ -99,6 +100,7 @@ def get_appo_tf_policy(name: str, base: type) -> type:
|
|||
EntropyCoeffSchedule,
|
||||
ValueNetworkMixin,
|
||||
TargetNetworkMixin,
|
||||
GradStatsMixin,
|
||||
base,
|
||||
):
|
||||
def __init__(
|
||||
|
@ -137,6 +139,7 @@ def get_appo_tf_policy(name: str, base: type) -> type:
|
|||
)
|
||||
ValueNetworkMixin.__init__(self, config)
|
||||
KLCoeffMixin.__init__(self, config)
|
||||
GradStatsMixin.__init__(self)
|
||||
|
||||
# Note: this is a bit ugly, but loss and optimizer initialization must
|
||||
# happen after all the MixIns are initialized.
|
||||
|
|
|
@ -19,6 +19,7 @@ from ray.rllib.utils import force_list
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_utils import explained_variance
|
||||
from ray.rllib.policy.tf_mixins import GradStatsMixin
|
||||
from ray.rllib.utils.typing import (
|
||||
LocalOptimizer,
|
||||
ModelGradients,
|
||||
|
@ -271,6 +272,7 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
|
|||
VTraceOptimizer,
|
||||
LearningRateSchedule,
|
||||
EntropyCoeffSchedule,
|
||||
GradStatsMixin,
|
||||
base,
|
||||
):
|
||||
def __init__(
|
||||
|
@ -298,6 +300,7 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
|
|||
existing_model=existing_model,
|
||||
)
|
||||
|
||||
GradStatsMixin.__init__(self)
|
||||
VTraceClipGradients.__init__(self)
|
||||
VTraceOptimizer.__init__(self)
|
||||
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
|
||||
|
@ -417,22 +420,6 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
|
|||
),
|
||||
}
|
||||
|
||||
@override(base)
|
||||
def grad_stats_fn(
|
||||
self, train_batch: SampleBatch, grads: ModelGradients
|
||||
) -> Dict[str, TensorType]:
|
||||
# We have support for more than one loss (list of lists of grads).
|
||||
if self.config.get("_tf_policy_handles_more_than_one_loss"):
|
||||
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
|
||||
# Old case: We have a single list of grads (only one loss term and
|
||||
# optimizer).
|
||||
else:
|
||||
grad_gnorm = tf.linalg.global_norm(grads)
|
||||
|
||||
return {
|
||||
"grad_gnorm": grad_gnorm,
|
||||
}
|
||||
|
||||
@override(base)
|
||||
def get_batch_divisibility_req(self) -> int:
|
||||
return self.config["rollout_fragment_length"]
|
||||
|
|
|
@ -124,7 +124,7 @@ class TestPPO(unittest.TestCase):
|
|||
for fw in framework_iterator(config, with_eager_tracing=True):
|
||||
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
|
||||
print("Env={}".format(env))
|
||||
for lstm in [True, False]:
|
||||
for lstm in [False]:
|
||||
print("LSTM={}".format(lstm))
|
||||
config.training(
|
||||
model=dict(
|
||||
|
|
|
@ -324,6 +324,26 @@ class ValueNetworkMixin:
|
|||
return self._cached_extra_action_fetches
|
||||
|
||||
|
||||
class GradStatsMixin:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def grad_stats_fn(
|
||||
self, train_batch: SampleBatch, grads: ModelGradients
|
||||
) -> Dict[str, TensorType]:
|
||||
# We have support for more than one loss (list of lists of grads).
|
||||
if self.config.get("_tf_policy_handles_more_than_one_loss"):
|
||||
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
|
||||
# Old case: We have a single list of grads (only one loss term and
|
||||
# optimizer).
|
||||
else:
|
||||
grad_gnorm = tf.linalg.global_norm(grads)
|
||||
|
||||
return {
|
||||
"grad_gnorm": grad_gnorm,
|
||||
}
|
||||
|
||||
|
||||
# TODO: find a better place for this util, since it's not technically MixIns.
|
||||
@DeveloperAPI
|
||||
def compute_gradients(
|
||||
|
|
|
@ -46,23 +46,31 @@ def apply_grad_clipping(
|
|||
An info dict containing the "grad_norm" key and the resulting clipped
|
||||
gradients.
|
||||
"""
|
||||
info = {}
|
||||
if policy.config["grad_clip"]:
|
||||
for param_group in optimizer.param_groups:
|
||||
# Make sure we only pass params with grad != None into torch
|
||||
# clip_grad_norm_. Would fail otherwise.
|
||||
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
|
||||
if params:
|
||||
# PyTorch clips gradients inplace and returns the norm before clipping
|
||||
# We therefore need to compute grad_gnorm further down (fixes #4965)
|
||||
clip_value = policy.config["grad_clip"]
|
||||
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
|
||||
grad_gnorm = 0
|
||||
if policy.config["grad_clip"] is not None:
|
||||
clip_value = policy.config["grad_clip"]
|
||||
else:
|
||||
clip_value = np.inf
|
||||
|
||||
if isinstance(global_norm, torch.Tensor):
|
||||
global_norm = global_norm.cpu().numpy()
|
||||
for param_group in optimizer.param_groups:
|
||||
# Make sure we only pass params with grad != None into torch
|
||||
# clip_grad_norm_. Would fail otherwise.
|
||||
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
|
||||
if params:
|
||||
# PyTorch clips gradients inplace and returns the norm before clipping
|
||||
# We therefore need to compute grad_gnorm further down (fixes #4965)
|
||||
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
|
||||
|
||||
info["grad_gnorm"] = min(global_norm, clip_value)
|
||||
return info
|
||||
if isinstance(global_norm, torch.Tensor):
|
||||
global_norm = global_norm.cpu().numpy()
|
||||
|
||||
grad_gnorm += min(global_norm, clip_value)
|
||||
|
||||
if grad_gnorm > 0:
|
||||
return {"grad_gnorm": grad_gnorm}
|
||||
else:
|
||||
# No grads available
|
||||
return {}
|
||||
|
||||
|
||||
@Deprecated(
|
||||
|
|
Loading…
Add table
Reference in a new issue