From 39c202fa66e80d3a113970fe9b9d2f01503bf968 Mon Sep 17 00:00:00 2001 From: Tomasz Wrona Date: Thu, 9 Dec 2021 14:44:33 +0100 Subject: [PATCH] [RLlib] Allow extra keys in info in multi-agent (#20793) --- rllib/env/base_env.py | 8 +++++--- rllib/evaluation/sampler.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 03b986d6d..32cb2f072 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -499,7 +499,9 @@ class _MultiAgentEnvToBaseEnv(BaseEnv): assert isinstance(rewards, dict), "Not a multi-agent reward" assert isinstance(dones, dict), "Not a multi-agent return" assert isinstance(infos, dict), "Not a multi-agent info" - if set(infos).difference(set(obs)): + # Allow `__common__` entry in `infos` for data unrelated with any + # agent, but rather with the environment itself. + if set(infos).difference(set(obs) | {"__common__"}): raise ValueError("Key set for infos must be a subset of obs: " "{} vs {}".format(infos.keys(), obs.keys())) if "__all__" not in dones: @@ -552,7 +554,7 @@ class _MultiAgentEnvState: observations = self.last_obs rewards = {} dones = {"__all__": self.last_dones["__all__"]} - infos = {} + infos = {"__common__": self.last_infos.get("__common__")} # If episode is done, release everything we have. if dones["__all__"]: @@ -599,7 +601,7 @@ class _MultiAgentEnvState: self.last_obs = self.env.reset() self.last_rewards = {} self.last_dones = {"__all__": False} - self.last_infos = {} + self.last_infos = {"__common__": {}} return self.last_obs diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 5ade28217..f52e14e9e 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -828,6 +828,9 @@ def _process_observations( raise ValueError( "observe() must return a dict of agent observations") + common_infos = infos[env_id].get("__common__", {}) + episode._set_last_info("__common__", common_infos) + # For each agent in the environment. # types: AgentID, EnvObsType for agent_id, raw_obs in all_agents_obs.items():