2018-07-22 05:09:25 -07:00
|
|
|
"""Example of using two different training methods at once in multi-agent.
|
|
|
|
|
|
|
|
Here we create a number of CartPole agents, some of which are trained with
|
|
|
|
DQN, and some of which are trained with PPO. We periodically sync weights
|
|
|
|
between the two trainers (note that no such syncing is needed when using just
|
|
|
|
a single training method).
|
|
|
|
|
|
|
|
For a simpler example, see also: multiagent_cartpole.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import gym
|
2020-10-02 23:07:44 +02:00
|
|
|
import os
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
import ray
|
2020-05-12 08:23:10 +02:00
|
|
|
from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
|
|
|
|
from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy, PPOTorchPolicy
|
2020-05-01 22:59:34 +02:00
|
|
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
2018-07-22 05:09:25 -07:00
|
|
|
from ray.tune.logger import pretty_print
|
|
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
2020-05-12 08:23:10 +02:00
|
|
|
# Use torch for both policies.
|
|
|
|
parser.add_argument("--torch", action="store_true")
|
|
|
|
parser.add_argument("--as-test", action="store_true")
|
|
|
|
parser.add_argument("--stop-iters", type=int, default=20)
|
|
|
|
parser.add_argument("--stop-reward", type=float, default=50)
|
|
|
|
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parser.parse_args()
|
2020-05-12 08:23:10 +02:00
|
|
|
|
2018-07-22 05:09:25 -07:00
|
|
|
ray.init()
|
|
|
|
|
|
|
|
# Simple environment with 4 independent cartpole entities
|
2020-05-01 22:59:34 +02:00
|
|
|
register_env("multi_agent_cartpole",
|
|
|
|
lambda _: MultiAgentCartPole({"num_agents": 4}))
|
2020-10-02 23:07:44 +02:00
|
|
|
single_dummy_env = gym.make("CartPole-v0")
|
|
|
|
obs_space = single_dummy_env.observation_space
|
|
|
|
act_space = single_dummy_env.action_space
|
2018-07-22 05:09:25 -07:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
# You can also have multiple policies per trainer, but here we just
|
2018-07-22 05:09:25 -07:00
|
|
|
# show one each for PPO and DQN.
|
2019-05-20 16:46:05 -07:00
|
|
|
policies = {
|
2020-05-12 08:23:10 +02:00
|
|
|
"ppo_policy": (PPOTorchPolicy if args.torch else PPOTFPolicy,
|
|
|
|
obs_space, act_space, {}),
|
2021-02-25 23:07:05 +01:00
|
|
|
"dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy,
|
|
|
|
obs_space, act_space, {}),
|
2018-07-22 05:09:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
def policy_mapping_fn(agent_id):
|
|
|
|
if agent_id % 2 == 0:
|
|
|
|
return "ppo_policy"
|
|
|
|
else:
|
|
|
|
return "dqn_policy"
|
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
ppo_trainer = PPOTrainer(
|
2020-05-01 22:59:34 +02:00
|
|
|
env="multi_agent_cartpole",
|
2018-07-22 05:09:25 -07:00
|
|
|
config={
|
|
|
|
"multiagent": {
|
2019-05-20 16:46:05 -07:00
|
|
|
"policies": policies,
|
2018-07-22 05:09:25 -07:00
|
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
|
|
"policies_to_train": ["ppo_policy"],
|
|
|
|
},
|
2020-02-19 21:18:45 +01:00
|
|
|
"explore": False,
|
2018-07-22 05:09:25 -07:00
|
|
|
# disable filters, otherwise we would need to synchronize those
|
|
|
|
# as well to the DQN agent
|
|
|
|
"observation_filter": "NoFilter",
|
2020-10-02 23:07:44 +02:00
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2020-05-27 16:19:13 +02:00
|
|
|
"framework": "torch" if args.torch else "tf",
|
2018-07-22 05:09:25 -07:00
|
|
|
})
|
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
dqn_trainer = DQNTrainer(
|
2020-05-01 22:59:34 +02:00
|
|
|
env="multi_agent_cartpole",
|
2018-07-22 05:09:25 -07:00
|
|
|
config={
|
|
|
|
"multiagent": {
|
2019-05-20 16:46:05 -07:00
|
|
|
"policies": policies,
|
2018-07-22 05:09:25 -07:00
|
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
|
|
"policies_to_train": ["dqn_policy"],
|
|
|
|
},
|
|
|
|
"gamma": 0.95,
|
|
|
|
"n_step": 3,
|
2020-10-02 23:07:44 +02:00
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2021-02-25 23:07:05 +01:00
|
|
|
"framework": "torch" if args.torch else "tf"
|
2018-07-22 05:09:25 -07:00
|
|
|
})
|
|
|
|
|
|
|
|
# You should see both the printed X and Y approach 200 as this trains:
|
|
|
|
# info:
|
|
|
|
# policy_reward_mean:
|
|
|
|
# dqn_policy: X
|
|
|
|
# ppo_policy: Y
|
2020-05-12 08:23:10 +02:00
|
|
|
for i in range(args.stop_iters):
|
2018-07-22 05:09:25 -07:00
|
|
|
print("== Iteration", i, "==")
|
|
|
|
|
|
|
|
# improve the DQN policy
|
|
|
|
print("-- DQN --")
|
2020-05-12 08:23:10 +02:00
|
|
|
result_dqn = dqn_trainer.train()
|
|
|
|
print(pretty_print(result_dqn))
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
# improve the PPO policy
|
|
|
|
print("-- PPO --")
|
2020-05-12 08:23:10 +02:00
|
|
|
result_ppo = ppo_trainer.train()
|
|
|
|
print(pretty_print(result_ppo))
|
|
|
|
|
|
|
|
# Test passed gracefully.
|
|
|
|
if args.as_test and \
|
|
|
|
result_dqn["episode_reward_mean"] > args.stop_reward and \
|
|
|
|
result_ppo["episode_reward_mean"] > args.stop_reward:
|
|
|
|
print("test passed (both agents above requested reward)")
|
|
|
|
quit(0)
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
# swap weights to synchronize
|
|
|
|
dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
|
|
|
|
ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))
|
2020-05-12 08:23:10 +02:00
|
|
|
|
|
|
|
# Desired reward not reached.
|
|
|
|
if args.as_test:
|
|
|
|
raise ValueError("Desired reward ({}) not reached!".format(
|
|
|
|
args.stop_reward))
|