mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Pettingzoo environment support (#9271)
* added pettingzoo wrapper env and example * added docs, examples for pettingzoo env support * fixed pettingzoo env flake8, added test * fixed pettingzoo env import * fixed pettingzoo env import * fixed pettingzoo import issue * fixed pettingzoo test * fixed linting problem * fixed bad quotes * future proofed pettingzoo dependency * fixed ray init in pettingzoo env * lint * manual lint Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
parent
b42d6a1ddc
commit
1425cdf834
7 changed files with 375 additions and 1 deletions
|
@ -232,7 +232,7 @@ install_dependencies() {
|
|||
opencv-python-headless pyyaml pandas==1.0.5 requests feather-format lxml openpyxl xlrd \
|
||||
py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \
|
||||
kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \
|
||||
Pillow prometheus_client boto3)
|
||||
Pillow prometheus_client boto3 pettingzoo)
|
||||
if [ "${OSTYPE}" != msys ]; then
|
||||
# These packages aren't Windows-compatible
|
||||
pip_packages+=(blist) # https://github.com/DanielStutzbach/blist/issues/81#issue-391460716
|
||||
|
|
|
@ -203,6 +203,29 @@ Here is a simple `example training script <https://github.com/ray-project/ray/bl
|
|||
|
||||
To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``.
|
||||
|
||||
|
||||
PettingZoo Multi-Agent Environments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
`PettingZoo <https://github.com/PettingZoo-Team/PettingZoo>`__ is a repository of over 50 diverse multi-agent environments. However, the API is note directly compatible with rllib, but it can be converted into an rllib MultiAgentEnv like in this example
|
||||
|
||||
.. code-block:: python
|
||||
from ray.tune.registry import register_env
|
||||
# import the pettingzoo environment
|
||||
from pettingzoo.gamma import prison_v0
|
||||
# import rllib pettingzoo interface
|
||||
from ray.rllib.env import PettingZooEnv
|
||||
# define how to make the environment. This way takes an optinoal environment config, num_floors
|
||||
env_creator = lambda config: prison_v0.env(num_floors=config.get("num_floors", 4))
|
||||
# register that way to make the environment under an rllib name
|
||||
register_env('prison', lambda config: PettingZooEnv(env_creator(config)))
|
||||
# now you can use `prison` as an environment
|
||||
# you can pass arguments to the environment creator with the env_config option in the config
|
||||
config['env_config'] = {"num_floors": 5}
|
||||
|
||||
A more complete example is here: `pettingzoo_env.py <https://github.com/ray-project/ray/blob/master/rllib/examples/pettingzoo_env.py>`__
|
||||
|
||||
|
||||
Rock Paper Scissors Example
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -1310,6 +1310,13 @@ py_test(
|
|||
args = ["TestSupportedMultiAgentOffPolicy"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_pettingzoo_env",
|
||||
tags = ["tests_dir", "tests_dir_S"],
|
||||
size = "medium",
|
||||
srcs = ["tests/test_pettingzoo_env.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_supported_spaces",
|
||||
tags = ["tests_dir", "tests_dir_S"],
|
||||
|
|
2
rllib/env/__init__.py
vendored
2
rllib/env/__init__.py
vendored
|
@ -2,6 +2,7 @@ from ray.rllib.env.base_env import BaseEnv
|
|||
from ray.rllib.env.dm_env_wrapper import DMEnv
|
||||
from ray.rllib.env.dm_control_wrapper import DMCEnv
|
||||
from ray.rllib.env.unity3d_env import Unity3DEnv
|
||||
from ray.rllib.env.pettingzoo_env import PettingZooEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
|
@ -20,6 +21,7 @@ __all__ = [
|
|||
"DMEnv",
|
||||
"DMCEnv",
|
||||
"Unity3DEnv",
|
||||
"PettingZooEnv",
|
||||
"PolicyClient",
|
||||
"PolicyServerInput",
|
||||
]
|
||||
|
|
207
rllib/env/pettingzoo_env.py
vendored
Normal file
207
rllib/env/pettingzoo_env.py
vendored
Normal file
|
@ -0,0 +1,207 @@
|
|||
from .multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
class PettingZooEnv(MultiAgentEnv):
|
||||
"""An interface to the PettingZoo MARL environment library.
|
||||
See: https://github.com/PettingZoo-Team/PettingZoo
|
||||
|
||||
Inherits from MultiAgentEnv and exposes a given AEC
|
||||
(actor-environment-cycle) game from the PettingZoo project via the
|
||||
MultiAgentEnv public API.
|
||||
|
||||
It reduces the class of AEC games to Partially Observable Markov (POM)
|
||||
games by imposing the following important restrictions onto an AEC
|
||||
environment:
|
||||
|
||||
1. Each agent steps in order specified in agents list (unless they are
|
||||
done, in which case, they should be skipped).
|
||||
2. Agents act simultaneously (-> No hard-turn games like chess).
|
||||
3. All agents have the same action_spaces and observation_spaces.
|
||||
Note: If, within your aec game, agents do not have homogeneous action /
|
||||
observation spaces, apply SuperSuit wrappers
|
||||
to apply padding functionality: https://github.com/PettingZoo-Team/
|
||||
SuperSuit#built-in-multi-agent-only-functions
|
||||
4. Environments are positive sum games (-> Agents are expected to cooperate
|
||||
to maximize reward). This isn't a hard restriction, it just that
|
||||
standard algorithms aren't expected to work well in highly competitive
|
||||
games.
|
||||
|
||||
Examples:
|
||||
>>> from pettingzoo.gamma import prison_v0
|
||||
>>> env = POMGameEnv(env_creator=prison_v0})
|
||||
>>> obs = env.reset()
|
||||
>>> print(obs)
|
||||
|
||||
{
|
||||
"0": [110, 119],
|
||||
"1": [105, 102],
|
||||
"2": [99, 95],
|
||||
}
|
||||
>>> obs, rewards, dones, infos = env.step(
|
||||
action_dict={
|
||||
"0": 1, "1": 0, "2": 2,
|
||||
})
|
||||
>>> print(rewards)
|
||||
{
|
||||
"0": 0,
|
||||
"1": 1,
|
||||
"2": 0,
|
||||
}
|
||||
>>> print(dones)
|
||||
{
|
||||
"0": False, # agent 0 is still running
|
||||
"1": True, # agent 1 is done
|
||||
"__all__": False, # the env is not done
|
||||
}
|
||||
>>> print(infos)
|
||||
{
|
||||
"0": {}, # info for agent 0
|
||||
"1": {}, # info for agent 1
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
"""
|
||||
Parameters:
|
||||
-----------
|
||||
env: AECenv object.
|
||||
"""
|
||||
self.aec_env = env
|
||||
|
||||
# agent idx list
|
||||
self.agents = self.aec_env.agents
|
||||
|
||||
# Get dictionaries of obs_spaces and act_spaces
|
||||
self.observation_spaces = self.aec_env.observation_spaces
|
||||
self.action_spaces = self.aec_env.action_spaces
|
||||
|
||||
# Get first observation space, assuming all agents have equal space
|
||||
self.observation_space = self.observation_spaces[self.agents[0]]
|
||||
|
||||
# Get first action space, assuming all agents have equal space
|
||||
self.action_space = self.action_spaces[self.agents[0]]
|
||||
|
||||
assert all(obs_space == self.observation_space
|
||||
for obs_space
|
||||
in self.aec_env.observation_spaces.values()), \
|
||||
"Observation spaces for all agents must be identical. Perhaps " \
|
||||
"SuperSuit's pad_observations wrapper can help (useage: " \
|
||||
"`supersuit.aec_wrappers.pad_observations(env)`"
|
||||
|
||||
assert all(act_space == self.action_space
|
||||
for act_space in self.aec_env.action_spaces.values()), \
|
||||
"Action spaces for all agents must be identical. Perhaps " \
|
||||
"SuperSuit's pad_action_space wrapper can help (useage: " \
|
||||
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
||||
|
||||
self.rewards = {}
|
||||
self.dones = {}
|
||||
self.obs = {}
|
||||
self.infos = {}
|
||||
|
||||
_ = self.reset()
|
||||
|
||||
def _init_dicts(self):
|
||||
# initialize with zero
|
||||
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
|
||||
# initialize with False
|
||||
self.dones = dict(zip(self.agents, [False for _ in self.agents]))
|
||||
self.dones["__all__"] = False
|
||||
|
||||
# initialize with None info object
|
||||
self.infos = dict(zip(self.agents, [{} for _ in self.agents]))
|
||||
|
||||
# initialize empty observations
|
||||
self.obs = dict(zip(self.agents, [None for _ in self.agents]))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the env and returns observations from ready agents.
|
||||
|
||||
Returns:
|
||||
obs (dict): New observations for each ready agent.
|
||||
"""
|
||||
# 1. Reset environment; agent pointer points to first agent.
|
||||
self.aec_env.reset(observe=False)
|
||||
|
||||
# 2. Copy agents from environment
|
||||
self.agents = self.aec_env.agents
|
||||
|
||||
# 3. Reset dictionaries
|
||||
self._init_dicts()
|
||||
|
||||
# 4. Get initial observations
|
||||
for agent in self.agents:
|
||||
|
||||
# For each agent get initial observations
|
||||
self.obs[agent] = self.aec_env.observe(agent)
|
||||
|
||||
return self.obs
|
||||
|
||||
def step(self, action_dict):
|
||||
"""
|
||||
Executes input actions from RL agents and returns observations from
|
||||
environment agents.
|
||||
|
||||
The returns are dicts mapping from agent_id strings to values. The
|
||||
number of agents in the env can vary over time.
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs (dict): New observations for each ready agent.
|
||||
rewards (dict): Reward values for each ready agent. If the
|
||||
episode is just started, the value will be None.
|
||||
dones (dict): Done values for each ready agent. The special key
|
||||
"__all__" (required) is used to indicate env termination.
|
||||
infos (dict): Optional info values for each agent id.
|
||||
"""
|
||||
env_done = False
|
||||
# iterate over self.agents
|
||||
for agent in self.agents:
|
||||
|
||||
# Execute only for agents that have not been done in previous steps
|
||||
if agent in action_dict.keys():
|
||||
if not env_done:
|
||||
assert agent == self.aec_env.agent_selection, \
|
||||
f"environment has a nontrivial ordering, and " \
|
||||
"cannot be used with the POMGameEnv wrapper\"" \
|
||||
"nCurrent agent: {self.aec_env.agent_selection}" \
|
||||
"\nExpected agent: {agent}"
|
||||
# Execute agent action in environment
|
||||
self.obs[agent] = self.aec_env.step(
|
||||
action_dict[agent], observe=True)
|
||||
if all(self.aec_env.dones.values()):
|
||||
env_done = True
|
||||
self.dones["__all__"] = True
|
||||
else:
|
||||
self.obs[agent] = self.aec_env.observe(agent)
|
||||
# Get reward
|
||||
self.rewards[agent] = self.aec_env.rewards[agent]
|
||||
# Update done status
|
||||
self.dones[agent] = self.aec_env.dones[agent]
|
||||
|
||||
# For agents with done = True, remove from dones, rewards and
|
||||
# observations.
|
||||
else:
|
||||
del self.dones[agent]
|
||||
del self.rewards[agent]
|
||||
del self.obs[agent]
|
||||
del self.infos[agent]
|
||||
|
||||
# update self.agents
|
||||
self.agents = list(action_dict.keys())
|
||||
|
||||
# Update infos stepwise
|
||||
for agent in self.agents:
|
||||
self.infos[agent] = self.aec_env.infos[agent]
|
||||
|
||||
return self.obs, self.rewards, self.dones, self.infos
|
||||
|
||||
def render(self, mode="human"):
|
||||
self.aec_env.render(mode=mode)
|
||||
|
||||
def close(self):
|
||||
self.aec_env.close()
|
||||
|
||||
def with_agent_groups(self, groups, obs_space=None, act_space=None):
|
||||
raise NotImplementedError
|
82
rllib/examples/pettingzoo_env.py
Normal file
82
rllib/examples/pettingzoo_env.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from copy import deepcopy
|
||||
import ray
|
||||
try:
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
except ImportError:
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.env import PettingZooEnv
|
||||
from pettingzoo.gamma import prison_v0
|
||||
from supersuit.aec_wrappers import normalize_obs, dtype, color_reduction
|
||||
|
||||
from numpy import float32
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""For this script, you need:
|
||||
1. Algorithm name and according module, e.g.: "PPo" + agents.ppo as agent
|
||||
2. Name of the aec game you want to train on, e.g.: "prison".
|
||||
3. num_cpus
|
||||
4. num_rollouts
|
||||
|
||||
Does require SuperSuit
|
||||
"""
|
||||
alg_name = "PPO"
|
||||
|
||||
# function that outputs the environment you wish to register.
|
||||
def env_creator(config):
|
||||
env = prison_v0.env(num_floors=config.get("num_floors", 4))
|
||||
env = dtype(env, dtype=float32)
|
||||
env = color_reduction(env, dtype=float32)
|
||||
env = normalize_obs(env, mode="R")
|
||||
return env
|
||||
|
||||
num_cpus = 1
|
||||
num_rollouts = 2
|
||||
|
||||
# 1. Gets default training configuration and specifies the POMgame to load.
|
||||
config = deepcopy(get_agent_class(alg_name)._default_config)
|
||||
|
||||
# 2. Set environment config. This will be passed to
|
||||
# the env_creator function via the register env lambda below
|
||||
config["env_config"] = {"num_floors": 5}
|
||||
|
||||
# 3. Register env
|
||||
register_env("prison", lambda config: PettingZooEnv(env_creator(config)))
|
||||
|
||||
# 4. Extract space dimensions
|
||||
test_env = PettingZooEnv(env_creator({}))
|
||||
obs_space = test_env.observation_space
|
||||
act_space = test_env.action_space
|
||||
|
||||
# 5. Configuration for multiagent setup with policy sharing:
|
||||
config["multiagent"] = {
|
||||
"policies": {
|
||||
# the first tuple value is None -> uses default policy
|
||||
"av": (None, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id: "av"
|
||||
}
|
||||
|
||||
config["log_level"] = "DEBUG"
|
||||
config["num_workers"] = 1
|
||||
# Fragment length, collected at once from each worker and for each agent!
|
||||
config["sample_batch_size"] = 30
|
||||
# Training batch size -> Fragments are concatenated up to this point.
|
||||
config["train_batch_size"] = 200
|
||||
# After n steps, force reset simulation
|
||||
config["horizon"] = 200
|
||||
# Default: False
|
||||
config["no_done_at_end"] = False
|
||||
# Info: If False, each agents trajectory is expected to have
|
||||
# maximum one done=True in the last step of the trajectory.
|
||||
# If no_done_at_end = True, environment is not resetted
|
||||
# when dones[__all__]= True.
|
||||
|
||||
# 6. Initialize ray and trainer object
|
||||
ray.init(num_cpus=num_cpus + 1)
|
||||
trainer = get_agent_class(alg_name)(env="prison", config=config)
|
||||
|
||||
# 7. Train once
|
||||
trainer.train()
|
||||
|
||||
test_env.reset()
|
53
rllib/tests/test_pettingzoo_env.py
Normal file
53
rllib/tests/test_pettingzoo_env.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.env import PettingZooEnv
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
from pettingzoo.mpe import simple_spread_v0
|
||||
|
||||
|
||||
class TestPettingZooEnv(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
ray.init()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_pettingzoo_env(self):
|
||||
register_env("prison", lambda _: PettingZooEnv(simple_spread_v0.env()))
|
||||
|
||||
agent_class = get_agent_class("PPO")
|
||||
|
||||
config = deepcopy(agent_class._default_config)
|
||||
|
||||
test_env = PettingZooEnv(simple_spread_v0.env())
|
||||
obs_space = test_env.observation_space
|
||||
act_space = test_env.action_space
|
||||
test_env.close()
|
||||
|
||||
config["multiagent"] = {
|
||||
"policies": {
|
||||
# the first tuple value is None -> uses default policy
|
||||
"av": (None, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id: "av"
|
||||
}
|
||||
|
||||
config["log_level"] = "DEBUG"
|
||||
config["num_workers"] = 0
|
||||
config["rollout_fragment_length"] = 30
|
||||
config["train_batch_size"] = 200
|
||||
config["horizon"] = 200 # After n steps, force reset simulation
|
||||
config["no_done_at_end"] = False
|
||||
|
||||
agent = agent_class(env="prison", config=config)
|
||||
agent.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue