2018-06-25 22:33:57 -07:00
|
|
|
"""Simple example of setting up a multi-agent policy mapping.
|
|
|
|
|
|
|
|
Control the number of agents and policies via --num-agents and --num-policies.
|
|
|
|
|
|
|
|
This works with hundreds of agents and policies, but note that initializing
|
2019-05-20 16:46:05 -07:00
|
|
|
many TF policies will take some time.
|
2018-06-25 22:33:57 -07:00
|
|
|
|
|
|
|
Also, TF evals might slow down with large numbers of policies. To debug TF
|
|
|
|
execution, set the TF_TIMELINE_DIR environment variable.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import gym
|
|
|
|
import random
|
|
|
|
|
|
|
|
import ray
|
2018-08-07 16:29:21 -07:00
|
|
|
from ray import tune
|
2020-04-29 12:12:59 +02:00
|
|
|
from ray.rllib.models import ModelCatalog
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
2019-03-02 13:37:16 -08:00
|
|
|
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
2018-06-25 22:33:57 -07:00
|
|
|
from ray.tune.registry import register_env
|
2019-05-16 22:12:07 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
2020-04-29 12:12:59 +02:00
|
|
|
from ray.rllib.utils.annotations import override
|
2019-05-16 22:12:07 -07:00
|
|
|
|
|
|
|
tf = try_import_tf()
|
2018-06-25 22:33:57 -07:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
parser.add_argument("--num-agents", type=int, default=4)
|
|
|
|
parser.add_argument("--num-policies", type=int, default=2)
|
|
|
|
parser.add_argument("--num-iters", type=int, default=20)
|
2019-05-27 17:24:45 -07:00
|
|
|
parser.add_argument("--simple", action="store_true")
|
2020-02-15 23:50:44 +01:00
|
|
|
parser.add_argument("--num-cpus", type=int, default=0)
|
2018-06-25 22:33:57 -07:00
|
|
|
|
2018-10-29 19:37:27 -07:00
|
|
|
|
2020-04-29 12:12:59 +02:00
|
|
|
class CustomModel1(TFModelV2):
|
|
|
|
def __init__(self, observation_space, action_space, num_outputs,
|
|
|
|
model_config, name):
|
|
|
|
super().__init__(observation_space, action_space, num_outputs,
|
|
|
|
model_config, name)
|
|
|
|
|
|
|
|
inputs = tf.keras.layers.Input(observation_space.shape)
|
2018-10-29 19:37:27 -07:00
|
|
|
# Example of (optional) weight sharing between two different policies.
|
|
|
|
# Here, we share the variables defined in the 'shared' variable scope
|
|
|
|
# by entering it explicitly with tf.AUTO_REUSE. This creates the
|
|
|
|
# variables for the 'fc1' layer in a global scope called 'shared'
|
|
|
|
# outside of the policy's normal variable scope.
|
|
|
|
with tf.variable_scope(
|
|
|
|
tf.VariableScope(tf.AUTO_REUSE, "shared"),
|
|
|
|
reuse=tf.AUTO_REUSE,
|
|
|
|
auxiliary_name_scope=False):
|
2020-04-29 12:12:59 +02:00
|
|
|
last_layer = tf.keras.layers.Dense(
|
|
|
|
units=64, activation=tf.nn.relu, name="fc1")(inputs)
|
|
|
|
output = tf.keras.layers.Dense(
|
|
|
|
units=num_outputs, activation=None, name="fc_out")(last_layer)
|
|
|
|
vf = tf.keras.layers.Dense(
|
|
|
|
units=1, activation=None, name="value_out")(last_layer)
|
|
|
|
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
|
|
|
self.register_variables(self.base_model.variables)
|
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
|
|
out, self._value_out = self.base_model(input_dict["obs"])
|
|
|
|
return out, []
|
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def value_function(self):
|
|
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
|
|
|
|
|
|
class CustomModel2(TFModelV2):
|
|
|
|
def __init__(self, observation_space, action_space, num_outputs,
|
|
|
|
model_config, name):
|
|
|
|
super().__init__(observation_space, action_space, num_outputs,
|
|
|
|
model_config, name)
|
|
|
|
|
|
|
|
inputs = tf.keras.layers.Input(observation_space.shape)
|
|
|
|
|
|
|
|
# Weights shared with CustomModel1.
|
2018-10-29 19:37:27 -07:00
|
|
|
with tf.variable_scope(
|
|
|
|
tf.VariableScope(tf.AUTO_REUSE, "shared"),
|
|
|
|
reuse=tf.AUTO_REUSE,
|
|
|
|
auxiliary_name_scope=False):
|
2020-04-29 12:12:59 +02:00
|
|
|
last_layer = tf.keras.layers.Dense(
|
|
|
|
units=64, activation=tf.nn.relu, name="fc1")(inputs)
|
|
|
|
output = tf.keras.layers.Dense(
|
|
|
|
units=num_outputs, activation=None, name="fc_out")(last_layer)
|
|
|
|
vf = tf.keras.layers.Dense(
|
|
|
|
units=1, activation=None, name="value_out")(last_layer)
|
|
|
|
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
|
|
|
self.register_variables(self.base_model.variables)
|
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
|
|
out, self._value_out = self.base_model(input_dict["obs"])
|
|
|
|
return out, []
|
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def value_function(self):
|
|
|
|
return tf.reshape(self._value_out, [-1])
|
2018-10-29 19:37:27 -07:00
|
|
|
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parser.parse_args()
|
2020-02-15 23:50:44 +01:00
|
|
|
ray.init(num_cpus=args.num_cpus or None)
|
2018-06-25 22:33:57 -07:00
|
|
|
|
|
|
|
# Simple environment with `num_agents` independent cartpole entities
|
|
|
|
register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents))
|
2018-10-29 19:37:27 -07:00
|
|
|
ModelCatalog.register_custom_model("model1", CustomModel1)
|
|
|
|
ModelCatalog.register_custom_model("model2", CustomModel2)
|
2018-06-25 22:33:57 -07:00
|
|
|
single_env = gym.make("CartPole-v0")
|
|
|
|
obs_space = single_env.observation_space
|
|
|
|
act_space = single_env.action_space
|
|
|
|
|
2018-10-29 19:37:27 -07:00
|
|
|
# Each policy can have a different configuration (including custom model)
|
|
|
|
def gen_policy(i):
|
2018-06-25 22:33:57 -07:00
|
|
|
config = {
|
2018-10-29 19:37:27 -07:00
|
|
|
"model": {
|
|
|
|
"custom_model": ["model1", "model2"][i % 2],
|
|
|
|
},
|
2018-11-13 18:00:03 -08:00
|
|
|
"gamma": random.choice([0.95, 0.99]),
|
2018-06-25 22:33:57 -07:00
|
|
|
}
|
2019-03-25 11:38:17 -07:00
|
|
|
return (None, obs_space, act_space, config)
|
2018-06-25 22:33:57 -07:00
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
# Setup PPO with an ensemble of `num_policies` different policies
|
|
|
|
policies = {
|
2018-10-29 19:37:27 -07:00
|
|
|
"policy_{}".format(i): gen_policy(i)
|
2018-07-19 15:30:36 -07:00
|
|
|
for i in range(args.num_policies)
|
2018-06-25 22:33:57 -07:00
|
|
|
}
|
2019-05-20 16:46:05 -07:00
|
|
|
policy_ids = list(policies.keys())
|
2018-06-25 22:33:57 -07:00
|
|
|
|
2019-03-30 14:07:50 -07:00
|
|
|
tune.run(
|
|
|
|
"PPO",
|
|
|
|
stop={"training_iteration": args.num_iters},
|
|
|
|
config={
|
2018-08-07 16:29:21 -07:00
|
|
|
"env": "multi_cartpole",
|
2019-03-30 14:07:50 -07:00
|
|
|
"log_level": "DEBUG",
|
2019-05-27 17:24:45 -07:00
|
|
|
"simple_optimizer": args.simple,
|
2019-03-30 14:07:50 -07:00
|
|
|
"num_sgd_iter": 10,
|
|
|
|
"multiagent": {
|
2019-05-20 16:46:05 -07:00
|
|
|
"policies": policies,
|
2019-08-31 16:00:10 -07:00
|
|
|
"policy_mapping_fn": (
|
2019-03-30 14:07:50 -07:00
|
|
|
lambda agent_id: random.choice(policy_ids)),
|
2018-06-25 22:33:57 -07:00
|
|
|
},
|
2019-03-30 14:07:50 -07:00
|
|
|
},
|
|
|
|
)
|