ray/rllib/examples/twostep_game.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

229 lines
7 KiB
Python

"""The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf
Configurations you can try:
- normal policy gradients (PG)
- contrib/MADDPG
- QMIX
- APEX_QMIX
See also: centralized_critic.py for centralized critic PPO on this game.
"""
import argparse
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
import numpy as np
import ray
from ray import tune
from ray.tune import register_env, grid_search
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.agents.qmix.qmix_policy import ENV_STATE
parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=50000)
parser.add_argument("--run", type=str, default="PG")
class TwoStepGame(MultiAgentEnv):
action_space = Discrete(2)
def __init__(self, env_config):
self.state = None
self.agent_1 = 0
self.agent_2 = 1
# MADDPG emits action logits instead of actual discrete actions
self.actions_are_logits = env_config.get("actions_are_logits", False)
self.one_hot_state_encoding = env_config.get("one_hot_state_encoding",
False)
self.with_state = env_config.get("separate_state_space", False)
if not self.one_hot_state_encoding:
self.observation_space = Discrete(6)
self.with_state = False
else:
# Each agent gets the full state (one-hot encoding of which of the
# three states are active) as input with the receiving agent's
# ID (1 or 2) concatenated onto the end.
if self.with_state:
self.observation_space = Dict({
"obs": MultiDiscrete([2, 2, 2, 3]),
ENV_STATE: MultiDiscrete([2, 2, 2])
})
else:
self.observation_space = MultiDiscrete([2, 2, 2, 3])
def reset(self):
self.state = np.array([1, 0, 0])
return self._obs()
def step(self, action_dict):
if self.actions_are_logits:
action_dict = {
k: np.random.choice([0, 1], p=v)
for k, v in action_dict.items()
}
state_index = np.flatnonzero(self.state)
if state_index == 0:
action = action_dict[self.agent_1]
assert action in [0, 1], action
if action == 0:
self.state = np.array([0, 1, 0])
else:
self.state = np.array([0, 0, 1])
global_rew = 0
done = False
elif state_index == 1:
global_rew = 7
done = True
else:
if action_dict[self.agent_1] == 0 and action_dict[self.
agent_2] == 0:
global_rew = 0
elif action_dict[self.agent_1] == 1 and action_dict[self.
agent_2] == 1:
global_rew = 8
else:
global_rew = 1
done = True
rewards = {
self.agent_1: global_rew / 2.0,
self.agent_2: global_rew / 2.0
}
obs = self._obs()
dones = {"__all__": done}
infos = {}
return obs, rewards, dones, infos
def _obs(self):
if self.with_state:
return {
self.agent_1: {
"obs": self.agent_1_obs(),
ENV_STATE: self.state
},
self.agent_2: {
"obs": self.agent_2_obs(),
ENV_STATE: self.state
}
}
else:
return {
self.agent_1: self.agent_1_obs(),
self.agent_2: self.agent_2_obs()
}
def agent_1_obs(self):
if self.one_hot_state_encoding:
return np.concatenate([self.state, [1]])
else:
return np.flatnonzero(self.state)[0]
def agent_2_obs(self):
if self.one_hot_state_encoding:
return np.concatenate([self.state, [2]])
else:
return np.flatnonzero(self.state)[0] + 3
if __name__ == "__main__":
args = parser.parse_args()
grouping = {
"group_1": [0, 1],
}
obs_space = Tuple([
Dict({
"obs": MultiDiscrete([2, 2, 2, 3]),
ENV_STATE: MultiDiscrete([2, 2, 2])
}),
Dict({
"obs": MultiDiscrete([2, 2, 2, 3]),
ENV_STATE: MultiDiscrete([2, 2, 2])
}),
])
act_space = Tuple([
TwoStepGame.action_space,
TwoStepGame.action_space,
])
register_env(
"grouped_twostep",
lambda config: TwoStepGame(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space))
if args.run == "contrib/MADDPG":
obs_space_dict = {
"agent_1": Discrete(6),
"agent_2": Discrete(6),
}
act_space_dict = {
"agent_1": TwoStepGame.action_space,
"agent_2": TwoStepGame.action_space,
}
config = {
"learning_starts": 100,
"env_config": {
"actions_are_logits": True,
},
"multiagent": {
"policies": {
"pol1": (None, Discrete(6), TwoStepGame.action_space, {
"agent_id": 0,
}),
"pol2": (None, Discrete(6), TwoStepGame.action_space, {
"agent_id": 1,
}),
},
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
},
}
group = False
elif args.run == "QMIX":
config = {
"sample_batch_size": 4,
"train_batch_size": 32,
"exploration_fraction": .4,
"exploration_final_eps": 0.0,
"num_workers": 0,
"mixer": grid_search([None, "qmix", "vdn"]),
"env_config": {
"separate_state_space": True,
"one_hot_state_encoding": True
},
}
group = True
elif args.run == "APEX_QMIX":
config = {
"num_gpus": 0,
"num_workers": 2,
"optimizer": {
"num_replay_buffer_shards": 1,
},
"min_iter_time_s": 3,
"buffer_size": 1000,
"learning_starts": 1000,
"train_batch_size": 128,
"sample_batch_size": 32,
"target_network_update_freq": 500,
"timesteps_per_iteration": 1000,
"env_config": {
"separate_state_space": True,
"one_hot_state_encoding": True
},
}
group = True
else:
config = {}
group = False
ray.init()
tune.run(
args.run,
stop={
"timesteps_total": args.stop,
},
config=dict(config, **{
"env": "grouped_twostep" if group else TwoStepGame,
}),
)