ray/rllib/env/wrappers/tests/test_kaggle_wrapper.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

68 lines
2.5 KiB
Python

import unittest
class TestKaggleFootballMultiAgentEnv(unittest.TestCase):
def test_football_env(self):
from ray.rllib.env.wrappers.kaggle_wrapper import KaggleFootballMultiAgentEnv
env = KaggleFootballMultiAgentEnv()
obs = env.reset()
self.assertEqual(list(obs.keys()), ["agent0", "agent1"])
action_dict = {"agent0": 0, "agent1": 0}
obs, reward, done, info = env.step(action_dict)
self.assertEqual(list(obs.keys()), ["agent0", "agent1"])
self.assertEqual(reward, {"agent0": 0, "agent1": 0})
self.assertEqual(
done,
{
"agent0": False,
"agent1": False,
"__all__": False,
},
)
self.assertEqual(info, {"agent0": {}, "agent1": {}})
def test_football_env_run_30_steps(self):
from ray.rllib.env.wrappers.kaggle_wrapper import KaggleFootballMultiAgentEnv
env = KaggleFootballMultiAgentEnv()
# use the built-in agents in the kaggle environment
run_right_agent = env.kaggle_env.agents["run_right"]
do_nothing_agent = env.kaggle_env.agents["do_nothing"]
obs = env.reset()
self.assertEqual(list(obs.keys()), ["agent0", "agent1"])
done = {"__all__": False}
num_steps_completed = 0
while not done["__all__"] and num_steps_completed <= 30:
action0 = run_right_agent(structify(obs["agent0"]))[0]
action1 = do_nothing_agent(structify(obs["agent1"]))[0]
action_dict = {"agent0": action0, "agent1": action1}
obs, _, done, _ = env.step(action_dict)
num_steps_completed += 1
def test_kaggle_football_agent_spaces(self):
from ray.rllib.env.wrappers.kaggle_wrapper import KaggleFootballMultiAgentEnv
env = KaggleFootballMultiAgentEnv()
obs = env.reset()
action_space, obs_space = env.build_agent_spaces()
self.assertTrue(obs_space.contains(obs["agent0"]))
self.assertTrue(obs_space.contains(obs["agent1"]))
action_dict = {
"agent0": action_space.sample(),
"agent1": action_space.sample(),
}
obs, _, _, _ = env.step(action_dict)
self.assertTrue(obs_space.contains(obs["agent0"]))
self.assertTrue(obs_space.contains(obs["agent1"]))
if __name__ == "__main__":
from kaggle_environments.utils import structify
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))