########## # Contribution by the Center on Long-Term Risk: # https://github.com/longtermrisk/marltoolbox ########## import random from ray.rllib.examples.env.matrix_sequential_social_dilemma import ( IteratedPrisonersDilemma, IteratedChicken, IteratedStagHunt, IteratedBoS, ) ENVS = [IteratedPrisonersDilemma, IteratedChicken, IteratedStagHunt, IteratedBoS] def test_reset(): max_steps = 20 env_all = [init_env(max_steps, env_class) for env_class in ENVS] for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) def init_env(max_steps, env_class, seed=None): config = { "max_steps": max_steps, } env = env_class(config) env.seed(seed) return env def check_obs(obs, env): assert len(obs) == 2, "two players" for key, player_obs in obs.items(): assert isinstance(player_obs, int) # .shape == (env.NUM_STATES) assert player_obs < env.NUM_STATES def assert_logger_buffer_size_two_players(env, n_steps): assert len(env.cc_count) == n_steps assert len(env.dd_count) == n_steps assert len(env.cd_count) == n_steps assert len(env.dc_count) == n_steps def test_step(): max_steps = 20 env_all = [init_env(max_steps, env_class) for env_class in ENVS] for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) actions = { policy_id: random.randint(0, env.NUM_ACTIONS - 1) for policy_id in env.players_ids } obs, reward, done, info = env.step(actions) check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=1) assert not done["__all__"] def test_multiple_steps(): max_steps = 20 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 0.75) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) for step_i in range(1, n_steps, 1): actions = { policy_id: random.randint(0, env.NUM_ACTIONS - 1) for policy_id in env.players_ids } obs, reward, done, info = env.step(actions) check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=step_i) assert not done["__all__"] def test_multiple_episodes(): max_steps = 20 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) step_i = 0 for _ in range(n_steps): step_i += 1 actions = { policy_id: random.randint(0, env.NUM_ACTIONS - 1) for policy_id in env.players_ids } obs, reward, done, info = env.step(actions) check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=step_i) assert not done["__all__"] or (step_i == max_steps and done["__all__"]) if done["__all__"]: obs = env.reset() check_obs(obs, env) step_i = 0 def assert_info(n_steps, p_row_act, p_col_act, env, max_steps, CC, DD, CD, DC): step_i = 0 for _ in range(n_steps): step_i += 1 actions = { "player_row": p_row_act[step_i - 1], "player_col": p_col_act[step_i - 1], } obs, reward, done, info = env.step(actions) check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=step_i) assert not done["__all__"] or (step_i == max_steps and done["__all__"]) if done["__all__"]: assert info["player_row"]["CC"] == CC assert info["player_col"]["CC"] == CC assert info["player_row"]["DD"] == DD assert info["player_col"]["DD"] == DD assert info["player_row"]["CD"] == CD assert info["player_col"]["CD"] == CD assert info["player_row"]["DC"] == DC assert info["player_col"]["DC"] == DC obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) step_i = 0 def test_logged_info_full_CC(): p_row_act = [0, 0, 0, 0] p_col_act = [0, 0, 0, 0] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) assert_info( n_steps, p_row_act, p_col_act, env, max_steps, CC=1.0, DD=0.0, CD=0.0, DC=0.0, ) def test_logged_info_full_DD(): p_row_act = [1, 1, 1, 1] p_col_act = [1, 1, 1, 1] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) assert_info( n_steps, p_row_act, p_col_act, env, max_steps, CC=0.0, DD=1.0, CD=0.0, DC=0.0, ) def test_logged_info_full_CD(): p_row_act = [0, 0, 0, 0] p_col_act = [1, 1, 1, 1] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) assert_info( n_steps, p_row_act, p_col_act, env, max_steps, CC=0.0, DD=0.0, CD=1.0, DC=0.0, ) def test_logged_info_full_DC(): p_row_act = [1, 1, 1, 1] p_col_act = [0, 0, 0, 0] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) assert_info( n_steps, p_row_act, p_col_act, env, max_steps, CC=0.0, DD=0.0, CD=0.0, DC=1.0, ) def test_logged_info_mix_CC_DD(): p_row_act = [0, 1, 1, 1] p_col_act = [0, 1, 1, 1] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) assert_info( n_steps, p_row_act, p_col_act, env, max_steps, CC=0.25, DD=0.75, CD=0.0, DC=0.0, ) def test_logged_info_mix_CD_CD(): p_row_act = [1, 0, 1, 0] p_col_act = [0, 1, 0, 1] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = int(max_steps * 8.25) for env in env_all: obs = env.reset() check_obs(obs, env) assert_logger_buffer_size_two_players(env, n_steps=0) assert_info( n_steps, p_row_act, p_col_act, env, max_steps, CC=0.0, DD=0.0, CD=0.5, DC=0.5, ) def test_observations_are_invariant_to_the_player_trained(): p_row_act = [0, 1, 1, 0] p_col_act = [0, 1, 0, 1] max_steps = 4 env_all = [init_env(max_steps, env_class) for env_class in ENVS] n_steps = 4 for env in env_all: _ = env.reset() step_i = 0 for _ in range(n_steps): step_i += 1 actions = { "player_row": p_row_act[step_i - 1], "player_col": p_col_act[step_i - 1], } obs, reward, done, info = env.step(actions) # assert observations are symmetrical respective to the actions if step_i == 1: assert obs[env.players_ids[0]] == obs[env.players_ids[1]] elif step_i == 2: assert obs[env.players_ids[0]] == obs[env.players_ids[1]] elif step_i == 3: obs_step_3 = obs elif step_i == 4: assert obs[env.players_ids[0]] == obs_step_3[env.players_ids[1]] assert obs[env.players_ids[1]] == obs_step_3[env.players_ids[0]]