mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Make actions sent by RLlib to the env immutable. (#24262)
This commit is contained in:
parent
5f12c62226
commit
ff575eeafc
4 changed files with 149 additions and 1 deletions
|
@ -37,7 +37,7 @@ from ray.rllib.utils.debug import summarize
|
|||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.filter import Filter
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.numpy import convert_to_numpy, make_action_immutable
|
||||
from ray.rllib.utils.spaces.space_utils import clip_action, unsquash_action, unbatch
|
||||
from ray.rllib.utils.typing import (
|
||||
SampleBatchType,
|
||||
|
@ -1250,6 +1250,9 @@ def _process_policy_eval_results(
|
|||
episode._set_last_action(agent_id, action)
|
||||
|
||||
assert agent_id not in actions_to_send[env_id]
|
||||
# Flag actions as immutable to notify the user when trying to change it
|
||||
# and to avoid hardly traceable errors.
|
||||
tree.traverse(make_action_immutable, action_to_send, top_down=False)
|
||||
actions_to_send[env_id][agent_id] = action_to_send
|
||||
|
||||
return actions_to_send
|
||||
|
|
|
@ -360,6 +360,55 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
self.assertLess(np.min(sample["actions"]), action_space.low[0])
|
||||
ev.stop()
|
||||
|
||||
def test_action_immutability(self):
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
|
||||
action_space = gym.spaces.Box(0.0001, 0.0002, (5,))
|
||||
|
||||
class ActionMutationEnv(RandomEnv):
|
||||
def init(self, config):
|
||||
self.test_case = config["test_case"]
|
||||
super().__init__(config=config)
|
||||
|
||||
def step(self, action):
|
||||
# Ensure that it is called from inside the sampling process.
|
||||
import inspect
|
||||
|
||||
curframe = inspect.currentframe()
|
||||
called_from_check = any(
|
||||
[
|
||||
frame[3] == "check_gym_environments"
|
||||
for frame in inspect.getouterframes(curframe, 2)
|
||||
]
|
||||
)
|
||||
# Check, whether the action is immutable.
|
||||
if action.flags.writeable and not called_from_check:
|
||||
self.test_case.assertFalse(
|
||||
action.flags.writeable, "Action is mutable"
|
||||
)
|
||||
return super().step(action)
|
||||
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: ActionMutationEnv(
|
||||
config=dict(
|
||||
test_case=self,
|
||||
action_space=action_space,
|
||||
max_episode_len=10,
|
||||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)
|
||||
),
|
||||
policy_spec=RandomPolicy,
|
||||
policy_config=dict(
|
||||
action_space=action_space,
|
||||
ignore_action_bounds=True,
|
||||
),
|
||||
clip_actions=False,
|
||||
batch_mode="complete_episodes",
|
||||
)
|
||||
ev.sample()
|
||||
ev.stop()
|
||||
|
||||
def test_reward_clipping(self):
|
||||
# Clipping: True (clip between -1.0 and 1.0).
|
||||
ev = RolloutWorker(
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from collections import OrderedDict
|
||||
from gym.spaces import Discrete, MultiDiscrete
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from types import MappingProxyType
|
||||
from typing import List, Optional
|
||||
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
|
@ -300,6 +302,42 @@ def flatten_inputs_to_1d_tensor(
|
|||
return merged
|
||||
|
||||
|
||||
def make_action_immutable(obj):
|
||||
"""Flags actions immutable to notify users when trying to change
|
||||
them.
|
||||
|
||||
Can also be used with any tree-like structure containing either
|
||||
dictionaries, numpy arrays or already immutable objects per se.
|
||||
Note, however that `tree.map_structure()` will in general not
|
||||
include the shallow object containing all others and therefore
|
||||
immutability will hold only for all objects contained in it.
|
||||
Use `tree.traverse(fun, action, top_down=False)` to include
|
||||
also the containing object.
|
||||
|
||||
Args:
|
||||
obj: The object to be made immutable.
|
||||
|
||||
Returns:
|
||||
The immutable object.
|
||||
|
||||
Examples:
|
||||
>>> import tree
|
||||
>>> import numpy as np
|
||||
>>> arr = np.arange(1,10)
|
||||
>>> d = dict(a = 1, b = (arr, arr))
|
||||
>>> tree.traverse(make_action_immutable, d, top_down=False)
|
||||
"""
|
||||
if isinstance(obj, np.ndarray):
|
||||
obj.setflags(write=False)
|
||||
return obj
|
||||
elif isinstance(obj, OrderedDict):
|
||||
return MappingProxyType(dict(obj))
|
||||
elif isinstance(obj, dict):
|
||||
return MappingProxyType(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray:
|
||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss."""
|
||||
return np.where(
|
||||
|
|
|
@ -6,6 +6,7 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor as flatten_np
|
||||
from ray.rllib.utils.numpy import make_action_immutable
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor as flatten_tf
|
||||
from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor as flatten_torch
|
||||
|
@ -60,6 +61,63 @@ class TestUtils(unittest.TestCase):
|
|||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_make_action_immutable(self):
|
||||
import gym
|
||||
from types import MappingProxyType
|
||||
|
||||
# Test Box space.
|
||||
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
|
||||
action = space.sample()
|
||||
action = make_action_immutable(action)
|
||||
self.assertFalse(action.flags["WRITEABLE"])
|
||||
|
||||
# Test Discrete space.
|
||||
# Nothing to be tested as sampled actions are integers
|
||||
# and integers are immutable by nature.
|
||||
|
||||
# Test MultiDiscrete space.
|
||||
space = gym.spaces.MultiDiscrete([3,3,3])
|
||||
action = space.sample()
|
||||
action = make_action_immutable(action)
|
||||
self.assertFalse(action.flags["WRITEABLE"])
|
||||
|
||||
# Test MultiBinary space.
|
||||
space = gym.spaces.MultiBinary([2,2,2])
|
||||
action = space.sample()
|
||||
action = make_action_immutable(action)
|
||||
self.assertFalse(action.flags["WRITEABLE"])
|
||||
|
||||
# Test Tuple space.
|
||||
space = gym.spaces.Tuple(
|
||||
(
|
||||
gym.spaces.Discrete(2),
|
||||
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
||||
)
|
||||
)
|
||||
action = space.sample()
|
||||
action = tree.traverse(make_action_immutable, action, top_down=False)
|
||||
self.assertFalse(action[1].flags["WRITEABLE"])
|
||||
|
||||
# Test Dict space.
|
||||
space = gym.spaces.Dict({
|
||||
"a": gym.spaces.Discrete(2),
|
||||
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
||||
"c": gym.spaces.Tuple(
|
||||
(
|
||||
gym.spaces.Discrete(2),
|
||||
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
||||
)
|
||||
)
|
||||
})
|
||||
action = space.sample()
|
||||
action = tree.traverse(make_action_immutable, action, top_down=False)
|
||||
def fail_fun(obj):
|
||||
obj["a"] = 5
|
||||
self.assertRaises(TypeError, fail_fun, action)
|
||||
self.assertFalse(action["b"].flags["WRITEABLE"])
|
||||
self.assertFalse(action["c"][1].flags["WRITEABLE"])
|
||||
self.assertTrue(isinstance(action, MappingProxyType))
|
||||
|
||||
def test_flatten_inputs_to_1d_tensor(self):
|
||||
# B=3; no time axis.
|
||||
check(
|
||||
|
|
Loading…
Add table
Reference in a new issue