mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] SlateQ + tf; release test fixes, related to TD-error not properly being formatted. (#24521) (#24542)
Co-authored-by: Sven Mika <svenmika1977@gmail.com>
This commit is contained in:
parent
01c99c0347
commit
a20cdaf8f0
1 changed files with 1 additions and 1 deletions
|
@ -204,7 +204,6 @@ def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
|||
"next_q_target_slate": policy._next_q_target_slate,
|
||||
"next_q_target_max": policy._next_q_target_max,
|
||||
"target_clicked": policy._target_clicked,
|
||||
"td_error": policy._td_error,
|
||||
"mean_td_error": policy._mean_td_error,
|
||||
"q_loss": policy._q_loss,
|
||||
"mean_actions": policy._mean_actions,
|
||||
|
@ -371,6 +370,7 @@ SlateQTFPolicy = build_tf_policy(
|
|||
make_model=build_slateq_model,
|
||||
loss_fn=build_slateq_losses,
|
||||
stats_fn=build_slateq_stats,
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
|
||||
optimizer_fn=rmsprop_optimizer,
|
||||
# Define how to act.
|
||||
action_distribution_fn=action_distribution_fn,
|
||||
|
|
Loading…
Add table
Reference in a new issue