mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
318 lines
8.8 KiB
Python
318 lines
8.8 KiB
Python
##########
|
|
# Contribution by the Center on Long-Term Risk:
|
|
# https://github.com/longtermrisk/marltoolbox
|
|
##########
|
|
import random
|
|
|
|
from ray.rllib.examples.env.matrix_sequential_social_dilemma import \
|
|
IteratedPrisonersDilemma, IteratedChicken, \
|
|
IteratedStagHunt, IteratedBoS
|
|
|
|
ENVS = [
|
|
IteratedPrisonersDilemma, IteratedChicken, IteratedStagHunt, IteratedBoS
|
|
]
|
|
|
|
|
|
def test_reset():
|
|
max_steps = 20
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
|
|
def init_env(max_steps, env_class, seed=None):
|
|
config = {
|
|
"max_steps": max_steps,
|
|
}
|
|
env = env_class(config)
|
|
env.seed(seed)
|
|
return env
|
|
|
|
|
|
def check_obs(obs, env):
|
|
assert len(obs) == 2, "two players"
|
|
for key, player_obs in obs.items():
|
|
assert isinstance(player_obs, int) # .shape == (env.NUM_STATES)
|
|
assert player_obs < env.NUM_STATES
|
|
|
|
|
|
def assert_logger_buffer_size_two_players(env, n_steps):
|
|
assert len(env.cc_count) == n_steps
|
|
assert len(env.dd_count) == n_steps
|
|
assert len(env.cd_count) == n_steps
|
|
assert len(env.dc_count) == n_steps
|
|
|
|
|
|
def test_step():
|
|
max_steps = 20
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
actions = {
|
|
policy_id: random.randint(0, env.NUM_ACTIONS - 1)
|
|
for policy_id in env.players_ids
|
|
}
|
|
obs, reward, done, info = env.step(actions)
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=1)
|
|
assert not done["__all__"]
|
|
|
|
|
|
def test_multiple_steps():
|
|
max_steps = 20
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 0.75)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
for step_i in range(1, n_steps, 1):
|
|
actions = {
|
|
policy_id: random.randint(0, env.NUM_ACTIONS - 1)
|
|
for policy_id in env.players_ids
|
|
}
|
|
obs, reward, done, info = env.step(actions)
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=step_i)
|
|
assert not done["__all__"]
|
|
|
|
|
|
def test_multiple_episodes():
|
|
max_steps = 20
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
step_i = 0
|
|
for _ in range(n_steps):
|
|
step_i += 1
|
|
actions = {
|
|
policy_id: random.randint(0, env.NUM_ACTIONS - 1)
|
|
for policy_id in env.players_ids
|
|
}
|
|
obs, reward, done, info = env.step(actions)
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=step_i)
|
|
assert not done["__all__"] or \
|
|
(step_i == max_steps and done["__all__"])
|
|
if done["__all__"]:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
step_i = 0
|
|
|
|
|
|
def assert_info(n_steps, p_row_act, p_col_act, env, max_steps, CC, DD, CD, DC):
|
|
step_i = 0
|
|
for _ in range(n_steps):
|
|
step_i += 1
|
|
actions = {
|
|
"player_row": p_row_act[step_i - 1],
|
|
"player_col": p_col_act[step_i - 1]
|
|
}
|
|
obs, reward, done, info = env.step(actions)
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=step_i)
|
|
assert not done["__all__"] or (step_i == max_steps and done["__all__"])
|
|
|
|
if done["__all__"]:
|
|
assert info["player_row"]["CC"] == CC
|
|
assert info["player_col"]["CC"] == CC
|
|
assert info["player_row"]["DD"] == DD
|
|
assert info["player_col"]["DD"] == DD
|
|
assert info["player_row"]["CD"] == CD
|
|
assert info["player_col"]["CD"] == CD
|
|
assert info["player_row"]["DC"] == DC
|
|
assert info["player_col"]["DC"] == DC
|
|
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
step_i = 0
|
|
|
|
|
|
def test_logged_info_full_CC():
|
|
p_row_act = [0, 0, 0, 0]
|
|
p_col_act = [0, 0, 0, 0]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
assert_info(
|
|
n_steps,
|
|
p_row_act,
|
|
p_col_act,
|
|
env,
|
|
max_steps,
|
|
CC=1.0,
|
|
DD=0.0,
|
|
CD=0.0,
|
|
DC=0.0)
|
|
|
|
|
|
def test_logged_info_full_DD():
|
|
p_row_act = [1, 1, 1, 1]
|
|
p_col_act = [1, 1, 1, 1]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
assert_info(
|
|
n_steps,
|
|
p_row_act,
|
|
p_col_act,
|
|
env,
|
|
max_steps,
|
|
CC=0.0,
|
|
DD=1.0,
|
|
CD=0.0,
|
|
DC=0.0)
|
|
|
|
|
|
def test_logged_info_full_CD():
|
|
p_row_act = [0, 0, 0, 0]
|
|
p_col_act = [1, 1, 1, 1]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
assert_info(
|
|
n_steps,
|
|
p_row_act,
|
|
p_col_act,
|
|
env,
|
|
max_steps,
|
|
CC=0.0,
|
|
DD=0.0,
|
|
CD=1.0,
|
|
DC=0.0)
|
|
|
|
|
|
def test_logged_info_full_DC():
|
|
p_row_act = [1, 1, 1, 1]
|
|
p_col_act = [0, 0, 0, 0]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
assert_info(
|
|
n_steps,
|
|
p_row_act,
|
|
p_col_act,
|
|
env,
|
|
max_steps,
|
|
CC=0.0,
|
|
DD=0.0,
|
|
CD=0.0,
|
|
DC=1.0)
|
|
|
|
|
|
def test_logged_info_mix_CC_DD():
|
|
p_row_act = [0, 1, 1, 1]
|
|
p_col_act = [0, 1, 1, 1]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
assert_info(
|
|
n_steps,
|
|
p_row_act,
|
|
p_col_act,
|
|
env,
|
|
max_steps,
|
|
CC=0.25,
|
|
DD=0.75,
|
|
CD=0.0,
|
|
DC=0.0)
|
|
|
|
|
|
def test_logged_info_mix_CD_CD():
|
|
p_row_act = [1, 0, 1, 0]
|
|
p_col_act = [0, 1, 0, 1]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = int(max_steps * 8.25)
|
|
|
|
for env in env_all:
|
|
obs = env.reset()
|
|
check_obs(obs, env)
|
|
assert_logger_buffer_size_two_players(env, n_steps=0)
|
|
|
|
assert_info(
|
|
n_steps,
|
|
p_row_act,
|
|
p_col_act,
|
|
env,
|
|
max_steps,
|
|
CC=0.0,
|
|
DD=0.0,
|
|
CD=0.5,
|
|
DC=0.5)
|
|
|
|
|
|
def test_observations_are_invariant_to_the_player_trained():
|
|
p_row_act = [0, 1, 1, 0]
|
|
p_col_act = [0, 1, 0, 1]
|
|
max_steps = 4
|
|
env_all = [init_env(max_steps, env_class) for env_class in ENVS]
|
|
n_steps = 4
|
|
|
|
for env in env_all:
|
|
_ = env.reset()
|
|
step_i = 0
|
|
for _ in range(n_steps):
|
|
step_i += 1
|
|
actions = {
|
|
"player_row": p_row_act[step_i - 1],
|
|
"player_col": p_col_act[step_i - 1]
|
|
}
|
|
obs, reward, done, info = env.step(actions)
|
|
# assert observations are symmetrical respective to the actions
|
|
if step_i == 1:
|
|
assert obs[env.players_ids[0]] == obs[env.players_ids[1]]
|
|
elif step_i == 2:
|
|
assert obs[env.players_ids[0]] == obs[env.players_ids[1]]
|
|
elif step_i == 3:
|
|
obs_step_3 = obs
|
|
elif step_i == 4:
|
|
assert obs[env.players_ids[0]] == obs_step_3[env.players_ids[
|
|
1]]
|
|
assert obs[env.players_ids[1]] == obs_step_3[env.players_ids[
|
|
0]]
|