[RLlib] Fix TorchPolicyV2 bug. (#25203)

This commit is contained in:
kourosh hakhamaneshi 2022-05-26 11:49:26 -07:00 committed by GitHub
parent c90dacb09b
commit 9684ea3af6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -588,7 +588,6 @@ class TorchPolicyV2(Policy):
# Action dist class and inputs are generated via custom function.
if is_overridden(self.action_distribution_fn):
dist_inputs, dist_class, state_out = self.action_distribution_fn(
self,
self.model,
input_dict=input_dict,
state_batches=state_batches,
@ -1025,9 +1024,8 @@ class TorchPolicyV2(Policy):
self.exploration.before_compute_actions(explore=explore, timestep=timestep)
if is_overridden(self.action_distribution_fn):
dist_inputs, dist_class, state_out = self.action_distribution_fn(
self,
self.model,
input_dict=input_dict,
obs_batch=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=explore,