ray/rllib/examples/env/multi_agent.py

205 lines
6.8 KiB
Python

import gym
import random
from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.utils.annotations import Deprecated
@Deprecated(
old="ray.rllib.examples.env.multi_agent.make_multiagent",
new="ray.rllib.env.multi_agent_env.make_multi_agent",
error=False)
def make_multiagent(env_name_or_creator):
return make_multi_agent(env_name_or_creator)
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""
def __init__(self, num):
self.agents = [MockEnv(25) for _ in range(num)]
self.dones = set()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
self.resetted = False
def reset(self):
self.resetted = True
self.dones = set()
return {i: a.reset() for i, a in enumerate(self.agents)}
def step(self, action_dict):
obs, rew, done, info = {}, {}, {}, {}
for i, action in action_dict.items():
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
if done[i]:
self.dones.add(i)
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
class EarlyDoneMultiAgent(MultiAgentEnv):
"""Env for testing when the env terminates (after agent 0 does)."""
def __init__(self):
self.agents = [MockEnv(3), MockEnv(5)]
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
self.observation_space = gym.spaces.Discrete(10)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset()
self.last_rew[i] = None
self.last_done[i] = False
self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]}
self.i = (self.i + 1) % len(self.agents)
return obs_dict
def step(self, action_dict):
assert len(self.dones) != len(self.agents)
for i, action in action_dict.items():
(self.last_obs[i], self.last_rew[i], self.last_done[i],
self.last_info[i]) = self.agents[i].step(action)
obs = {self.i: self.last_obs[self.i]}
rew = {self.i: self.last_rew[self.i]}
done = {self.i: self.last_done[self.i]}
info = {self.i: self.last_info[self.i]}
if done[self.i]:
rew[self.i] = 0
self.dones.add(self.i)
self.i = (self.i + 1) % len(self.agents)
done["__all__"] = len(self.dones) == len(self.agents) - 1
return obs, rew, done, info
class FlexAgentsMultiAgent(MultiAgentEnv):
"""Env of independent agents, each of which exits after n steps."""
def __init__(self):
self.agents = {}
self.agentID = 0
self.dones = set()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
self.resetted = False
def spawn(self):
# Spawn a new agent into the current episode.
agentID = self.agentID
self.agents[agentID] = MockEnv(25)
self.agentID += 1
return agentID
def reset(self):
self.agents = {}
self.spawn()
self.resetted = True
self.dones = set()
obs = {}
for i, a in self.agents.items():
obs[i] = a.reset()
return obs
def step(self, action_dict):
obs, rew, done, info = {}, {}, {}, {}
# Apply the actions.
for i, action in action_dict.items():
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
if done[i]:
self.dones.add(i)
# Sometimes, add a new agent to the episode.
if random.random() > 0.75:
i = self.spawn()
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
if done[i]:
self.dones.add(i)
# Sometimes, kill an existing agent.
if len(self.agents) > 1 and random.random() > 0.25:
keys = list(self.agents.keys())
key = random.choice(keys)
done[key] = True
del self.agents[key]
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
class RoundRobinMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 5 steps.
On each step() of the env, only one agent takes an action."""
def __init__(self, num, increment_obs=False):
if increment_obs:
# Observations are 0, 1, 2, 3... etc. as time advances
self.agents = [MockEnv2(5) for _ in range(num)]
else:
# Observations are all zeros
self.agents = [MockEnv(5) for _ in range(num)]
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
self.num = num
self.observation_space = gym.spaces.Discrete(10)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.dones = set()
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset()
self.last_rew[i] = None
self.last_done[i] = False
self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]}
self.i = (self.i + 1) % self.num
return obs_dict
def step(self, action_dict):
assert len(self.dones) != len(self.agents)
for i, action in action_dict.items():
(self.last_obs[i], self.last_rew[i], self.last_done[i],
self.last_info[i]) = self.agents[i].step(action)
obs = {self.i: self.last_obs[self.i]}
rew = {self.i: self.last_rew[self.i]}
done = {self.i: self.last_done[self.i]}
info = {self.i: self.last_info[self.i]}
if done[self.i]:
rew[self.i] = 0
self.dones.add(self.i)
self.i = (self.i + 1) % self.num
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
MultiAgentCartPole = make_multi_agent("CartPole-v0")
MultiAgentMountainCar = make_multi_agent("MountainCarContinuous-v0")
MultiAgentPendulum = make_multi_agent("Pendulum-v0")
MultiAgentStatelessCartPole = make_multi_agent(
lambda config: StatelessCartPole(config))