[RLlib] SlateQ + tf; release test fixes, related to TD-error not properly being formatted. (#24521) (#24542)

Co-authored-by: Sven Mika <svenmika1977@gmail.com>
This commit is contained in:
Avnish Narayan 2022-05-06 14:18:43 -04:00 committed by GitHub
parent 01c99c0347
commit a20cdaf8f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -204,7 +204,6 @@ def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
"next_q_target_slate": policy._next_q_target_slate, "next_q_target_slate": policy._next_q_target_slate,
"next_q_target_max": policy._next_q_target_max, "next_q_target_max": policy._next_q_target_max,
"target_clicked": policy._target_clicked, "target_clicked": policy._target_clicked,
"td_error": policy._td_error,
"mean_td_error": policy._mean_td_error, "mean_td_error": policy._mean_td_error,
"q_loss": policy._q_loss, "q_loss": policy._q_loss,
"mean_actions": policy._mean_actions, "mean_actions": policy._mean_actions,
@ -371,6 +370,7 @@ SlateQTFPolicy = build_tf_policy(
make_model=build_slateq_model, make_model=build_slateq_model,
loss_fn=build_slateq_losses, loss_fn=build_slateq_losses,
stats_fn=build_slateq_stats, stats_fn=build_slateq_stats,
extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
optimizer_fn=rmsprop_optimizer, optimizer_fn=rmsprop_optimizer,
# Define how to act. # Define how to act.
action_distribution_fn=action_distribution_fn, action_distribution_fn=action_distribution_fn,