mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
"""Test utils in rllib/utils/space_utils.py."""
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
|
from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
|
|
|
|
|
|
class TestSpaceUtils(unittest.TestCase):
|
|
def test_convert_element_to_space_type(self):
|
|
"""Test if space converter works for all elements/space permutations"""
|
|
box_space = Box(low=-1, high=1, shape=(2,))
|
|
discrete_space = Discrete(2)
|
|
multi_discrete_space = MultiDiscrete([2, 2])
|
|
multi_binary_space = MultiBinary(2)
|
|
tuple_space = Tuple((box_space, discrete_space))
|
|
dict_space = Dict(
|
|
{
|
|
"box": box_space,
|
|
"discrete": discrete_space,
|
|
"multi_discrete": multi_discrete_space,
|
|
"multi_binary": multi_binary_space,
|
|
"dict_space": Dict(
|
|
{
|
|
"box2": box_space,
|
|
"discrete2": discrete_space,
|
|
}
|
|
),
|
|
"tuple_space": tuple_space,
|
|
}
|
|
)
|
|
|
|
box_space_uncoverted = box_space.sample().astype(np.float64)
|
|
multi_discrete_unconverted = multi_discrete_space.sample().astype(np.int32)
|
|
multi_binary_unconverted = multi_binary_space.sample().astype(np.int32)
|
|
tuple_unconverted = (box_space_uncoverted, float(0))
|
|
modified_element = {
|
|
"box": box_space_uncoverted,
|
|
"discrete": float(0),
|
|
"multi_discrete": multi_discrete_unconverted,
|
|
"multi_binary": multi_binary_unconverted,
|
|
"tuple_space": tuple_unconverted,
|
|
"dict_space": {
|
|
"box2": box_space_uncoverted,
|
|
"discrete2": float(0),
|
|
},
|
|
}
|
|
element_with_correct_types = convert_element_to_space_type(
|
|
modified_element, dict_space.sample()
|
|
)
|
|
assert dict_space.contains(element_with_correct_types)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|