2020-05-22 10:17:51 +02:00
|
|
|
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
2020-03-12 04:39:47 +01:00
|
|
|
import numpy as np
|
|
|
|
import unittest
|
2018-12-18 10:40:01 -08:00
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray.tune import register_env
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.qmix import QMixConfig
|
2020-05-27 10:19:47 +02:00
|
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
2018-12-18 10:40:01 -08:00
|
|
|
|
|
|
|
|
|
|
|
class AvailActionsTestEnv(MultiAgentEnv):
|
2020-05-22 10:17:51 +02:00
|
|
|
num_actions = 10
|
|
|
|
action_space = Discrete(num_actions)
|
2018-12-18 10:40:01 -08:00
|
|
|
observation_space = Dict(
|
2020-05-22 10:17:51 +02:00
|
|
|
{
|
|
|
|
"obs": Dict(
|
|
|
|
{
|
|
|
|
"test": Dict({"a": Discrete(2), "b": MultiDiscrete([2, 3, 4])}),
|
|
|
|
"state": MultiDiscrete([2, 2, 2]),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
"action_mask": Box(0, 1, (num_actions,)),
|
2018-12-18 10:40:01 -08:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
def __init__(self, env_config):
|
2022-01-18 07:34:06 -08:00
|
|
|
super().__init__()
|
2018-12-18 10:40:01 -08:00
|
|
|
self.state = None
|
2022-01-04 08:54:41 +01:00
|
|
|
self.avail = env_config.get("avail_actions", [3])
|
2019-04-07 16:11:50 -07:00
|
|
|
self.action_mask = np.array([0] * 10)
|
2022-01-04 08:54:41 +01:00
|
|
|
for a in self.avail:
|
|
|
|
self.action_mask[a] = 1
|
2018-12-18 10:40:01 -08:00
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.state = 0
|
|
|
|
return {
|
|
|
|
"agent_1": {
|
2020-05-22 10:17:51 +02:00
|
|
|
"obs": self.observation_space["obs"].sample(),
|
2018-12-18 10:40:01 -08:00
|
|
|
"action_mask": self.action_mask,
|
2022-01-04 08:54:41 +01:00
|
|
|
},
|
|
|
|
"agent_2": {
|
|
|
|
"obs": self.observation_space["obs"].sample(),
|
|
|
|
"action_mask": self.action_mask,
|
|
|
|
},
|
2018-12-18 10:40:01 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
def step(self, action_dict):
|
|
|
|
if self.state > 0:
|
2022-01-04 08:54:41 +01:00
|
|
|
assert (
|
|
|
|
action_dict["agent_1"] in self.avail
|
|
|
|
and action_dict["agent_2"] in self.avail
|
2018-12-18 10:40:01 -08:00
|
|
|
), "Failed to obey available actions mask!"
|
|
|
|
self.state += 1
|
2022-01-04 08:54:41 +01:00
|
|
|
rewards = {"agent_1": 1, "agent_2": 0.5}
|
2020-05-22 10:17:51 +02:00
|
|
|
obs = {
|
|
|
|
"agent_1": {
|
|
|
|
"obs": self.observation_space["obs"].sample(),
|
|
|
|
"action_mask": self.action_mask,
|
2022-01-04 08:54:41 +01:00
|
|
|
},
|
|
|
|
"agent_2": {
|
|
|
|
"obs": self.observation_space["obs"].sample(),
|
|
|
|
"action_mask": self.action_mask,
|
2020-05-22 10:17:51 +02:00
|
|
|
},
|
|
|
|
}
|
2022-01-04 08:54:41 +01:00
|
|
|
dones = {"__all__": self.state >= 20}
|
2018-12-18 10:40:01 -08:00
|
|
|
return obs, rewards, dones, {}
|
|
|
|
|
|
|
|
|
2020-07-17 12:14:34 +02:00
|
|
|
class TestQMix(unittest.TestCase):
|
2021-06-19 22:42:00 +02:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_avail_actions_qmix(self):
|
|
|
|
grouping = {
|
2022-01-04 08:54:41 +01:00
|
|
|
"group_1": ["agent_1", "agent_2"],
|
2020-03-12 04:39:47 +01:00
|
|
|
}
|
2022-01-04 08:54:41 +01:00
|
|
|
obs_space = Tuple(
|
|
|
|
[
|
|
|
|
AvailActionsTestEnv.observation_space,
|
|
|
|
AvailActionsTestEnv.observation_space,
|
|
|
|
]
|
|
|
|
)
|
|
|
|
act_space = Tuple(
|
|
|
|
[AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space]
|
|
|
|
)
|
2020-03-12 04:39:47 +01:00
|
|
|
register_env(
|
|
|
|
"action_mask_test",
|
|
|
|
lambda config: AvailActionsTestEnv(config).with_agent_groups(
|
|
|
|
grouping, obs_space=obs_space, act_space=act_space
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2020-03-12 04:39:47 +01:00
|
|
|
)
|
2018-12-18 10:40:01 -08:00
|
|
|
|
2022-05-16 00:45:32 -07:00
|
|
|
config = (
|
|
|
|
QMixConfig()
|
|
|
|
.framework(framework="torch")
|
|
|
|
.environment(
|
|
|
|
env="action_mask_test",
|
|
|
|
env_config={"avail_actions": [3, 4, 8]},
|
|
|
|
)
|
|
|
|
.rollouts(num_envs_per_worker=5)
|
|
|
|
) # Test with vectorization on.
|
|
|
|
|
|
|
|
trainer = config.build()
|
|
|
|
|
2021-06-19 22:42:00 +02:00
|
|
|
for _ in range(4):
|
2022-01-04 08:54:41 +01:00
|
|
|
trainer.train() # OK if it doesn't trip the action assertion error
|
2022-05-16 00:45:32 -07:00
|
|
|
|
2022-01-04 08:54:41 +01:00
|
|
|
assert trainer.train()["episode_reward_mean"] == 30.0
|
|
|
|
trainer.stop()
|
2020-06-16 08:50:53 +02:00
|
|
|
ray.shutdown()
|
2020-03-12 04:39:47 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|