ray/rllib/examples/coin_game_env.py

93 lines
2.7 KiB
Python
Raw Normal View History

##########
# Contribution by the Center on Long-Term Risk:
# https://github.com/longtermrisk/marltoolbox
##########
import argparse
import os
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.examples.env.coin_game_non_vectorized_env import CoinGame, AsymCoinGame
parser = argparse.ArgumentParser()
parser.add_argument("--tf", action="store_true")
parser.add_argument("--stop-iters", type=int, default=2000)
def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
train_n_replicates = 1 if debug else 1
seeds = list(range(train_n_replicates))
ray.init()
stop = {
"training_iteration": 2 if debug else stop_iters,
}
env_config = {
"players_ids": ["player_red", "player_blue"],
"max_steps": 20,
"grid_size": 3,
"get_additional_info": True,
}
rllib_config = {
"env": AsymCoinGame if asymmetric_env else CoinGame,
"env_config": env_config,
"multiagent": {
"policies": {
env_config["players_ids"][0]: (
None,
AsymCoinGame(env_config).OBSERVATION_SPACE,
AsymCoinGame.ACTION_SPACE,
{},
),
env_config["players_ids"][1]: (
None,
AsymCoinGame(env_config).OBSERVATION_SPACE,
AsymCoinGame.ACTION_SPACE,
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
},
# Size of batches collected from each worker.
"rollout_fragment_length": 20,
# Number of timesteps collected for each SGD round.
# This defines the size of each SGD epoch.
"train_batch_size": 512,
"model": {
"dim": env_config["grid_size"],
"conv_filters": [
[16, [3, 3], 1],
[32, [3, 3], 1],
], # [Channel, [Kernel, Kernel], Stride]]
},
"lr": 5e-3,
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "tf" if tf else "torch",
}
tuner = tune.Tuner(
PPO,
param_space=rllib_config,
run_config=air.RunConfig(
name="PPO_AsymCG",
stop=stop,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=0, checkpoint_at_end=True
),
),
)
tuner.fit()
ray.shutdown()
if __name__ == "__main__":
args = parser.parse_args()
debug_mode = True
use_asymmetric_env = False
main(debug_mode, args.stop_iters, args.tf, use_asymmetric_env)