[rllib] Fixed pettingzoo wrapper (#11060)

* fixed PettingZooEnv, relevant docs, examples

* fixed linting error

* pettingzoo wrapper fix

* fixed linting issue
This commit is contained in:
Benjamin Black 2020-09-30 18:34:47 -04:00 committed by GitHub
parent c77cfaa5ad
commit 4445f32798
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 9 deletions

View file

@ -56,7 +56,7 @@ mypy
networkx
numba
openpyxl
pettingzoo
pettingzoo>=1.3.2
Pillow; platform_system != "Windows"
pygments
pytest==5.4.3

View file

@ -156,13 +156,21 @@ class PettingZooEnv(MultiAgentEnv):
infos (dict): Optional info values for each agent id.
"""
stepped_agents = set()
while self.aec_env.agent_selection not in stepped_agents:
while (self.aec_env.agent_selection not in stepped_agents
and self.aec_env.dones[self.aec_env.agent_selection]):
agent = self.aec_env.agent_selection
assert agent in action_dict, \
self.aec_env.step(None)
stepped_agents.add(agent)
stepped_agents = set()
# print(action_dict)
while (self.aec_env.agent_selection not in stepped_agents):
agent = self.aec_env.agent_selection
assert agent in action_dict or self.aec_env.dones[agent], \
"Live environment agent is not in actions dictionary"
self.aec_env.step(action_dict[agent])
stepped_agents.add(agent)
# print(self.aec_env.dones)
# print(stepped_agents)
assert all(agent in stepped_agents or self.aec_env.dones[agent]
for agent in action_dict), \
"environment has a nontrivial ordering, and cannot be used with"\
@ -234,11 +242,18 @@ class ParallelPettingZooEnv(MultiAgentEnv):
return self.par_env.reset()
def step(self, action_dict):
for agent in self.agents:
action_dict[agent] = self.action_space.sample()
obs, rew, dones, info = self.par_env.step(action_dict)
dones["__all__"] = all(dones.values())
return obs, rew, dones, info
aobs, arew, adones, ainfo = self.par_env.step(action_dict)
obss = {}
rews = {}
dones = {}
infos = {}
for agent in action_dict:
obss[agent] = aobs[agent]
rews[agent] = arew[agent]
dones[agent] = adones[agent]
infos[agent] = ainfo[agent]
dones["__all__"] = all(adones.values())
return obss, rews, dones, infos
def close(self):
self.par_env.close()