ray/rllib/agents/qmix/tests/test_qmix.py

121 lines
3.4 KiB
Python
Raw Normal View History

from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import unittest
import ray
from ray.tune import register_env
from ray.rllib.agents.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import MultiAgentEnv
class AvailActionsTestEnv(MultiAgentEnv):
num_actions = 10
action_space = Discrete(num_actions)
observation_space = Dict(
{
"obs": Dict(
{
"test": Dict({"a": Discrete(2), "b": MultiDiscrete([2, 3, 4])}),
"state": MultiDiscrete([2, 2, 2]),
}
),
"action_mask": Box(0, 1, (num_actions,)),
}
)
def __init__(self, env_config):
2022-01-18 07:34:06 -08:00
super().__init__()
self.state = None
self.avail = env_config.get("avail_actions", [3])
self.action_mask = np.array([0] * 10)
for a in self.avail:
self.action_mask[a] = 1
def reset(self):
self.state = 0
return {
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
"agent_2": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
}
def step(self, action_dict):
if self.state > 0:
assert (
action_dict["agent_1"] in self.avail
and action_dict["agent_2"] in self.avail
), "Failed to obey available actions mask!"
self.state += 1
rewards = {"agent_1": 1, "agent_2": 0.5}
obs = {
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
"agent_2": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
}
dones = {"__all__": self.state >= 20}
return obs, rewards, dones, {}
class TestQMix(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_avail_actions_qmix(self):
grouping = {
"group_1": ["agent_1", "agent_2"],
}
obs_space = Tuple(
[
AvailActionsTestEnv.observation_space,
AvailActionsTestEnv.observation_space,
]
)
act_space = Tuple(
[AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space]
)
register_env(
"action_mask_test",
lambda config: AvailActionsTestEnv(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space
),
)
config = QMixConfig()\
.framework(framework="torch")\
.environment(
env="action_mask_test",
env_config={"avail_actions": [3, 4, 8]},
)\
.rollouts(num_envs_per_worker=5) # Test with vectorization on.
trainer = config.build()
for _ in range(4):
trainer.train() # OK if it doesn't trip the action assertion error
assert trainer.train()["episode_reward_mean"] == 30.0
trainer.stop()
ray.shutdown()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))