ray/rllib/utils/spaces/tests/test_space_utils.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

59 lines
2.1 KiB
Python

"""Test utils in rllib/utils/space_utils.py."""
import unittest
import numpy as np
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
class TestSpaceUtils(unittest.TestCase):
def test_convert_element_to_space_type(self):
"""Test if space converter works for all elements/space permutations"""
box_space = Box(low=-1, high=1, shape=(2,))
discrete_space = Discrete(2)
multi_discrete_space = MultiDiscrete([2, 2])
multi_binary_space = MultiBinary(2)
tuple_space = Tuple((box_space, discrete_space))
dict_space = Dict(
{
"box": box_space,
"discrete": discrete_space,
"multi_discrete": multi_discrete_space,
"multi_binary": multi_binary_space,
"dict_space": Dict(
{
"box2": box_space,
"discrete2": discrete_space,
}
),
"tuple_space": tuple_space,
}
)
box_space_uncoverted = box_space.sample().astype(np.float64)
multi_discrete_unconverted = multi_discrete_space.sample().astype(np.int32)
multi_binary_unconverted = multi_binary_space.sample().astype(np.int32)
tuple_unconverted = (box_space_uncoverted, float(0))
modified_element = {
"box": box_space_uncoverted,
"discrete": float(0),
"multi_discrete": multi_discrete_unconverted,
"multi_binary": multi_binary_unconverted,
"tuple_space": tuple_unconverted,
"dict_space": {
"box2": box_space_uncoverted,
"discrete2": float(0),
},
}
element_with_correct_types = convert_element_to_space_type(
modified_element, dict_space.sample()
)
assert dict_space.contains(element_with_correct_types)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))