mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
"""Test utils in rllib/utils/space_utils.py."""
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
|
from ray.rllib.utils.spaces.space_utils import (
|
|
convert_element_to_space_type,
|
|
get_base_struct_from_space,
|
|
unsquash_action,
|
|
)
|
|
|
|
|
|
class TestSpaceUtils(unittest.TestCase):
|
|
def test_convert_element_to_space_type(self):
|
|
"""Test if space converter works for all elements/space permutations"""
|
|
box_space = Box(low=-1, high=1, shape=(2,))
|
|
discrete_space = Discrete(2)
|
|
multi_discrete_space = MultiDiscrete([2, 2])
|
|
multi_binary_space = MultiBinary(2)
|
|
tuple_space = Tuple((box_space, discrete_space))
|
|
dict_space = Dict(
|
|
{
|
|
"box": box_space,
|
|
"discrete": discrete_space,
|
|
"multi_discrete": multi_discrete_space,
|
|
"multi_binary": multi_binary_space,
|
|
"dict_space": Dict(
|
|
{
|
|
"box2": box_space,
|
|
"discrete2": discrete_space,
|
|
}
|
|
),
|
|
"tuple_space": tuple_space,
|
|
}
|
|
)
|
|
|
|
box_space_uncoverted = box_space.sample().astype(np.float64)
|
|
multi_discrete_unconverted = multi_discrete_space.sample().astype(np.int32)
|
|
multi_binary_unconverted = multi_binary_space.sample().astype(np.int32)
|
|
tuple_unconverted = (box_space_uncoverted, float(0))
|
|
modified_element = {
|
|
"box": box_space_uncoverted,
|
|
"discrete": float(0),
|
|
"multi_discrete": multi_discrete_unconverted,
|
|
"multi_binary": multi_binary_unconverted,
|
|
"tuple_space": tuple_unconverted,
|
|
"dict_space": {
|
|
"box2": box_space_uncoverted,
|
|
"discrete2": float(0),
|
|
},
|
|
}
|
|
element_with_correct_types = convert_element_to_space_type(
|
|
modified_element, dict_space.sample()
|
|
)
|
|
assert dict_space.contains(element_with_correct_types)
|
|
|
|
def test_unsquash_action(self):
|
|
"""Test to make sure unsquash_action works for both float and int Box spaces."""
|
|
space = Box(low=3, high=8, shape=(2,), dtype=np.float32)
|
|
struct = get_base_struct_from_space(space)
|
|
action = unsquash_action(0.5, struct)
|
|
self.assertEqual(action[0], 6.75)
|
|
self.assertEqual(action[1], 6.75)
|
|
|
|
space = Box(low=3, high=8, shape=(2,), dtype=np.int32)
|
|
struct = get_base_struct_from_space(space)
|
|
action = unsquash_action(3, struct)
|
|
self.assertEqual(action[0], 6)
|
|
self.assertEqual(action[1], 6)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|