diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 12cadb528..a86491ff5 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -37,7 +37,7 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.filter import Filter from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.numpy import convert_to_numpy, make_action_immutable from ray.rllib.utils.spaces.space_utils import clip_action, unsquash_action, unbatch from ray.rllib.utils.typing import ( SampleBatchType, @@ -1250,6 +1250,9 @@ def _process_policy_eval_results( episode._set_last_action(agent_id, action) assert agent_id not in actions_to_send[env_id] + # Flag actions as immutable to notify the user when trying to change it + # and to avoid hardly traceable errors. + tree.traverse(make_action_immutable, action_to_send, top_down=False) actions_to_send[env_id][agent_id] = action_to_send return actions_to_send diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index faa888b9d..8516377cf 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -360,6 +360,55 @@ class TestRolloutWorker(unittest.TestCase): self.assertLess(np.min(sample["actions"]), action_space.low[0]) ev.stop() + def test_action_immutability(self): + from ray.rllib.examples.env.random_env import RandomEnv + + action_space = gym.spaces.Box(0.0001, 0.0002, (5,)) + + class ActionMutationEnv(RandomEnv): + def init(self, config): + self.test_case = config["test_case"] + super().__init__(config=config) + + def step(self, action): + # Ensure that it is called from inside the sampling process. + import inspect + + curframe = inspect.currentframe() + called_from_check = any( + [ + frame[3] == "check_gym_environments" + for frame in inspect.getouterframes(curframe, 2) + ] + ) + # Check, whether the action is immutable. + if action.flags.writeable and not called_from_check: + self.test_case.assertFalse( + action.flags.writeable, "Action is mutable" + ) + return super().step(action) + + ev = RolloutWorker( + env_creator=lambda _: ActionMutationEnv( + config=dict( + test_case=self, + action_space=action_space, + max_episode_len=10, + p_done=0.0, + check_action_bounds=True, + ) + ), + policy_spec=RandomPolicy, + policy_config=dict( + action_space=action_space, + ignore_action_bounds=True, + ), + clip_actions=False, + batch_mode="complete_episodes", + ) + ev.sample() + ev.stop() + def test_reward_clipping(self): # Clipping: True (clip between -1.0 and 1.0). ev = RolloutWorker( diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index f5a3d2057..43975d6ee 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -1,6 +1,8 @@ +from collections import OrderedDict from gym.spaces import Discrete, MultiDiscrete import numpy as np import tree # pip install dm_tree +from types import MappingProxyType from typing import List, Optional from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning @@ -300,6 +302,42 @@ def flatten_inputs_to_1d_tensor( return merged +def make_action_immutable(obj): + """Flags actions immutable to notify users when trying to change + them. + + Can also be used with any tree-like structure containing either + dictionaries, numpy arrays or already immutable objects per se. + Note, however that `tree.map_structure()` will in general not + include the shallow object containing all others and therefore + immutability will hold only for all objects contained in it. + Use `tree.traverse(fun, action, top_down=False)` to include + also the containing object. + + Args: + obj: The object to be made immutable. + + Returns: + The immutable object. + + Examples: + >>> import tree + >>> import numpy as np + >>> arr = np.arange(1,10) + >>> d = dict(a = 1, b = (arr, arr)) + >>> tree.traverse(make_action_immutable, d, top_down=False) + """ + if isinstance(obj, np.ndarray): + obj.setflags(write=False) + return obj + elif isinstance(obj, OrderedDict): + return MappingProxyType(dict(obj)) + elif isinstance(obj, dict): + return MappingProxyType(obj) + else: + return obj + + def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray: """Reference: https://en.wikipedia.org/wiki/Huber_loss.""" return np.where( diff --git a/rllib/utils/tests/test_utils.py b/rllib/utils/tests/test_utils.py index a114ab772..80889ea7f 100644 --- a/rllib/utils/tests/test_utils.py +++ b/rllib/utils/tests/test_utils.py @@ -6,6 +6,7 @@ import unittest import ray from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor as flatten_np +from ray.rllib.utils.numpy import make_action_immutable from ray.rllib.utils.test_utils import check from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor as flatten_tf from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor as flatten_torch @@ -60,6 +61,63 @@ class TestUtils(unittest.TestCase): def tearDownClass(cls) -> None: ray.shutdown() + def test_make_action_immutable(self): + import gym + from types import MappingProxyType + + # Test Box space. + space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32) + action = space.sample() + action = make_action_immutable(action) + self.assertFalse(action.flags["WRITEABLE"]) + + # Test Discrete space. + # Nothing to be tested as sampled actions are integers + # and integers are immutable by nature. + + # Test MultiDiscrete space. + space = gym.spaces.MultiDiscrete([3,3,3]) + action = space.sample() + action = make_action_immutable(action) + self.assertFalse(action.flags["WRITEABLE"]) + + # Test MultiBinary space. + space = gym.spaces.MultiBinary([2,2,2]) + action = space.sample() + action = make_action_immutable(action) + self.assertFalse(action.flags["WRITEABLE"]) + + # Test Tuple space. + space = gym.spaces.Tuple( + ( + gym.spaces.Discrete(2), + gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32), + ) + ) + action = space.sample() + action = tree.traverse(make_action_immutable, action, top_down=False) + self.assertFalse(action[1].flags["WRITEABLE"]) + + # Test Dict space. + space = gym.spaces.Dict({ + "a": gym.spaces.Discrete(2), + "b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32), + "c": gym.spaces.Tuple( + ( + gym.spaces.Discrete(2), + gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32), + ) + ) + }) + action = space.sample() + action = tree.traverse(make_action_immutable, action, top_down=False) + def fail_fun(obj): + obj["a"] = 5 + self.assertRaises(TypeError, fail_fun, action) + self.assertFalse(action["b"].flags["WRITEABLE"]) + self.assertFalse(action["c"][1].flags["WRITEABLE"]) + self.assertTrue(isinstance(action, MappingProxyType)) + def test_flatten_inputs_to_1d_tensor(self): # B=3; no time axis. check(