Pettingzoo environment support (#9271)

* added pettingzoo wrapper env and example

* added docs, examples for pettingzoo env support

* fixed pettingzoo env flake8, added test

* fixed pettingzoo env import

* fixed pettingzoo env import

* fixed pettingzoo import issue

* fixed pettingzoo test

* fixed linting problem

* fixed bad quotes

* future proofed pettingzoo dependency

* fixed ray init in pettingzoo env

* lint

* manual lint

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Benjamin Black 2020-07-06 22:32:26 -06:00 committed by GitHub
parent b42d6a1ddc
commit 1425cdf834
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 375 additions and 1 deletions

View file

@ -232,7 +232,7 @@ install_dependencies() {
opencv-python-headless pyyaml pandas==1.0.5 requests feather-format lxml openpyxl xlrd \ opencv-python-headless pyyaml pandas==1.0.5 requests feather-format lxml openpyxl xlrd \
py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \ py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \
kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \ kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \
Pillow prometheus_client boto3) Pillow prometheus_client boto3 pettingzoo)
if [ "${OSTYPE}" != msys ]; then if [ "${OSTYPE}" != msys ]; then
# These packages aren't Windows-compatible # These packages aren't Windows-compatible
pip_packages+=(blist) # https://github.com/DanielStutzbach/blist/issues/81#issue-391460716 pip_packages+=(blist) # https://github.com/DanielStutzbach/blist/issues/81#issue-391460716

View file

