[RLlib] Allow extra keys in info in multi-agent (#20793)

This commit is contained in:
Tomasz Wrona 2021-12-09 14:44:33 +01:00 committed by GitHub
parent a8286c55af
commit 39c202fa66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View file

@ -499,7 +499,9 @@ class _MultiAgentEnvToBaseEnv(BaseEnv):
assert isinstance(rewards, dict), "Not a multi-agent reward"
assert isinstance(dones, dict), "Not a multi-agent return"
assert isinstance(infos, dict), "Not a multi-agent info"
if set(infos).difference(set(obs)):
# Allow `__common__` entry in `infos` for data unrelated with any
# agent, but rather with the environment itself.
if set(infos).difference(set(obs) | {"__common__"}):
raise ValueError("Key set for infos must be a subset of obs: "
"{} vs {}".format(infos.keys(), obs.keys()))
if "__all__" not in dones:
@ -552,7 +554,7 @@ class _MultiAgentEnvState:
observations = self.last_obs
rewards = {}
dones = {"__all__": self.last_dones["__all__"]}
infos = {}
infos = {"__common__": self.last_infos.get("__common__")}
# If episode is done, release everything we have.
if dones["__all__"]:
@ -599,7 +601,7 @@ class _MultiAgentEnvState:
self.last_obs = self.env.reset()
self.last_rewards = {}
self.last_dones = {"__all__": False}
self.last_infos = {}
self.last_infos = {"__common__": {}}
return self.last_obs

View file

@ -828,6 +828,9 @@ def _process_observations(
raise ValueError(
"observe() must return a dict of agent observations")
common_infos = infos[env_id].get("__common__", {})
episode._set_last_info("__common__", common_infos)
# For each agent in the environment.
# types: AgentID, EnvObsType
for agent_id, raw_obs in all_agents_obs.items():