ray/rllib/examples/env/action_mask_env.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

42 lines
1.4 KiB
Python

from gym.spaces import Box, Dict, Discrete
import numpy as np
from ray.rllib.examples.env.random_env import RandomEnv
class ActionMaskEnv(RandomEnv):
"""A randomly acting environment that publishes an action-mask each step."""
def __init__(self, config):
super().__init__(config)
# Masking only works for Discrete actions.
assert isinstance(self.action_space, Discrete)
# Add action_mask to observations.
self.observation_space = Dict(
{
"action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
"observations": self.observation_space,
}
)
self.valid_actions = None
def reset(self):
obs = super().reset()
self._fix_action_mask(obs)
return obs
def step(self, action):
# Check whether action is valid.
if not self.valid_actions[action]:
raise ValueError(
f"Invalid action sent to env! " f"valid_actions={self.valid_actions}"
)
obs, rew, done, info = super().step(action)
self._fix_action_mask(obs)
return obs, rew, done, info
def _fix_action_mask(self, obs):
# Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
self.valid_actions = np.round(obs["action_mask"])
obs["action_mask"] = self.valid_actions