ray/rllib/env/wrappers/tests/test_unity3d_env.py

55 lines
2 KiB
Python

import unittest
from unittest.mock import patch
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
@patch("mlagents_envs.environment.UnityEnvironment")
class TestUnity3DEnv(unittest.TestCase):
def test_port_editor(self, mock_unity3d):
"""Test if the environment uses the editor port
when no environment file is provided"""
_ = Unity3DEnv(port=None)
args, kwargs = mock_unity3d.call_args
mock_unity3d.assert_called_once()
self.assertEqual(5004, kwargs.get("base_port"))
def test_port_app(self, mock_unity3d):
"""Test if the environment uses the correct port
when the environment file is provided"""
_ = Unity3DEnv(file_name="app", port=None)
args, kwargs = mock_unity3d.call_args
mock_unity3d.assert_called_once()
self.assertEqual(5005, kwargs.get("base_port"))
def test_ports_multi_app(self, mock_unity3d):
"""Test if the base_port + worker_id
is different for each environment"""
_ = Unity3DEnv(file_name="app", port=None)
args, kwargs_first = mock_unity3d.call_args
_ = Unity3DEnv(file_name="app", port=None)
args, kwargs_second = mock_unity3d.call_args
self.assertNotEqual(
kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
def test_custom_port_app(self, mock_unity3d):
"""Test if the base_port + worker_id is different
for each environment when using custom ports"""
_ = Unity3DEnv(file_name="app", port=5010)
args, kwargs_first = mock_unity3d.call_args
_ = Unity3DEnv(file_name="app", port=5010)
args, kwargs_second = mock_unity3d.call_args
self.assertNotEqual(
kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))