ray/rllib/examples/iterated_prisoners_dilemma_env.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

86 lines
2.4 KiB
Python

##########
# Contribution by the Center on Long-Term Risk:
# https://github.com/longtermrisk/marltoolbox
##########
import argparse
import os
import ray
from ray import tune
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.examples.env.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
parser.add_argument("--stop-iters", type=int, default=200)
def main(debug, stop_iters=200, tf=False):
train_n_replicates = 1 if debug else 1
seeds = list(range(train_n_replicates))
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
tune_analysis = tune.run(
PGTrainer,
config=rllib_config,
stop=stop_config,
checkpoint_freq=0,
checkpoint_at_end=True,
name="PG_IPD",
)
ray.shutdown()
return tune_analysis
def get_rllib_config(seeds, debug=False, stop_iters=200, tf=False):
stop_config = {
"training_iteration": 2 if debug else stop_iters,
}
env_config = {
"players_ids": ["player_row", "player_col"],
"max_steps": 20,
"get_additional_info": True,
}
rllib_config = {
"env": IteratedPrisonersDilemma,
"env_config": env_config,
"multiagent": {
"policies": {
env_config["players_ids"][0]: (
None,
IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE,
{},
),
env_config["players_ids"][1]: (
None,
IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE,
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
},
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
}
return rllib_config, stop_config
if __name__ == "__main__":
debug_mode = True
args = parser.parse_args()
main(debug_mode, args.stop_iters, args.tf)