mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
30 lines
986 B
Python
30 lines
986 B
Python
import unittest
|
|
|
|
from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
|
|
from ray.rllib.env.multi_agent_env import make_multi_agent
|
|
|
|
|
|
class TestGroupAgentsWrapper(unittest.TestCase):
|
|
def test_group_agents_wrapper(self):
|
|
MultiAgentCartPole = make_multi_agent("CartPole-v0")
|
|
grouped_ma_cartpole = GroupAgentsWrapper(
|
|
env=MultiAgentCartPole({
|
|
"num_agents": 4
|
|
}),
|
|
groups={
|
|
"group1": [0, 1],
|
|
"group2": [2, 3]
|
|
})
|
|
obs = grouped_ma_cartpole.reset()
|
|
self.assertTrue(len(obs) == 2)
|
|
self.assertTrue("group1" in obs and "group2" in obs)
|
|
self.assertTrue(
|
|
isinstance(obs["group1"], list) and len(obs["group1"]) == 2)
|
|
self.assertTrue(
|
|
isinstance(obs["group2"], list) and len(obs["group2"]) == 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
import pytest
|
|
sys.exit(pytest.main(["-v", __file__]))
|