mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Update sac_tf_policy.py (add tf.cast to float32 for rewards) (#14843)
This commit is contained in:
parent
6708211b59
commit
8874ccec2d
1 changed files with 1 additions and 1 deletions
|
@ -323,7 +323,7 @@ def sac_actor_critic_loss(
|
|||
|
||||
# Compute RHS of bellman equation for the Q-loss (critic(s)).
|
||||
q_t_selected_target = tf.stop_gradient(
|
||||
train_batch[SampleBatch.REWARDS] +
|
||||
tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) +
|
||||
policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
|
||||
|
||||
# Compute the TD-error (potentially clipped).
|
||||
|
|
Loading…
Add table
Reference in a new issue