mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
156 lines
4.9 KiB
Python
156 lines
4.9 KiB
Python
"""Example of using two different training methods at once in multi-agent.
|
|
|
|
Here we create a number of CartPole agents, some of which are trained with
|
|
DQN, and some of which are trained with PPO. We periodically sync weights
|
|
between the two trainers (note that no such syncing is needed when using just
|
|
a single training method).
|
|
|
|
For a simpler example, see also: multiagent_cartpole.py
|
|
"""
|
|
|
|
import argparse
|
|
import gym
|
|
import os
|
|
|
|
import ray
|
|
from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
|
|
from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy, PPOTorchPolicy
|
|
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
|
from ray.tune.logger import pretty_print
|
|
from ray.tune.registry import register_env
|
|
|
|
parser = argparse.ArgumentParser()
|
|
# Use torch for both policies.
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
default="tf",
|
|
help="The DL framework specifier.",
|
|
)
|
|
parser.add_argument(
|
|
"--as-test",
|
|
action="store_true",
|
|
help="Whether this script should be run as a test: --stop-reward must "
|
|
"be achieved within --stop-timesteps AND --stop-iters.",
|
|
)
|
|
parser.add_argument(
|
|
"--stop-iters", type=int, default=20, help="Number of iterations to train."
|
|
)
|
|
parser.add_argument(
|
|
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
|
|
)
|
|
parser.add_argument(
|
|
"--stop-reward", type=float, default=50.0, help="Reward at which we stop training."
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
ray.init()
|
|
|
|
# Simple environment with 4 independent cartpole entities
|
|
register_env(
|
|
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4})
|
|
)
|
|
single_dummy_env = gym.make("CartPole-v0")
|
|
obs_space = single_dummy_env.observation_space
|
|
act_space = single_dummy_env.action_space
|
|
|
|
# You can also have multiple policies per trainer, but here we just
|
|
# show one each for PPO and DQN.
|
|
policies = {
|
|
"ppo_policy": (
|
|
PPOTorchPolicy if args.framework == "torch" else PPOTFPolicy,
|
|
obs_space,
|
|
act_space,
|
|
{},
|
|
),
|
|
"dqn_policy": (
|
|
DQNTorchPolicy if args.framework == "torch" else DQNTFPolicy,
|
|
obs_space,
|
|
act_space,
|
|
{},
|
|
),
|
|
}
|
|
|
|
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
|
if agent_id % 2 == 0:
|
|
return "ppo_policy"
|
|
else:
|
|
return "dqn_policy"
|
|
|
|
ppo_trainer = PPOTrainer(
|
|
env="multi_agent_cartpole",
|
|
config={
|
|
"multiagent": {
|
|
"policies": policies,
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
"policies_to_train": ["ppo_policy"],
|
|
},
|
|
"model": {
|
|
"vf_share_layers": True,
|
|
},
|
|
"num_sgd_iter": 6,
|
|
"vf_loss_coeff": 0.01,
|
|
# disable filters, otherwise we would need to synchronize those
|
|
# as well to the DQN agent
|
|
"observation_filter": "MeanStdFilter",
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
|
"framework": args.framework,
|
|
},
|
|
)
|
|
|
|
dqn_trainer = DQNTrainer(
|
|
env="multi_agent_cartpole",
|
|
config={
|
|
"multiagent": {
|
|
"policies": policies,
|
|
"policy_mapping_fn": policy_mapping_fn,
|
|
"policies_to_train": ["dqn_policy"],
|
|
},
|
|
"model": {
|
|
"vf_share_layers": True,
|
|
},
|
|
"gamma": 0.95,
|
|
"n_step": 3,
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
|
"framework": args.framework,
|
|
},
|
|
)
|
|
|
|
# You should see both the printed X and Y approach 200 as this trains:
|
|
# info:
|
|
# policy_reward_mean:
|
|
# dqn_policy: X
|
|
# ppo_policy: Y
|
|
for i in range(args.stop_iters):
|
|
print("== Iteration", i, "==")
|
|
|
|
# improve the DQN policy
|
|
print("-- DQN --")
|
|
result_dqn = dqn_trainer.train()
|
|
print(pretty_print(result_dqn))
|
|
|
|
# improve the PPO policy
|
|
print("-- PPO --")
|
|
result_ppo = ppo_trainer.train()
|
|
print(pretty_print(result_ppo))
|
|
|
|
# Test passed gracefully.
|
|
if (
|
|
args.as_test
|
|
and result_dqn["episode_reward_mean"] > args.stop_reward
|
|
and result_ppo["episode_reward_mean"] > args.stop_reward
|
|
):
|
|
print("test passed (both agents above requested reward)")
|
|
quit(0)
|
|
|
|
# swap weights to synchronize
|
|
dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
|
|
ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))
|
|
|
|
# Desired reward not reached.
|
|
if args.as_test:
|
|
raise ValueError("Desired reward ({}) not reached!".format(args.stop_reward))
|