ray/rllib/examples/iterated_prisoners_dilemma_env.py
xwjiang2010 fcf897ee72
[air] update rllib example to use Tuner API. (#26987)
update rllib example to use Tuner API.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
2022-07-27 12:12:59 +01:00

90 lines
2.5 KiB
Python

##########
# Contribution by the Center on Long-Term Risk:
# https://github.com/longtermrisk/marltoolbox
##########
import argparse
import os
import ray
from ray import air, tune
from ray.rllib.algorithms.pg import PG
from ray.rllib.examples.env.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
parser.add_argument("--stop-iters", type=int, default=200)
def main(debug, stop_iters=200, framework="tf"):
train_n_replicates = 1 if debug else 1
seeds = list(range(train_n_replicates))
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, framework)
tuner = tune.Tuner(
PG,
param_space=rllib_config,
run_config=air.RunConfig(
name="PG_IPD",
stop=stop_config,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=0,
checkpoint_at_end=True,
),
),
)
tuner.fit()
ray.shutdown()
def get_rllib_config(seeds, debug=False, stop_iters=200, framework="tf"):
stop_config = {
"training_iteration": 2 if debug else stop_iters,
}
env_config = {
"players_ids": ["player_row", "player_col"],
"max_steps": 20,
"get_additional_info": True,
}
rllib_config = {
"env": IteratedPrisonersDilemma,
"env_config": env_config,
"multiagent": {
"policies": {
env_config["players_ids"][0]: (
None,
IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE,
{},
),
env_config["players_ids"][1]: (
None,
IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE,
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
},
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": framework,
}
return rllib_config, stop_config
if __name__ == "__main__":
debug_mode = True
args = parser.parse_args()
main(debug_mode, args.stop_iters, args.framework)