ray/rllib/examples/multi_agent_cartpole.py
Sven Mika 57544b1ff9
[RLlib] Examples folder restructuring (Model examples; final part). (#8278)
- This PR completes any previously missing PyTorch Model counterparts to TFModels in examples/models.
- It also makes sure, all example scripts in the rllib/examples folder are tested for both frameworks and learn the given task (this is often currently not checked) using a --as-test flag in connection with a --stop-reward.
2020-05-12 08:23:10 +02:00

96 lines
3.2 KiB
Python

"""Simple example of setting up a multi-agent policy mapping.
Control the number of agents and policies via --num-agents and --num-policies.
This works with hundreds of agents and policies, but note that initializing
many TF policies will take some time.
Also, TF evals might slow down with large numbers of policies. To debug TF
execution, set the TF_TIMELINE_DIR environment variable.
"""
import argparse
import gym
import random
import ray
from ray import tune
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.models.shared_weights_model import \
SharedWeightsModel1, SharedWeightsModel2, TorchSharedWeightsModel
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--stop-iters", type=int, default=20)
parser.add_argument("--stop-reward", type=float, default=150)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--simple", action="store_true")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--torch", action="store_true")
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
# Register the models to use.
mod1 = TorchSharedWeightsModel if args.torch else SharedWeightsModel1
mod2 = TorchSharedWeightsModel if args.torch else SharedWeightsModel2
ModelCatalog.register_custom_model("model1", mod1)
ModelCatalog.register_custom_model("model2", mod2)
# Get obs- and action Spaces.
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
# Each policy can have a different configuration (including custom model).
def gen_policy(i):
config = {
"model": {
"custom_model": ["model1", "model2"][i % 2],
},
"gamma": random.choice([0.95, 0.99]),
}
return (None, obs_space, act_space, config)
# Setup PPO with an ensemble of `num_policies` different policies.
policies = {
"policy_{}".format(i): gen_policy(i)
for i in range(args.num_policies)
}
policy_ids = list(policies.keys())
config = {
"env": MultiAgentCartPole,
"env_config": {
"num_agents": args.num_agents,
},
"log_level": "DEBUG",
"simple_optimizer": args.simple,
"num_sgd_iter": 10,
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id: random.choice(policy_ids)),
},
"use_pytorch": args.torch,
}
stop = {
"episode_reward_mean": args.stop_reward,
"timesteps_total": args.stop_timesteps,
"training_iteration": args.stop_iters,
}
results = tune.run("PPO", stop=stop, config=config)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()