mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Unify gnorm mixin for tf and torch policies. (#26102)
This commit is contained in:
parent
c44d9ff397
commit
e9a8f7d9ae
5 changed files with 50 additions and 32 deletions
|
@ -32,6 +32,7 @@ from ray.rllib.policy.tf_mixins import (
|
||||||
LearningRateSchedule,
|
LearningRateSchedule,
|
||||||
KLCoeffMixin,
|
KLCoeffMixin,
|
||||||
ValueNetworkMixin,
|
ValueNetworkMixin,
|
||||||
|
GradStatsMixin,
|
||||||
)
|
)
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||||
|
@ -99,6 +100,7 @@ def get_appo_tf_policy(name: str, base: type) -> type:
|
||||||
EntropyCoeffSchedule,
|
EntropyCoeffSchedule,
|
||||||
ValueNetworkMixin,
|
ValueNetworkMixin,
|
||||||
TargetNetworkMixin,
|
TargetNetworkMixin,
|
||||||
|
GradStatsMixin,
|
||||||
base,
|
base,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -137,6 +139,7 @@ def get_appo_tf_policy(name: str, base: type) -> type:
|
||||||
)
|
)
|
||||||
ValueNetworkMixin.__init__(self, config)
|
ValueNetworkMixin.__init__(self, config)
|
||||||
KLCoeffMixin.__init__(self, config)
|
KLCoeffMixin.__init__(self, config)
|
||||||
|
GradStatsMixin.__init__(self)
|
||||||
|
|
||||||
# Note: this is a bit ugly, but loss and optimizer initialization must
|
# Note: this is a bit ugly, but loss and optimizer initialization must
|
||||||
# happen after all the MixIns are initialized.
|
# happen after all the MixIns are initialized.
|
||||||
|
|
|
@ -19,6 +19,7 @@ from ray.rllib.utils import force_list
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.tf_utils import explained_variance
|
from ray.rllib.utils.tf_utils import explained_variance
|
||||||
|
from ray.rllib.policy.tf_mixins import GradStatsMixin
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
LocalOptimizer,
|
LocalOptimizer,
|
||||||
ModelGradients,
|
ModelGradients,
|
||||||
|
@ -271,6 +272,7 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
|
||||||
VTraceOptimizer,
|
VTraceOptimizer,
|
||||||
LearningRateSchedule,
|
LearningRateSchedule,
|
||||||
EntropyCoeffSchedule,
|
EntropyCoeffSchedule,
|
||||||
|
GradStatsMixin,
|
||||||
base,
|
base,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -298,6 +300,7 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
|
||||||
existing_model=existing_model,
|
existing_model=existing_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GradStatsMixin.__init__(self)
|
||||||
VTraceClipGradients.__init__(self)
|
VTraceClipGradients.__init__(self)
|
||||||
VTraceOptimizer.__init__(self)
|
VTraceOptimizer.__init__(self)
|
||||||
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
|
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
|
||||||
|
@ -417,22 +420,6 @@ def get_impala_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@override(base)
|
|
||||||
def grad_stats_fn(
|
|
||||||
self, train_batch: SampleBatch, grads: ModelGradients
|
|
||||||
) -> Dict[str, TensorType]:
|
|
||||||
# We have support for more than one loss (list of lists of grads).
|
|
||||||
if self.config.get("_tf_policy_handles_more_than_one_loss"):
|
|
||||||
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
|
|
||||||
# Old case: We have a single list of grads (only one loss term and
|
|
||||||
# optimizer).
|
|
||||||
else:
|
|
||||||
grad_gnorm = tf.linalg.global_norm(grads)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"grad_gnorm": grad_gnorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
@override(base)
|
@override(base)
|
||||||
def get_batch_divisibility_req(self) -> int:
|
def get_batch_divisibility_req(self) -> int:
|
||||||
return self.config["rollout_fragment_length"]
|
return self.config["rollout_fragment_length"]
|
||||||
|
|
|
@ -124,7 +124,7 @@ class TestPPO(unittest.TestCase):
|
||||||
for fw in framework_iterator(config, with_eager_tracing=True):
|
for fw in framework_iterator(config, with_eager_tracing=True):
|
||||||
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
|
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
|
||||||
print("Env={}".format(env))
|
print("Env={}".format(env))
|
||||||
for lstm in [True, False]:
|
for lstm in [False]:
|
||||||
print("LSTM={}".format(lstm))
|
print("LSTM={}".format(lstm))
|
||||||
config.training(
|
config.training(
|
||||||
model=dict(
|
model=dict(
|
||||||
|
|
|
@ -324,6 +324,26 @@ class ValueNetworkMixin:
|
||||||
return self._cached_extra_action_fetches
|
return self._cached_extra_action_fetches
|
||||||
|
|
||||||
|
|
||||||
|
class GradStatsMixin:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def grad_stats_fn(
|
||||||
|
self, train_batch: SampleBatch, grads: ModelGradients
|
||||||
|
) -> Dict[str, TensorType]:
|
||||||
|
# We have support for more than one loss (list of lists of grads).
|
||||||
|
if self.config.get("_tf_policy_handles_more_than_one_loss"):
|
||||||
|
grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
|
||||||
|
# Old case: We have a single list of grads (only one loss term and
|
||||||
|
# optimizer).
|
||||||
|
else:
|
||||||
|
grad_gnorm = tf.linalg.global_norm(grads)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"grad_gnorm": grad_gnorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: find a better place for this util, since it's not technically MixIns.
|
# TODO: find a better place for this util, since it's not technically MixIns.
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def compute_gradients(
|
def compute_gradients(
|
||||||
|
|
|
@ -46,23 +46,31 @@ def apply_grad_clipping(
|
||||||
An info dict containing the "grad_norm" key and the resulting clipped
|
An info dict containing the "grad_norm" key and the resulting clipped
|
||||||
gradients.
|
gradients.
|
||||||
"""
|
"""
|
||||||
info = {}
|
grad_gnorm = 0
|
||||||
if policy.config["grad_clip"]:
|
if policy.config["grad_clip"] is not None:
|
||||||
for param_group in optimizer.param_groups:
|
clip_value = policy.config["grad_clip"]
|
||||||
# Make sure we only pass params with grad != None into torch
|
else:
|
||||||
# clip_grad_norm_. Would fail otherwise.
|
clip_value = np.inf
|
||||||
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
|
|
||||||
if params:
|
|
||||||
# PyTorch clips gradients inplace and returns the norm before clipping
|
|
||||||
# We therefore need to compute grad_gnorm further down (fixes #4965)
|
|
||||||
clip_value = policy.config["grad_clip"]
|
|
||||||
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
|
|
||||||
|
|
||||||
if isinstance(global_norm, torch.Tensor):
|
for param_group in optimizer.param_groups:
|
||||||
global_norm = global_norm.cpu().numpy()
|
# Make sure we only pass params with grad != None into torch
|
||||||
|
# clip_grad_norm_. Would fail otherwise.
|
||||||
|
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
|
||||||
|
if params:
|
||||||
|
# PyTorch clips gradients inplace and returns the norm before clipping
|
||||||
|
# We therefore need to compute grad_gnorm further down (fixes #4965)
|
||||||
|
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
|
||||||
|
|
||||||
info["grad_gnorm"] = min(global_norm, clip_value)
|
if isinstance(global_norm, torch.Tensor):
|
||||||
return info
|
global_norm = global_norm.cpu().numpy()
|
||||||
|
|
||||||
|
grad_gnorm += min(global_norm, clip_value)
|
||||||
|
|
||||||
|
if grad_gnorm > 0:
|
||||||
|
return {"grad_gnorm": grad_gnorm}
|
||||||
|
else:
|
||||||
|
# No grads available
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@Deprecated(
|
@Deprecated(
|
||||||
|
|
Loading…
Add table
Reference in a new issue