ray/rllib/examples/two_step_game.py

177 lines
5.2 KiB
Python

"""The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf
Configurations you can try:
- normal policy gradients (PG)
- MADDPG
- QMIX
See also: centralized_critic.py for centralized critic PPO on this game.
"""
import argparse
from gym.spaces import Dict, Discrete, Tuple, MultiDiscrete
import logging
import os
import ray
from ray import tune
from ray.tune import register_env
from ray.rllib.algorithms.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.examples.env.two_step_game import TwoStepGame
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.test_utils import check_learning_achieved
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument(
"--run", type=str, default="PG", help="The RLlib-registered algorithm to use."
)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--mixer",
type=str,
default="qmix",
choices=["qmix", "vdn", "none"],
help="The mixer model to use.",
)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
"--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=70000, help="Number of timesteps to train."
)
parser.add_argument(
"--stop-reward", type=float, default=8.0, help="Reward at which we stop training."
)
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.",
)
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
if args.run == "contrib/MADDPG":
logger.warning(
"`contrib/MADDPG` is not longer a valid algorithm descriptor! "
"Use `MADDPG` instead."
)
args.run = "MADDPG"
grouping = {
"group_1": [0, 1],
}
obs_space = Tuple(
[
Dict(
{
"obs": MultiDiscrete([2, 2, 2, 3]),
ENV_STATE: MultiDiscrete([2, 2, 2]),
}
),
Dict(
{
"obs": MultiDiscrete([2, 2, 2, 3]),
ENV_STATE: MultiDiscrete([2, 2, 2]),
}
),
]
)
act_space = Tuple(
[
TwoStepGame.action_space,
TwoStepGame.action_space,
]
)
register_env(
"grouped_twostep",
lambda config: TwoStepGame(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space
),
)
if args.run == "MADDPG":
obs_space = Discrete(6)
act_space = TwoStepGame.action_space
config = {
"env": TwoStepGame,
"env_config": {
"actions_are_logits": True,
},
"replay_buffer_config": {"learning_starts": 100},
"multiagent": {
"policies": {
"pol1": PolicySpec(
observation_space=obs_space,
action_space=act_space,
config={"agent_id": 0},
),
"pol2": PolicySpec(
observation_space=obs_space,
action_space=act_space,
config={"agent_id": 1},
),
},
"policy_mapping_fn": (lambda aid, **kwargs: "pol2" if aid else "pol1"),
},
"framework": args.framework,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}
elif args.run == "QMIX":
config = (
QMixConfig()
.training(mixer=args.mixer, train_batch_size=32)
.rollouts(num_rollout_workers=0, rollout_fragment_length=4)
.exploration(
exploration_config={
"final_epsilon": 0.0,
}
)
.environment(
env="grouped_twostep",
env_config={
"separate_state_space": True,
"one_hot_state_encoding": True,
},
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
config = config.to_dict()
else:
config = {
"env": TwoStepGame,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
}
stop = {
"episode_reward_mean": args.stop_reward,
"timesteps_total": args.stop_timesteps,
"training_iteration": args.stop_iters,
}
results = tune.run(args.run, stop=stop, config=config, verbose=2)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()