ray/rllib/algorithms/qmix/tests/test_qmix.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

123 lines
3.5 KiB
Python
Raw Normal View History

from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import unittest
import ray
from ray.tune import register_env
from ray.rllib.algorithms.qmix import QMixConfig
from ray.rllib.env.multi_agent_env import MultiAgentEnv
class AvailActionsTestEnv(MultiAgentEnv):
num_actions = 10
action_space = Discrete(num_actions)
observation_space = Dict(
{
"obs": Dict(
{
"test": Dict({"a": Discrete(2), "b": MultiDiscrete([2, 3, 4])}),
"state": MultiDiscrete([2, 2, 2]),
}
),
"action_mask": Box(0, 1, (num_actions,)),
}
)
def __init__(self, env_config):
2022-01-18 07:34:06 -08:00
super().__init__()
self.state = None
self.avail = env_config.get("avail_actions", [3])
self.action_mask = np.array([0] * 10)
for a in self.avail:
self.action_mask[a] = 1
def reset(self):
self.state = 0
return {
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
"agent_2": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
}
def step(self, action_dict):
if self.state > 0:
assert (
action_dict["agent_1"] in self.avail
and action_dict["agent_2"] in self.avail
), "Failed to obey available actions mask!"
self.state += 1
rewards = {"agent_1": 1, "agent_2": 0.5}
obs = {
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
"agent_2": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask,
},
}
dones = {"__all__": self.state >= 20}
return obs, rewards, dones, {}
class TestQMix(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_avail_actions_qmix(self):
grouping = {
"group_1": ["agent_1", "agent_2"],
}
obs_space = Tuple(
[
AvailActionsTestEnv.observation_space,
AvailActionsTestEnv.observation_space,
]
)
act_space = Tuple(
[AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space]
)
register_env(
"action_mask_test",
lambda config: AvailActionsTestEnv(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space
),
)
config = (
QMixConfig()
.framework(framework="torch")
.environment(
env="action_mask_test",
env_config={"avail_actions": [3, 4, 8]},
)
.rollouts(num_envs_per_worker=5)
) # Test with vectorization on.
trainer = config.build()
for _ in range(4):
trainer.train() # OK if it doesn't trip the action assertion error
assert trainer.train()["episode_reward_mean"] == 30.0
trainer.stop()
ray.shutdown()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))