mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] SlateQ + tf; release test fixes, related to TD-error not properly being formatted. (#24521) (#24542)
Co-authored-by: Sven Mika <svenmika1977@gmail.com>
This commit is contained in:
parent
01c99c0347
commit
a20cdaf8f0
1 changed files with 1 additions and 1 deletions
|
@ -204,7 +204,6 @@ def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
||||||
"next_q_target_slate": policy._next_q_target_slate,
|
"next_q_target_slate": policy._next_q_target_slate,
|
||||||
"next_q_target_max": policy._next_q_target_max,
|
"next_q_target_max": policy._next_q_target_max,
|
||||||
"target_clicked": policy._target_clicked,
|
"target_clicked": policy._target_clicked,
|
||||||
"td_error": policy._td_error,
|
|
||||||
"mean_td_error": policy._mean_td_error,
|
"mean_td_error": policy._mean_td_error,
|
||||||
"q_loss": policy._q_loss,
|
"q_loss": policy._q_loss,
|
||||||
"mean_actions": policy._mean_actions,
|
"mean_actions": policy._mean_actions,
|
||||||
|
@ -371,6 +370,7 @@ SlateQTFPolicy = build_tf_policy(
|
||||||
make_model=build_slateq_model,
|
make_model=build_slateq_model,
|
||||||
loss_fn=build_slateq_losses,
|
loss_fn=build_slateq_losses,
|
||||||
stats_fn=build_slateq_stats,
|
stats_fn=build_slateq_stats,
|
||||||
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
|
||||||
optimizer_fn=rmsprop_optimizer,
|
optimizer_fn=rmsprop_optimizer,
|
||||||
# Define how to act.
|
# Define how to act.
|
||||||
action_distribution_fn=action_distribution_fn,
|
action_distribution_fn=action_distribution_fn,
|
||||||
|
|
Loading…
Add table
Reference in a new issue