mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Remove redundant scaler of l2 reg (#5172)
* remove redundant scaler of l2 reg * lint formatted * Update ddpg_policy.py
This commit is contained in:
parent
ae03c42dd6
commit
81d297f87e
1 changed files with 3 additions and 5 deletions
|
@ -231,17 +231,15 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
|
|||
if config["l2_reg"] is not None:
|
||||
for var in self.policy_vars:
|
||||
if "bias" not in var.name:
|
||||
self.actor_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
self.actor_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
|
||||
for var in self.q_func_vars:
|
||||
if "bias" not in var.name:
|
||||
self.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
self.critic_loss += (config["l2_reg"] * tf.nn.l2_loss(var))
|
||||
if self.config["twin_q"]:
|
||||
for var in self.twin_q_func_vars:
|
||||
if "bias" not in var.name:
|
||||
self.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
config["l2_reg"] * tf.nn.l2_loss(var))
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
|
|
Loading…
Add table
Reference in a new issue