2021-03-09 17:26:20 +01:00
|
|
|
##########
|
|
|
|
# Contribution by the Center on Long-Term Risk:
|
|
|
|
# https://github.com/longtermrisk/marltoolbox
|
|
|
|
##########
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray import tune
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.pg import PGTrainer
|
2021-03-09 17:26:20 +01:00
|
|
|
from ray.rllib.examples.env.matrix_sequential_social_dilemma import (
|
|
|
|
IteratedPrisonersDilemma,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-03-09 17:26:20 +01:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
2021-05-18 13:18:12 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--framework",
|
|
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
|
|
default="tf",
|
|
|
|
help="The DL framework specifier.",
|
|
|
|
)
|
2021-03-09 17:26:20 +01:00
|
|
|
parser.add_argument("--stop-iters", type=int, default=200)
|
|
|
|
|
|
|
|
|
2022-04-05 08:36:20 +02:00
|
|
|
def main(debug, stop_iters=200, framework="tf"):
|
2021-03-09 17:26:20 +01:00
|
|
|
train_n_replicates = 1 if debug else 1
|
|
|
|
seeds = list(range(train_n_replicates))
|
|
|
|
|
|
|
|
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
|
|
|
|
|
2022-04-05 08:36:20 +02:00
|
|
|
rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, framework)
|
2021-03-09 17:26:20 +01:00
|
|
|
tune_analysis = tune.run(
|
|
|
|
PGTrainer,
|
|
|
|
config=rllib_config,
|
|
|
|
stop=stop_config,
|
|
|
|
checkpoint_freq=0,
|
|
|
|
checkpoint_at_end=True,
|
|
|
|
name="PG_IPD",
|
|
|
|
)
|
|
|
|
ray.shutdown()
|
|
|
|
return tune_analysis
|
|
|
|
|
|
|
|
|
2022-04-05 08:36:20 +02:00
|
|
|
def get_rllib_config(seeds, debug=False, stop_iters=200, framework="tf"):
|
2021-03-09 17:26:20 +01:00
|
|
|
stop_config = {
|
|
|
|
"training_iteration": 2 if debug else stop_iters,
|
|
|
|
}
|
|
|
|
|
|
|
|
env_config = {
|
|
|
|
"players_ids": ["player_row", "player_col"],
|
|
|
|
"max_steps": 20,
|
|
|
|
"get_additional_info": True,
|
|
|
|
}
|
|
|
|
|
|
|
|
rllib_config = {
|
|
|
|
"env": IteratedPrisonersDilemma,
|
|
|
|
"env_config": env_config,
|
|
|
|
"multiagent": {
|
|
|
|
"policies": {
|
|
|
|
env_config["players_ids"][0]: (
|
|
|
|
None,
|
|
|
|
IteratedPrisonersDilemma.OBSERVATION_SPACE,
|
|
|
|
IteratedPrisonersDilemma.ACTION_SPACE,
|
|
|
|
{},
|
|
|
|
),
|
|
|
|
env_config["players_ids"][1]: (
|
|
|
|
None,
|
|
|
|
IteratedPrisonersDilemma.OBSERVATION_SPACE,
|
|
|
|
IteratedPrisonersDilemma.ACTION_SPACE,
|
|
|
|
{},
|
|
|
|
),
|
|
|
|
},
|
2021-06-21 13:46:01 +02:00
|
|
|
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
|
2021-03-09 17:26:20 +01:00
|
|
|
},
|
|
|
|
"seed": tune.grid_search(seeds),
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2022-04-05 08:36:20 +02:00
|
|
|
"framework": framework,
|
2021-03-09 17:26:20 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return rllib_config, stop_config
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
debug_mode = True
|
|
|
|
args = parser.parse_args()
|
2022-04-05 08:36:20 +02:00
|
|
|
main(debug_mode, args.stop_iters, args.framework)
|