[rllib] Fix linting (#24335)

#24262 broke linting. This fixes this.
This commit is contained in:
Kai Fricke 2022-04-29 15:21:11 +01:00 committed by GitHub
parent 46cd7f1830
commit 242706922b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 27 deletions

View file

@ -308,7 +308,7 @@ def make_action_immutable(obj):
Can also be used with any tree-like structure containing either
dictionaries, numpy arrays or already immutable objects per se.
Note, however that `tree.map_structure()` will in general not
Note, however that `tree.map_structure()` will in general not
include the shallow object containing all others and therefore
immutability will hold only for all objects contained in it.
Use `tree.traverse(fun, action, top_down=False)` to include
@ -332,7 +332,7 @@ def make_action_immutable(obj):
return obj
elif isinstance(obj, OrderedDict):
return MappingProxyType(dict(obj))
elif isinstance(obj, dict):
elif isinstance(obj, dict):
return MappingProxyType(obj)
else:
return obj

View file

@ -61,32 +61,32 @@ class TestUtils(unittest.TestCase):
def tearDownClass(cls) -> None:
ray.shutdown()
def test_make_action_immutable(self):
import gym
def test_make_action_immutable(self):
import gym
from types import MappingProxyType
# Test Box space.
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
action = space.sample()
action = space.sample()
action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"])
# Test Discrete space.
# Nothing to be tested as sampled actions are integers
# and integers are immutable by nature.
# and integers are immutable by nature.
# Test MultiDiscrete space.
space = gym.spaces.MultiDiscrete([3,3,3])
space = gym.spaces.MultiDiscrete([3, 3, 3])
action = space.sample()
action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"])
# Test MultiBinary space.
space = gym.spaces.MultiBinary([2,2,2])
space = gym.spaces.MultiBinary([2, 2, 2])
action = space.sample()
action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"])
# Test Tuple space.
space = gym.spaces.Tuple(
(
@ -97,27 +97,33 @@ class TestUtils(unittest.TestCase):
action = space.sample()
action = tree.traverse(make_action_immutable, action, top_down=False)
self.assertFalse(action[1].flags["WRITEABLE"])
# Test Dict space.
space = gym.spaces.Dict({
"a": gym.spaces.Discrete(2),
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
"c": gym.spaces.Tuple(
(
gym.spaces.Discrete(2),
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
)
)
})
space = gym.spaces.Dict(
{
"a": gym.spaces.Discrete(2),
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
"c": gym.spaces.Tuple(
(
gym.spaces.Discrete(2),
gym.spaces.Box(
low=-1.0, high=1.0, shape=(8,), dtype=np.float32
),
)
),
}
)
action = space.sample()
action = tree.traverse(make_action_immutable, action, top_down=False)
def fail_fun(obj):
obj["a"] = 5
self.assertRaises(TypeError, fail_fun, action)
self.assertFalse(action["b"].flags["WRITEABLE"])
self.assertFalse(action["c"][1].flags["WRITEABLE"])
self.assertTrue(isinstance(action, MappingProxyType))
self.assertFalse(action["c"][1].flags["WRITEABLE"])
self.assertTrue(isinstance(action, MappingProxyType))
def test_flatten_inputs_to_1d_tensor(self):
# B=3; no time axis.
check(