From d7301a51f452189166bdafb5a7c63081bf1c0910 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 9 Feb 2021 17:05:26 +0100 Subject: [PATCH] [RLlib]: Trajectory View API: Keep env infos (e.g. for postprocessing callbacks), no matter what. (#13555) --- rllib/policy/dynamic_tf_policy.py | 4 ++-- rllib/policy/policy.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 10ecf9931..a5b01db87 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -590,12 +590,12 @@ class DynamicTFPolicy(TFPolicy): del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). - # Also always leave DONES and REWARDS, no matter what. + # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, - SampleBatch.REWARDS] and \ + SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 1bce4b96d..d208c7d15 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -676,12 +676,12 @@ class Policy(metaclass=ABCMeta): self.view_requirements[key].used_for_training = False # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). - # Also always leave DONES and REWARDS, no matter what. + # Also always leave DONES, REWARDS, INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, - SampleBatch.REWARDS] and \ + SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from