ray/rllib/examples/env/action_mask_env.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

43 lines
1.4 KiB
Python
Raw Normal View History

from gym.spaces import Box, Dict, Discrete
import numpy as np
from ray.rllib.examples.env.random_env import RandomEnv
class ActionMaskEnv(RandomEnv):
"""A randomly acting environment that publishes an action-mask each step."""
def __init__(self, config):
super().__init__(config)
self._skip_env_checking = True
# Masking only works for Discrete actions.
assert isinstance(self.action_space, Discrete)
# Add action_mask to observations.
self.observation_space = Dict(
{
"action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
"observations": self.observation_space,
}
)
self.valid_actions = None
def reset(self):
obs = super().reset()
self._fix_action_mask(obs)
return obs
def step(self, action):
# Check whether action is valid.
if not self.valid_actions[action]:
raise ValueError(
f"Invalid action sent to env! " f"valid_actions={self.valid_actions}"
)
obs, rew, done, info = super().step(action)
self._fix_action_mask(obs)
return obs, rew, done, info
def _fix_action_mask(self, obs):
# Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
self.valid_actions = np.round(obs["action_mask"])
obs["action_mask"] = self.valid_actions