mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Fix TorchPolicyV2 bug. (#25203)
This commit is contained in:
parent
c90dacb09b
commit
9684ea3af6
1 changed files with 1 additions and 3 deletions
|
@ -588,7 +588,6 @@ class TorchPolicyV2(Policy):
|
|||
# Action dist class and inputs are generated via custom function.
|
||||
if is_overridden(self.action_distribution_fn):
|
||||
dist_inputs, dist_class, state_out = self.action_distribution_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_dict=input_dict,
|
||||
state_batches=state_batches,
|
||||
|
@ -1025,9 +1024,8 @@ class TorchPolicyV2(Policy):
|
|||
self.exploration.before_compute_actions(explore=explore, timestep=timestep)
|
||||
if is_overridden(self.action_distribution_fn):
|
||||
dist_inputs, dist_class, state_out = self.action_distribution_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_dict=input_dict,
|
||||
obs_batch=input_dict,
|
||||
state_batches=state_batches,
|
||||
seq_lens=seq_lens,
|
||||
explore=explore,
|
||||
|
|
Loading…
Add table
Reference in a new issue