[RLlib] Make actions sent by RLlib to the env immutable. (#24262)

This commit is contained in:
simonsays1980 2022-04-29 10:27:06 +02:00 committed by GitHub
parent 5f12c62226
commit ff575eeafc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 1 deletions

View file

@ -37,7 +37,7 @@ from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.numpy import convert_to_numpy, make_action_immutable
from ray.rllib.utils.spaces.space_utils import clip_action, unsquash_action, unbatch
from ray.rllib.utils.typing import (
SampleBatchType,
@ -1250,6 +1250,9 @@ def _process_policy_eval_results(
episode._set_last_action(agent_id, action)
assert agent_id not in actions_to_send[env_id]
# Flag actions as immutable to notify the user when trying to change it
# and to avoid hardly traceable errors.
tree.traverse(make_action_immutable, action_to_send, top_down=False)
actions_to_send[env_id][agent_id] = action_to_send
return actions_to_send

View file

@ -360,6 +360,55 @@ class TestRolloutWorker(unittest.TestCase):
self.assertLess(np.min(sample["actions"]), action_space.low[0])
ev.stop()
def test_action_immutability(self):
from ray.rllib.examples.env.random_env import RandomEnv
action_space = gym.spaces.Box(0.0001, 0.0002, (5,))
class ActionMutationEnv(RandomEnv):
def init(self, config):
self.test_case = config["test_case"]
super().__init__(config=config)
def step(self, action):
# Ensure that it is called from inside the sampling process.
import inspect
curframe = inspect.currentframe()
called_from_check = any(
[
frame[3] == "check_gym_environments"
for frame in inspect.getouterframes(curframe, 2)
]
)
# Check, whether the action is immutable.
if action.flags.writeable and not called_from_check:
self.test_case.assertFalse(
action.flags.writeable, "Action is mutable"
)
return super().step(action)
ev = RolloutWorker(
env_creator=lambda _: ActionMutationEnv(
config=dict(
test_case=self,
action_space=action_space,
max_episode_len=10,
p_done=0.0,
check_action_bounds=True,
)
),
policy_spec=RandomPolicy,
policy_config=dict(
action_space=action_space,
ignore_action_bounds=True,
),
clip_actions=False,
batch_mode="complete_episodes",
)
ev.sample()
ev.stop()
def test_reward_clipping(self):
# Clipping: True (clip between -1.0 and 1.0).
ev = RolloutWorker(

View file

@ -1,6 +1,8 @@
from collections import OrderedDict
from gym.spaces import Discrete, MultiDiscrete
import numpy as np
import tree # pip install dm_tree
from types import MappingProxyType
from typing import List, Optional
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
@ -300,6 +302,42 @@ def flatten_inputs_to_1d_tensor(
return merged
def make_action_immutable(obj):
"""Flags actions immutable to notify users when trying to change
them.
Can also be used with any tree-like structure containing either
dictionaries, numpy arrays or already immutable objects per se.
Note, however that `tree.map_structure()` will in general not
include the shallow object containing all others and therefore
immutability will hold only for all objects contained in it.
Use `tree.traverse(fun, action, top_down=False)` to include
also the containing object.
Args:
obj: The object to be made immutable.
Returns:
The immutable object.
Examples:
>>> import tree
>>> import numpy as np
>>> arr = np.arange(1,10)
>>> d = dict(a = 1, b = (arr, arr))
>>> tree.traverse(make_action_immutable, d, top_down=False)
"""
if isinstance(obj, np.ndarray):
obj.setflags(write=False)
return obj
elif isinstance(obj, OrderedDict):
return MappingProxyType(dict(obj))
elif isinstance(obj, dict):
return MappingProxyType(obj)
else:
return obj
def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray:
"""Reference: https://en.wikipedia.org/wiki/Huber_loss."""
return np.where(

View file

@ -6,6 +6,7 @@ import unittest
import ray
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor as flatten_np
from ray.rllib.utils.numpy import make_action_immutable
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor as flatten_tf
from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor as flatten_torch
@ -60,6 +61,63 @@ class TestUtils(unittest.TestCase):
def tearDownClass(cls) -> None:
ray.shutdown()
def test_make_action_immutable(self):
import gym
from types import MappingProxyType
# Test Box space.
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
action = space.sample()
action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"])
# Test Discrete space.
# Nothing to be tested as sampled actions are integers
# and integers are immutable by nature.
# Test MultiDiscrete space.
space = gym.spaces.MultiDiscrete([3,3,3])
action = space.sample()
action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"])
# Test MultiBinary space.
space = gym.spaces.MultiBinary([2,2,2])
action = space.sample()
action = make_action_immutable(action)
self.assertFalse(action.flags["WRITEABLE"])
# Test Tuple space.
space = gym.spaces.Tuple(
(
gym.spaces.Discrete(2),
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
)
)
action = space.sample()
action = tree.traverse(make_action_immutable, action, top_down=False)
self.assertFalse(action[1].flags["WRITEABLE"])
# Test Dict space.
space = gym.spaces.Dict({
"a": gym.spaces.Discrete(2),
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
"c": gym.spaces.Tuple(
(
gym.spaces.Discrete(2),
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
)
)
})
action = space.sample()
action = tree.traverse(make_action_immutable, action, top_down=False)
def fail_fun(obj):
obj["a"] = 5
self.assertRaises(TypeError, fail_fun, action)
self.assertFalse(action["b"].flags["WRITEABLE"])
self.assertFalse(action["c"][1].flags["WRITEABLE"])
self.assertTrue(isinstance(action, MappingProxyType))
def test_flatten_inputs_to_1d_tensor(self):
# B=3; no time axis.
check(