mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
46cd7f1830
commit
242706922b
2 changed files with 33 additions and 27 deletions
|
@ -308,7 +308,7 @@ def make_action_immutable(obj):
|
||||||
|
|
||||||
Can also be used with any tree-like structure containing either
|
Can also be used with any tree-like structure containing either
|
||||||
dictionaries, numpy arrays or already immutable objects per se.
|
dictionaries, numpy arrays or already immutable objects per se.
|
||||||
Note, however that `tree.map_structure()` will in general not
|
Note, however that `tree.map_structure()` will in general not
|
||||||
include the shallow object containing all others and therefore
|
include the shallow object containing all others and therefore
|
||||||
immutability will hold only for all objects contained in it.
|
immutability will hold only for all objects contained in it.
|
||||||
Use `tree.traverse(fun, action, top_down=False)` to include
|
Use `tree.traverse(fun, action, top_down=False)` to include
|
||||||
|
@ -332,7 +332,7 @@ def make_action_immutable(obj):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, OrderedDict):
|
elif isinstance(obj, OrderedDict):
|
||||||
return MappingProxyType(dict(obj))
|
return MappingProxyType(dict(obj))
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
return MappingProxyType(obj)
|
return MappingProxyType(obj)
|
||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -61,32 +61,32 @@ class TestUtils(unittest.TestCase):
|
||||||
def tearDownClass(cls) -> None:
|
def tearDownClass(cls) -> None:
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def test_make_action_immutable(self):
|
def test_make_action_immutable(self):
|
||||||
import gym
|
import gym
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
|
||||||
# Test Box space.
|
# Test Box space.
|
||||||
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
|
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
|
||||||
action = space.sample()
|
action = space.sample()
|
||||||
action = make_action_immutable(action)
|
action = make_action_immutable(action)
|
||||||
self.assertFalse(action.flags["WRITEABLE"])
|
self.assertFalse(action.flags["WRITEABLE"])
|
||||||
|
|
||||||
# Test Discrete space.
|
# Test Discrete space.
|
||||||
# Nothing to be tested as sampled actions are integers
|
# Nothing to be tested as sampled actions are integers
|
||||||
# and integers are immutable by nature.
|
# and integers are immutable by nature.
|
||||||
|
|
||||||
# Test MultiDiscrete space.
|
# Test MultiDiscrete space.
|
||||||
space = gym.spaces.MultiDiscrete([3,3,3])
|
space = gym.spaces.MultiDiscrete([3, 3, 3])
|
||||||
action = space.sample()
|
action = space.sample()
|
||||||
action = make_action_immutable(action)
|
action = make_action_immutable(action)
|
||||||
self.assertFalse(action.flags["WRITEABLE"])
|
self.assertFalse(action.flags["WRITEABLE"])
|
||||||
|
|
||||||
# Test MultiBinary space.
|
# Test MultiBinary space.
|
||||||
space = gym.spaces.MultiBinary([2,2,2])
|
space = gym.spaces.MultiBinary([2, 2, 2])
|
||||||
action = space.sample()
|
action = space.sample()
|
||||||
action = make_action_immutable(action)
|
action = make_action_immutable(action)
|
||||||
self.assertFalse(action.flags["WRITEABLE"])
|
self.assertFalse(action.flags["WRITEABLE"])
|
||||||
|
|
||||||
# Test Tuple space.
|
# Test Tuple space.
|
||||||
space = gym.spaces.Tuple(
|
space = gym.spaces.Tuple(
|
||||||
(
|
(
|
||||||
|
@ -97,27 +97,33 @@ class TestUtils(unittest.TestCase):
|
||||||
action = space.sample()
|
action = space.sample()
|
||||||
action = tree.traverse(make_action_immutable, action, top_down=False)
|
action = tree.traverse(make_action_immutable, action, top_down=False)
|
||||||
self.assertFalse(action[1].flags["WRITEABLE"])
|
self.assertFalse(action[1].flags["WRITEABLE"])
|
||||||
|
|
||||||
# Test Dict space.
|
# Test Dict space.
|
||||||
space = gym.spaces.Dict({
|
space = gym.spaces.Dict(
|
||||||
"a": gym.spaces.Discrete(2),
|
{
|
||||||
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
"a": gym.spaces.Discrete(2),
|
||||||
"c": gym.spaces.Tuple(
|
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
||||||
(
|
"c": gym.spaces.Tuple(
|
||||||
gym.spaces.Discrete(2),
|
(
|
||||||
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
gym.spaces.Discrete(2),
|
||||||
)
|
gym.spaces.Box(
|
||||||
)
|
low=-1.0, high=1.0, shape=(8,), dtype=np.float32
|
||||||
})
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
action = space.sample()
|
action = space.sample()
|
||||||
action = tree.traverse(make_action_immutable, action, top_down=False)
|
action = tree.traverse(make_action_immutable, action, top_down=False)
|
||||||
|
|
||||||
def fail_fun(obj):
|
def fail_fun(obj):
|
||||||
obj["a"] = 5
|
obj["a"] = 5
|
||||||
|
|
||||||
self.assertRaises(TypeError, fail_fun, action)
|
self.assertRaises(TypeError, fail_fun, action)
|
||||||
self.assertFalse(action["b"].flags["WRITEABLE"])
|
self.assertFalse(action["b"].flags["WRITEABLE"])
|
||||||
self.assertFalse(action["c"][1].flags["WRITEABLE"])
|
self.assertFalse(action["c"][1].flags["WRITEABLE"])
|
||||||
self.assertTrue(isinstance(action, MappingProxyType))
|
self.assertTrue(isinstance(action, MappingProxyType))
|
||||||
|
|
||||||
def test_flatten_inputs_to_1d_tensor(self):
|
def test_flatten_inputs_to_1d_tensor(self):
|
||||||
# B=3; no time axis.
|
# B=3; no time axis.
|
||||||
check(
|
check(
|
||||||
|
|
Loading…
Add table
Reference in a new issue