[RLlib] Update sac_tf_policy.py (add tf.cast to float32 for rewards) (#14843)

This commit is contained in:
astronauti 2021-03-24 16:12:55 +01:00 committed by GitHub
parent 6708211b59
commit 8874ccec2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -323,7 +323,7 @@ def sac_actor_critic_loss(
# Compute RHS of bellman equation for the Q-loss (critic(s)).
q_t_selected_target = tf.stop_gradient(
train_batch[SampleBatch.REWARDS] +
tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) +
policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
# Compute the TD-error (potentially clipped).