mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Allow extra keys in info in multi-agent (#20793)
This commit is contained in:
parent
a8286c55af
commit
39c202fa66
2 changed files with 8 additions and 3 deletions
8
rllib/env/base_env.py
vendored
8
rllib/env/base_env.py
vendored
|
@ -499,7 +499,9 @@ class _MultiAgentEnvToBaseEnv(BaseEnv):
|
|||
assert isinstance(rewards, dict), "Not a multi-agent reward"
|
||||
assert isinstance(dones, dict), "Not a multi-agent return"
|
||||
assert isinstance(infos, dict), "Not a multi-agent info"
|
||||
if set(infos).difference(set(obs)):
|
||||
# Allow `__common__` entry in `infos` for data unrelated with any
|
||||
# agent, but rather with the environment itself.
|
||||
if set(infos).difference(set(obs) | {"__common__"}):
|
||||
raise ValueError("Key set for infos must be a subset of obs: "
|
||||
"{} vs {}".format(infos.keys(), obs.keys()))
|
||||
if "__all__" not in dones:
|
||||
|
@ -552,7 +554,7 @@ class _MultiAgentEnvState:
|
|||
observations = self.last_obs
|
||||
rewards = {}
|
||||
dones = {"__all__": self.last_dones["__all__"]}
|
||||
infos = {}
|
||||
infos = {"__common__": self.last_infos.get("__common__")}
|
||||
|
||||
# If episode is done, release everything we have.
|
||||
if dones["__all__"]:
|
||||
|
@ -599,7 +601,7 @@ class _MultiAgentEnvState:
|
|||
self.last_obs = self.env.reset()
|
||||
self.last_rewards = {}
|
||||
self.last_dones = {"__all__": False}
|
||||
self.last_infos = {}
|
||||
self.last_infos = {"__common__": {}}
|
||||
return self.last_obs
|
||||
|
||||
|
||||
|
|
|
@ -828,6 +828,9 @@ def _process_observations(
|
|||
raise ValueError(
|
||||
"observe() must return a dict of agent observations")
|
||||
|
||||
common_infos = infos[env_id].get("__common__", {})
|
||||
episode._set_last_info("__common__", common_infos)
|
||||
|
||||
# For each agent in the environment.
|
||||
# types: AgentID, EnvObsType
|
||||
for agent_id, raw_obs in all_agents_obs.items():
|
||||
|
|
Loading…
Add table
Reference in a new issue