@ -203,6 +203,29 @@ Here is a simple `example training script <https://github.com/ray-project/ray/bl
To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``. To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``.
PettingZoo Multi-Agent Environments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
`PettingZoo <https://github.com/PettingZoo-Team/PettingZoo>`__ is a repository of over 50 diverse multi-agent environments. However, the API is note directly compatible with rllib, but it can be converted into an rllib MultiAgentEnv like in this example
.. code-block:: python
from ray.tune.registry import register_env
# import the pettingzoo environment
from pettingzoo.gamma import prison_v0
# import rllib pettingzoo interface
from ray.rllib.env import PettingZooEnv
# define how to make the environment. This way takes an optinoal environment config, num_floors
env_creator = lambda config: prison_v0.env(num_floors=config.get("num_floors", 4))
# register that way to make the environment under an rllib name
register_env('prison', lambda config: PettingZooEnv(env_creator(config)))
# now you can use `prison` as an environment
# you can pass arguments to the environment creator with the env_config option in the config
config['env_config'] = {"num_floors": 5}
A more complete example is here: `pettingzoo_env.py <https://github.com/ray-project/ray/blob/master/rllib/examples/pettingzoo_env.py>`__
Rock Paper Scissors Example Rock Paper Scissors Example
~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -1310,6 +1310,13 @@ py_test(
args = ["TestSupportedMultiAgentOffPolicy"] args = ["TestSupportedMultiAgentOffPolicy"]
) )
py_test(
name = "tests/test_pettingzoo_env",
tags = ["tests_dir", "tests_dir_S"],
size = "medium",
srcs = ["tests/test_pettingzoo_env.py"]
)
py_test( py_test(
name = "tests/test_supported_spaces", name = "tests/test_supported_spaces",
tags = ["tests_dir", "tests_dir_S"], tags = ["tests_dir", "tests_dir_S"],

View file

@ -2,6 +2,7 @@ from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.dm_env_wrapper import DMEnv from ray.rllib.env.dm_env_wrapper import DMEnv
from ray.rllib.env.dm_control_wrapper import DMCEnv from ray.rllib.env.dm_control_wrapper import DMCEnv
from ray.rllib.env.unity3d_env import Unity3DEnv from ray.rllib.env.unity3d_env import Unity3DEnv
from ray.rllib.env.pettingzoo_env import PettingZooEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
@ -20,6 +21,7 @@ __all__ = [
"DMEnv", "DMEnv",
"DMCEnv", "DMCEnv",
"Unity3DEnv", "Unity3DEnv",
"PettingZooEnv",
"PolicyClient", "PolicyClient",
"PolicyServerInput", "PolicyServerInput",
] ]

207
rllib/env/pettingzoo_env.py vendored Normal file
View file

@ -0,0 +1,207 @@
from .multi_agent_env import MultiAgentEnv
class PettingZooEnv(MultiAgentEnv):
"""An interface to the PettingZoo MARL environment library.
See: https://github.com/PettingZoo-Team/PettingZoo
Inherits from MultiAgentEnv and exposes a given AEC
(actor-environment-cycle) game from the PettingZoo project via the
MultiAgentEnv public API.
It reduces the class of AEC games to Partially Observable Markov (POM)
games by imposing the following important restrictions onto an AEC
environment:
1. Each agent steps in order specified in agents list (unless they are
done, in which case, they should be skipped).
2. Agents act simultaneously (-> No hard-turn games like chess).
3. All agents have the same action_spaces and observation_spaces.
Note: If, within your aec game, agents do not have homogeneous action /
observation spaces, apply SuperSuit wrappers
to apply padding functionality: https://github.com/PettingZoo-Team/
SuperSuit#built-in-multi-agent-only-functions
4. Environments are positive sum games (-> Agents are expected to cooperate
to maximize reward). This isn't a hard restriction, it just that
standard algorithms aren't expected to work well in highly competitive
games.
Examples:
>>> from pettingzoo.gamma import prison_v0
>>> env = POMGameEnv(env_creator=prison_v0})
>>> obs = env.reset()
>>> print(obs)
{
"0": [110, 119],
"1": [105, 102],
"2": [99, 95],
}
>>> obs, rewards, dones, infos = env.step(
action_dict={
"0": 1, "1": 0, "2": 2,
})
>>> print(rewards)
{
"0": 0,
"1": 1,
"2": 0,
}
>>> print(dones)
{
"0": False, # agent 0 is still running
"1": True, # agent 1 is done
"__all__": False, # the env is not done
}
>>> print(infos)
{
"0": {}, # info for agent 0
"1": {}, # info for agent 1
}
"""
def __init__(self, env):
"""
Parameters:
-----------
env: AECenv object.
"""
self.aec_env = env
# agent idx list
self.agents = self.aec_env.agents
# Get dictionaries of obs_spaces and act_spaces
self.observation_spaces = self.aec_env.observation_spaces
self.action_spaces = self.aec_env.action_spaces
# Get first observation space, assuming all agents have equal space
self.observation_space = self.observation_spaces[self.agents[0]]
# Get first action space, assuming all agents have equal space
self.action_space = self.action_spaces[self.agents[0]]
assert all(obs_space == self.observation_space
for obs_space
in self.aec_env.observation_spaces.values()), \
"Observation spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_observations wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_observations(env)`"
assert all(act_space == self.action_space
for act_space in self.aec_env.action_spaces.values()), \
"Action spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_action_space wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_action_space(env)`"
self.rewards = {}
self.dones = {}
self.obs = {}
self.infos = {}
_ = self.reset()
def _init_dicts(self):
# initialize with zero
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
# initialize with False
self.dones = dict(zip(self.agents, [False for _ in self.agents]))
self.dones["__all__"] = False
# initialize with None info object
self.infos = dict(zip(self.agents, [{} for _ in self.agents]))
# initialize empty observations
self.obs = dict(zip(self.agents, [None for _ in self.agents]))
def reset(self):
"""
Resets the env and returns observations from ready agents.
Returns:
obs (dict): New observations for each ready agent.
"""
# 1. Reset environment; agent pointer points to first agent.
self.aec_env.reset(observe=False)
# 2. Copy agents from environment
self.agents = self.aec_env.agents
# 3. Reset dictionaries
self._init_dicts()
# 4. Get initial observations
for agent in self.agents:
# For each agent get initial observations
self.obs[agent] = self.aec_env.observe(agent)
return self.obs
def step(self, action_dict):
"""
Executes input actions from RL agents and returns observations from
environment agents.
The returns are dicts mapping from agent_id strings to values. The
number of agents in the env can vary over time.
Returns
-------
obs (dict): New observations for each ready agent.
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
"__all__" (required) is used to indicate env termination.
infos (dict): Optional info values for each agent id.
"""
env_done = False
# iterate over self.agents
for agent in self.agents:
# Execute only for agents that have not been done in previous steps
if agent in action_dict.keys():
if not env_done:
assert agent == self.aec_env.agent_selection, \
f"environment has a nontrivial ordering, and " \
"cannot be used with the POMGameEnv wrapper\"" \
"nCurrent agent: {self.aec_env.agent_selection}" \
"\nExpected agent: {agent}"
# Execute agent action in environment
self.obs[agent] = self.aec_env.step(
action_dict[agent], observe=True)
if all(self.aec_env.dones.values()):
env_done = True
self.dones["__all__"] = True
else:
self.obs[agent] = self.aec_env.observe(agent)
# Get reward
self.rewards[agent] = self.aec_env.rewards[agent]
# Update done status
self.dones[agent] = self.aec_env.dones[agent]
# For agents with done = True, remove from dones, rewards and
# observations.
else:
del self.dones[agent]
del self.rewards[agent]
del self.obs[agent]
del self.infos[agent]
# update self.agents
self.agents = list(action_dict.keys())
# Update infos stepwise
for agent in self.agents:
self.infos[agent] = self.aec_env.infos[agent]
return self.obs, self.rewards, self.dones, self.infos
def render(self, mode="human"):
self.aec_env.render(mode=mode)
def close(self):
self.aec_env.close()
def with_agent_groups(self, groups, obs_space=None, act_space=None):
raise NotImplementedError

View file

@ -0,0 +1,82 @@
from copy import deepcopy
import ray
try:
from ray.rllib.agents.agent import get_agent_class
except ImportError:
from ray.rllib.agents.registry import get_agent_class
from ray.tune.registry import register_env
from ray.rllib.env import PettingZooEnv
from pettingzoo.gamma import prison_v0
from supersuit.aec_wrappers import normalize_obs, dtype, color_reduction
from numpy import float32
if __name__ == "__main__":
"""For this script, you need:
1. Algorithm name and according module, e.g.: "PPo" + agents.ppo as agent
2. Name of the aec game you want to train on, e.g.: "prison".
3. num_cpus
4. num_rollouts
Does require SuperSuit
"""
alg_name = "PPO"
# function that outputs the environment you wish to register.
def env_creator(config):
env = prison_v0.env(num_floors=config.get("num_floors", 4))
env = dtype(env, dtype=float32)
env = color_reduction(env, dtype=float32)
env = normalize_obs(env, mode="R")
return env
num_cpus = 1
num_rollouts = 2
# 1. Gets default training configuration and specifies the POMgame to load.
config = deepcopy(get_agent_class(alg_name)._default_config)
# 2. Set environment config. This will be passed to
# the env_creator function via the register env lambda below
config["env_config"] = {"num_floors": 5}
# 3. Register env
register_env("prison", lambda config: PettingZooEnv(env_creator(config)))
# 4. Extract space dimensions
test_env = PettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
# 5. Configuration for multiagent setup with policy sharing:
config["multiagent"] = {
"policies": {
# the first tuple value is None -> uses default policy
"av": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id: "av"
}
config["log_level"] = "DEBUG"
config["num_workers"] = 1
# Fragment length, collected at once from each worker and for each agent!
config["sample_batch_size"] = 30
# Training batch size -> Fragments are concatenated up to this point.
config["train_batch_size"] = 200
# After n steps, force reset simulation
config["horizon"] = 200
# Default: False
config["no_done_at_end"] = False
# Info: If False, each agents trajectory is expected to have
# maximum one done=True in the last step of the trajectory.
# If no_done_at_end = True, environment is not resetted
# when dones[__all__]= True.
# 6. Initialize ray and trainer object
ray.init(num_cpus=num_cpus + 1)
trainer = get_agent_class(alg_name)(env="prison", config=config)
# 7. Train once
trainer.train()
test_env.reset()

View file

@ -0,0 +1,53 @@
import unittest
from copy import deepcopy
import ray
from ray.tune.registry import register_env
from ray.rllib.env import PettingZooEnv
from ray.rllib.agents.registry import get_agent_class
from pettingzoo.mpe import simple_spread_v0
class TestPettingZooEnv(unittest.TestCase):
def setUp(self) -> None:
ray.init()
def tearDown(self) -> None:
ray.shutdown()
def test_pettingzoo_env(self):
register_env("prison", lambda _: PettingZooEnv(simple_spread_v0.env()))
agent_class = get_agent_class("PPO")
config = deepcopy(agent_class._default_config)
test_env = PettingZooEnv(simple_spread_v0.env())
obs_space = test_env.observation_space
act_space = test_env.action_space
test_env.close()
config["multiagent"] = {
"policies": {
# the first tuple value is None -> uses default policy
"av": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id: "av"
}
config["log_level"] = "DEBUG"
config["num_workers"] = 0
config["rollout_fragment_length"] = 30
config["train_batch_size"] = 200
config["horizon"] = 200 # After n steps, force reset simulation
config["no_done_at_end"] = False
agent = agent_class(env="prison", config=config)
agent.train()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))