ray/python/ray/rllib/examples/multiagent_cartpole.py

71 lines
2.2 KiB
Python
Raw Normal View History

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Simple example of setting up a multi-agent policy mapping.
Control the number of agents and policies via --num-agents and --num-policies.
This works with hundreds of agents and policies, but note that initializing
many TF policy graphs will take some time.
Also, TF evals might slow down with large numbers of policies. To debug TF
execution, set the TF_TIMELINE_DIR environment variable.
"""
import argparse
import gym
import random
import ray
[rllib] Document "v2" APIs (#2316) * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * envs * vec * doc prep * models * rl * alg * up * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * merge * wip * fix up * move pg class * rename env * wip * update * tip * alg * readme * fix catalog * readme * doc * context * remove prep * comma * add env * link to paper * paper * update * rnn * update * wip * clean up ev creation * fix * fix * fix * fix lint * up * no comma * ma * Update run_multi_node_tests.sh * fix * sphinx is stupid * sphinx is stupid * clarify torch graph * no horizon * fix config * sb * Update test_optimizers.py
2018-07-01 00:05:08 -07:00
from ray.rllib.agents.pg.pg import PGAgent
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.test.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--num-iters", type=int, default=20)
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
# Simple environment with `num_agents` independent cartpole entities
register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents))
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
def gen_policy():
config = {
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
"n_step": random.choice([1, 2, 3, 4, 5]),
}
return (PGPolicyGraph, obs_space, act_space, config)
# Setup PG with an ensemble of `num_policies` different policy graphs
policy_graphs = {
"policy_{}".format(i): gen_policy() for i in range(args.num_policies)
}
policy_ids = list(policy_graphs.keys())
agent = PGAgent(
env="multi_cartpole",
config={
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": (
lambda agent_id: random.choice(policy_ids)),
},
})
for i in range(args.num_iters):
print("== Iteration", i, "==")
print(pretty_print(agent.train()))