[rllib] Fix linting (#24335)

#24262 broke linting. This fixes this.
This commit is contained in:
Kai Fricke 2022-04-29 15:21:11 +01:00 committed by GitHub
parent 46cd7f1830
commit 242706922b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 27 deletions

View file

@ -308,7 +308,7 @@ def make_action_immutable(obj):
Can also be used with any tree-like structure containing either Can also be used with any tree-like structure containing either
dictionaries, numpy arrays or already immutable objects per se. dictionaries, numpy arrays or already immutable objects per se.
Note, however that `tree.map_structure()` will in general not Note, however that `tree.map_structure()` will in general not
include the shallow object containing all others and therefore include the shallow object containing all others and therefore
immutability will hold only for all objects contained in it. immutability will hold only for all objects contained in it.
Use `tree.traverse(fun, action, top_down=False)` to include Use `tree.traverse(fun, action, top_down=False)` to include
@ -332,7 +332,7 @@ def make_action_immutable(obj):
return obj return obj
elif isinstance(obj, OrderedDict): elif isinstance(obj, OrderedDict):
return MappingProxyType(dict(obj)) return MappingProxyType(dict(obj))
elif isinstance(obj, dict): elif isinstance(obj, dict):
return MappingProxyType(obj) return MappingProxyType(obj)
else: else:
return obj return obj

View file

@ -61,32 +61,32 @@ class TestUtils(unittest.TestCase):
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
ray.shutdown() ray.shutdown()
def test_make_action_immutable(self): def test_make_action_immutable(self):
import gym import gym
from types import MappingProxyType from types import MappingProxyType
# Test Box space. # Test Box space.
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32) space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
action = space.sample() action = space.sample()
action = make_action_immutable(action) action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"]) self.assertFalse(action.flags["WRITEABLE"])
# Test Discrete space. # Test Discrete space.
# Nothing to be tested as sampled actions are integers # Nothing to be tested as sampled actions are integers
# and integers are immutable by nature. # and integers are immutable by nature.
# Test MultiDiscrete space. # Test MultiDiscrete space.
space = gym.spaces.MultiDiscrete([3,3,3]) space = gym.spaces.MultiDiscrete([3, 3, 3])
action = space.sample() action = space.sample()
action = make_action_immutable(action) action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"]) self.assertFalse(action.flags["WRITEABLE"])
# Test MultiBinary space. # Test MultiBinary space.
space = gym.spaces.MultiBinary([2,2,2]) space = gym.spaces.MultiBinary([2, 2, 2])
action = space.sample() action = space.sample()
action = make_action_immutable(action) action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"]) self.assertFalse(action.flags["WRITEABLE"])
# Test Tuple space. # Test Tuple space.
space = gym.spaces.Tuple( space = gym.spaces.Tuple(
( (
@ -97,27 +97,33 @@ class TestUtils(unittest.TestCase):
action = space.sample() action = space.sample()
action = tree.traverse(make_action_immutable, action, top_down=False) action = tree.traverse(make_action_immutable, action, top_down=False)
self.assertFalse(action[1].flags["WRITEABLE"]) self.assertFalse(action[1].flags["WRITEABLE"])
# Test Dict space. # Test Dict space.
space = gym.spaces.Dict({ space = gym.spaces.Dict(
"a": gym.spaces.Discrete(2), {
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32), "a": gym.spaces.Discrete(2),
"c": gym.spaces.Tuple( "b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
( "c": gym.spaces.Tuple(
gym.spaces.Discrete(2), (
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32), gym.spaces.Discrete(2),
) gym.spaces.Box(
) low=-1.0, high=1.0, shape=(8,), dtype=np.float32
}) ),
)
),
}
)
action = space.sample() action = space.sample()
action = tree.traverse(make_action_immutable, action, top_down=False) action = tree.traverse(make_action_immutable, action, top_down=False)
def fail_fun(obj): def fail_fun(obj):
obj["a"] = 5 obj["a"] = 5
self.assertRaises(TypeError, fail_fun, action) self.assertRaises(TypeError, fail_fun, action)
self.assertFalse(action["b"].flags["WRITEABLE"]) self.assertFalse(action["b"].flags["WRITEABLE"])
self.assertFalse(action["c"][1].flags["WRITEABLE"]) self.assertFalse(action["c"][1].flags["WRITEABLE"])
self.assertTrue(isinstance(action, MappingProxyType)) self.assertTrue(isinstance(action, MappingProxyType))
def test_flatten_inputs_to_1d_tensor(self): def test_flatten_inputs_to_1d_tensor(self):
# B=3; no time axis. # B=3; no time axis.
check( check(