2020-09-04 01:37:46 -04:00
|
|
|
from ray import tune
|
|
|
|
from ray.tune.registry import register_env
|
2021-01-19 10:09:39 +01:00
|
|
|
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
2021-01-06 03:14:54 +05:30
|
|
|
from pettingzoo.sisl import waterworld_v2
|
2020-09-04 01:37:46 -04:00
|
|
|
|
|
|
|
# Based on code from github.com/parametersharingmadrl/parametersharingmadrl
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# RDQN - Rainbow DQN
|
|
|
|
# ADQN - Apex DQN
|
|
|
|
def env_creator(args):
|
2021-01-06 03:14:54 +05:30
|
|
|
return PettingZooEnv(waterworld_v2.env())
|
2020-09-04 01:37:46 -04:00
|
|
|
|
|
|
|
env = env_creator({})
|
|
|
|
register_env("waterworld", env_creator)
|
|
|
|
|
|
|
|
obs_space = env.observation_space
|
|
|
|
act_spc = env.action_space
|
|
|
|
|
|
|
|
policies = {agent: (None, obs_space, act_spc, {}) for agent in env.agents}
|
|
|
|
|
|
|
|
tune.run(
|
|
|
|
"APEX_DDPG",
|
|
|
|
stop={"episodes_total": 60000},
|
|
|
|
checkpoint_freq=10,
|
|
|
|
config={
|
|
|
|
# Enviroment specific
|
|
|
|
"env": "waterworld",
|
|
|
|
# General
|
|
|
|
"num_gpus": 1,
|
|
|
|
"num_workers": 2,
|
|
|
|
# Method specific
|
|
|
|
"multiagent": {
|
|
|
|
"policies": policies,
|
2021-06-18 12:21:49 -07:00
|
|
|
"policy_mapping_fn": (lambda agent_id: agent_id),
|
2020-09-04 01:37:46 -04:00
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|