2018-07-22 05:09:25 -07:00
|
|
|
"""Example of using two different training methods at once in multi-agent.
|
|
|
|
|
|
|
|
Here we create a number of CartPole agents, some of which are trained with
|
|
|
|
DQN, and some of which are trained with PPO. We periodically sync weights
|
|
|
|
between the two trainers (note that no such syncing is needed when using just
|
|
|
|
a single training method).
|
|
|
|
|
|
|
|
For a simpler example, see also: multiagent_cartpole.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import gym
|
2020-10-02 23:07:44 +02:00
|
|
|
import os
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
import ray
|
2020-05-12 08:23:10 +02:00
|
|
|
from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
|
|
|
|
from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy, PPOTorchPolicy
|
2020-05-01 22:59:34 +02:00
|
|
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
2018-07-22 05:09:25 -07:00
|
|
|
from ray.tune.logger import pretty_print
|
|
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
2020-05-12 08:23:10 +02:00
|
|
|
# Use torch for both policies.
|
2021-05-18 13:18:12 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--framework",
|
|
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
|
|
default="tf",
|
|
|
|
help="The DL framework specifier.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--as-test",
|
|
|
|
action="store_true",
|
|
|
|
help="Whether this script should be run as a test: --stop-reward must "
|
|
|
|
"be achieved within --stop-timesteps AND --stop-iters.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--stop-iters",
|
|
|
|
type=int,
|
|
|
|
default=20,
|
|
|
|
help="Number of iterations to train.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--stop-timesteps",
|
|
|
|
type=int,
|
|
|
|
default=100000,
|
|
|
|
help="Number of timesteps to train.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--stop-reward",
|
|
|
|
type=float,
|
|
|
|
default=50.0,
|
|
|
|
help="Reward at which we stop training.")
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parser.parse_args()
|
2020-05-12 08:23:10 +02:00
|
|
|
|
2018-07-22 05:09:25 -07:00
|
|
|
ray.init()
|
|
|
|
|
|
|
|
# Simple environment with 4 independent cartpole entities
|
2020-05-01 22:59:34 +02:00
|
|
|
register_env("multi_agent_cartpole",
|
|
|
|
lambda _: MultiAgentCartPole({"num_agents": 4}))
|
2020-10-02 23:07:44 +02:00
|
|
|
single_dummy_env = gym.make("CartPole-v0")
|
|
|
|
obs_space = single_dummy_env.observation_space
|
|
|
|
act_space = single_dummy_env.action_space
|
2018-07-22 05:09:25 -07:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
# You can also have multiple policies per trainer, but here we just
|
2018-07-22 05:09:25 -07:00
|
|
|
# show one each for PPO and DQN.
|
2019-05-20 16:46:05 -07:00
|
|
|
policies = {
|
2021-05-18 13:18:12 +02:00
|
|
|
"ppo_policy": (PPOTorchPolicy if args.framework == "torch" else
|
|
|
|
PPOTFPolicy, obs_space, act_space, {}),
|
|
|
|
"dqn_policy": (DQNTorchPolicy if args.framework == "torch" else
|
|
|
|
DQNTFPolicy, obs_space, act_space, {}),
|
2018-07-22 05:09:25 -07:00
|
|
|
}
|
|
|
|
|
2021-06-21 13:46:01 +02:00
|
|
|
def policy_mapping_fn(agent_id, episode, **kwargs):
|
2018-07-22 05:09:25 -07:00
|
|
|
if agent_id % 2 == 0:
|
|
|
|
return "ppo_policy"
|
|
|
|
else:
|
|
|
|
return "dqn_policy"
|
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
ppo_trainer = PPOTrainer(
|
2020-05-01 22:59:34 +02:00
|
|
|
env="multi_agent_cartpole",
|
2018-07-22 05:09:25 -07:00
|
|
|
config={
|
|
|
|
"multiagent": {
|
2019-05-20 16:46:05 -07:00
|
|
|
"policies": policies,
|
2018-07-22 05:09:25 -07:00
|
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
|
|
"policies_to_train": ["ppo_policy"],
|
|
|
|
},
|
2021-03-03 14:33:03 +01:00
|
|
|
"model": {
|
|
|
|
"vf_share_layers": True,
|
|
|
|
},
|
|
|
|
"num_sgd_iter": 6,
|
|
|
|
"vf_loss_coeff": 0.01,
|
2018-07-22 05:09:25 -07:00
|
|
|
# disable filters, otherwise we would need to synchronize those
|
|
|
|
# as well to the DQN agent
|
2021-03-03 14:33:03 +01:00
|
|
|
"observation_filter": "MeanStdFilter",
|
2020-10-02 23:07:44 +02:00
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2021-05-18 13:18:12 +02:00
|
|
|
"framework": args.framework,
|
2018-07-22 05:09:25 -07:00
|
|
|
})
|
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
dqn_trainer = DQNTrainer(
|
2020-05-01 22:59:34 +02:00
|
|
|
env="multi_agent_cartpole",
|
2018-07-22 05:09:25 -07:00
|
|
|
config={
|
|
|
|
"multiagent": {
|
2019-05-20 16:46:05 -07:00
|
|
|
"policies": policies,
|
2018-07-22 05:09:25 -07:00
|
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
|
|
"policies_to_train": ["dqn_policy"],
|
|
|
|
},
|
2021-03-03 14:33:03 +01:00
|
|
|
"model": {
|
|
|
|
"vf_share_layers": True,
|
|
|
|
},
|
2018-07-22 05:09:25 -07:00
|
|
|
"gamma": 0.95,
|
|
|
|
"n_step": 3,
|
2020-10-02 23:07:44 +02:00
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2021-05-18 13:18:12 +02:00
|
|
|
"framework": args.framework,
|
2018-07-22 05:09:25 -07:00
|
|
|
})
|
|
|
|
|
|
|
|
# You should see both the printed X and Y approach 200 as this trains:
|
|
|
|
# info:
|
|
|
|
# policy_reward_mean:
|
|
|
|
# dqn_policy: X
|
|
|
|
# ppo_policy: Y
|
2020-05-12 08:23:10 +02:00
|
|
|
for i in range(args.stop_iters):
|
2018-07-22 05:09:25 -07:00
|
|
|
print("== Iteration", i, "==")
|
|
|
|
|
|
|
|
# improve the DQN policy
|
|
|
|
print("-- DQN --")
|
2020-05-12 08:23:10 +02:00
|
|
|
result_dqn = dqn_trainer.train()
|
|
|
|
print(pretty_print(result_dqn))
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
# improve the PPO policy
|
|
|
|
print("-- PPO --")
|
2020-05-12 08:23:10 +02:00
|
|
|
result_ppo = ppo_trainer.train()
|
|
|
|
print(pretty_print(result_ppo))
|
|
|
|
|
|
|
|
# Test passed gracefully.
|
|
|
|
if args.as_test and \
|
|
|
|
result_dqn["episode_reward_mean"] > args.stop_reward and \
|
|
|
|
result_ppo["episode_reward_mean"] > args.stop_reward:
|
|
|
|
print("test passed (both agents above requested reward)")
|
|
|
|
quit(0)
|
2018-07-22 05:09:25 -07:00
|
|
|
|
|
|
|
# swap weights to synchronize
|
|
|
|
dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
|
|
|
|
ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))
|
2020-05-12 08:23:10 +02:00
|
|
|
|
|
|
|
# Desired reward not reached.
|
|
|
|
if args.as_test:
|
|
|
|
raise ValueError("Desired reward ({}) not reached!".format(
|
|
|
|
args.stop_reward))
|