2020-05-01 22:59:34 +02:00
|
|
|
import gym
|
|
|
|
|
|
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
2020-12-09 01:41:45 +01:00
|
|
|
from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2
|
2020-07-29 21:15:09 +02:00
|
|
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
2020-05-01 22:59:34 +02:00
|
|
|
|
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
def make_multiagent(env_name_or_creator):
|
2020-05-01 22:59:34 +02:00
|
|
|
class MultiEnv(MultiAgentEnv):
|
|
|
|
def __init__(self, config):
|
2020-05-30 22:48:34 +02:00
|
|
|
num = config.pop("num_agents", 1)
|
|
|
|
if isinstance(env_name_or_creator, str):
|
|
|
|
self.agents = [
|
|
|
|
gym.make(env_name_or_creator) for _ in range(num)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
self.agents = [env_name_or_creator(config) for _ in range(num)]
|
2020-05-01 22:59:34 +02:00
|
|
|
self.dones = set()
|
|
|
|
self.observation_space = self.agents[0].observation_space
|
|
|
|
self.action_space = self.agents[0].action_space
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.dones = set()
|
|
|
|
return {i: a.reset() for i, a in enumerate(self.agents)}
|
|
|
|
|
|
|
|
def step(self, action_dict):
|
|
|
|
obs, rew, done, info = {}, {}, {}, {}
|
|
|
|
for i, action in action_dict.items():
|
|
|
|
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
|
|
|
if done[i]:
|
|
|
|
self.dones.add(i)
|
|
|
|
done["__all__"] = len(self.dones) == len(self.agents)
|
|
|
|
return obs, rew, done, info
|
|
|
|
|
|
|
|
return MultiEnv
|
|
|
|
|
|
|
|
|
|
|
|
class BasicMultiAgent(MultiAgentEnv):
|
|
|
|
"""Env of N independent agents, each of which exits after 25 steps."""
|
|
|
|
|
|
|
|
def __init__(self, num):
|
|
|
|
self.agents = [MockEnv(25) for _ in range(num)]
|
|
|
|
self.dones = set()
|
|
|
|
self.observation_space = gym.spaces.Discrete(2)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
self.resetted = False
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.resetted = True
|
|
|
|
self.dones = set()
|
|
|
|
return {i: a.reset() for i, a in enumerate(self.agents)}
|
|
|
|
|
|
|
|
def step(self, action_dict):
|
|
|
|
obs, rew, done, info = {}, {}, {}, {}
|
|
|
|
for i, action in action_dict.items():
|
|
|
|
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
|
|
|
if done[i]:
|
|
|
|
self.dones.add(i)
|
|
|
|
done["__all__"] = len(self.dones) == len(self.agents)
|
|
|
|
return obs, rew, done, info
|
|
|
|
|
|
|
|
|
|
|
|
class EarlyDoneMultiAgent(MultiAgentEnv):
|
|
|
|
"""Env for testing when the env terminates (after agent 0 does)."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.agents = [MockEnv(3), MockEnv(5)]
|
|
|
|
self.dones = set()
|
|
|
|
self.last_obs = {}
|
|
|
|
self.last_rew = {}
|
|
|
|
self.last_done = {}
|
|
|
|
self.last_info = {}
|
|
|
|
self.i = 0
|
|
|
|
self.observation_space = gym.spaces.Discrete(10)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.dones = set()
|
|
|
|
self.last_obs = {}
|
|
|
|
self.last_rew = {}
|
|
|
|
self.last_done = {}
|
|
|
|
self.last_info = {}
|
|
|
|
self.i = 0
|
|
|
|
for i, a in enumerate(self.agents):
|
|
|
|
self.last_obs[i] = a.reset()
|
|
|
|
self.last_rew[i] = None
|
|
|
|
self.last_done[i] = False
|
|
|
|
self.last_info[i] = {}
|
|
|
|
obs_dict = {self.i: self.last_obs[self.i]}
|
|
|
|
self.i = (self.i + 1) % len(self.agents)
|
|
|
|
return obs_dict
|
|
|
|
|
|
|
|
def step(self, action_dict):
|
|
|
|
assert len(self.dones) != len(self.agents)
|
|
|
|
for i, action in action_dict.items():
|
|
|
|
(self.last_obs[i], self.last_rew[i], self.last_done[i],
|
|
|
|
self.last_info[i]) = self.agents[i].step(action)
|
|
|
|
obs = {self.i: self.last_obs[self.i]}
|
|
|
|
rew = {self.i: self.last_rew[self.i]}
|
|
|
|
done = {self.i: self.last_done[self.i]}
|
|
|
|
info = {self.i: self.last_info[self.i]}
|
|
|
|
if done[self.i]:
|
|
|
|
rew[self.i] = 0
|
|
|
|
self.dones.add(self.i)
|
|
|
|
self.i = (self.i + 1) % len(self.agents)
|
|
|
|
done["__all__"] = len(self.dones) == len(self.agents) - 1
|
|
|
|
return obs, rew, done, info
|
|
|
|
|
|
|
|
|
|
|
|
class RoundRobinMultiAgent(MultiAgentEnv):
|
|
|
|
"""Env of N independent agents, each of which exits after 5 steps.
|
|
|
|
|
|
|
|
On each step() of the env, only one agent takes an action."""
|
|
|
|
|
|
|
|
def __init__(self, num, increment_obs=False):
|
|
|
|
if increment_obs:
|
|
|
|
# Observations are 0, 1, 2, 3... etc. as time advances
|
|
|
|
self.agents = [MockEnv2(5) for _ in range(num)]
|
|
|
|
else:
|
|
|
|
# Observations are all zeros
|
|
|
|
self.agents = [MockEnv(5) for _ in range(num)]
|
|
|
|
self.dones = set()
|
|
|
|
self.last_obs = {}
|
|
|
|
self.last_rew = {}
|
|
|
|
self.last_done = {}
|
|
|
|
self.last_info = {}
|
|
|
|
self.i = 0
|
|
|
|
self.num = num
|
|
|
|
self.observation_space = gym.spaces.Discrete(10)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.dones = set()
|
|
|
|
self.last_obs = {}
|
|
|
|
self.last_rew = {}
|
|
|
|
self.last_done = {}
|
|
|
|
self.last_info = {}
|
|
|
|
self.i = 0
|
|
|
|
for i, a in enumerate(self.agents):
|
|
|
|
self.last_obs[i] = a.reset()
|
|
|
|
self.last_rew[i] = None
|
|
|
|
self.last_done[i] = False
|
|
|
|
self.last_info[i] = {}
|
|
|
|
obs_dict = {self.i: self.last_obs[self.i]}
|
|
|
|
self.i = (self.i + 1) % self.num
|
|
|
|
return obs_dict
|
|
|
|
|
|
|
|
def step(self, action_dict):
|
|
|
|
assert len(self.dones) != len(self.agents)
|
|
|
|
for i, action in action_dict.items():
|
|
|
|
(self.last_obs[i], self.last_rew[i], self.last_done[i],
|
|
|
|
self.last_info[i]) = self.agents[i].step(action)
|
|
|
|
obs = {self.i: self.last_obs[self.i]}
|
|
|
|
rew = {self.i: self.last_rew[self.i]}
|
|
|
|
done = {self.i: self.last_done[self.i]}
|
|
|
|
info = {self.i: self.last_info[self.i]}
|
|
|
|
if done[self.i]:
|
|
|
|
rew[self.i] = 0
|
|
|
|
self.dones.add(self.i)
|
|
|
|
self.i = (self.i + 1) % self.num
|
|
|
|
done["__all__"] = len(self.dones) == len(self.agents)
|
|
|
|
return obs, rew, done, info
|
|
|
|
|
|
|
|
|
|
|
|
MultiAgentCartPole = make_multiagent("CartPole-v0")
|
|
|
|
MultiAgentMountainCar = make_multiagent("MountainCarContinuous-v0")
|
|
|
|
MultiAgentPendulum = make_multiagent("Pendulum-v0")
|
2020-07-29 21:15:09 +02:00
|
|
|
MultiAgentStatelessCartPole = make_multiagent(
|
|
|
|
lambda config: StatelessCartPole(config))
|