ray/rllib/env/wrappers/tests/test_group_agents_wrapper.py

30 lines
986 B
Python

import unittest
from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
from ray.rllib.env.multi_agent_env import make_multi_agent
class TestGroupAgentsWrapper(unittest.TestCase):
def test_group_agents_wrapper(self):
MultiAgentCartPole = make_multi_agent("CartPole-v0")
grouped_ma_cartpole = GroupAgentsWrapper(
env=MultiAgentCartPole({
"num_agents": 4
}),
groups={
"group1": [0, 1],
"group2": [2, 3]
})
obs = grouped_ma_cartpole.reset()
self.assertTrue(len(obs) == 2)
self.assertTrue("group1" in obs and "group2" in obs)
self.assertTrue(
isinstance(obs["group1"], list) and len(obs["group1"]) == 2)
self.assertTrue(
isinstance(obs["group2"], list) and len(obs["group2"]) == 2)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))