mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
from gym.spaces import Box, Dict, Discrete
|
|
import numpy as np
|
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
|
|
|
|
|
class ActionMaskEnv(RandomEnv):
|
|
"""A randomly acting environment that publishes an action-mask each step."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self._skip_env_checking = True
|
|
# Masking only works for Discrete actions.
|
|
assert isinstance(self.action_space, Discrete)
|
|
# Add action_mask to observations.
|
|
self.observation_space = Dict(
|
|
{
|
|
"action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
|
|
"observations": self.observation_space,
|
|
}
|
|
)
|
|
self.valid_actions = None
|
|
|
|
def reset(self):
|
|
obs = super().reset()
|
|
self._fix_action_mask(obs)
|
|
return obs
|
|
|
|
def step(self, action):
|
|
# Check whether action is valid.
|
|
if not self.valid_actions[action]:
|
|
raise ValueError(
|
|
f"Invalid action sent to env! " f"valid_actions={self.valid_actions}"
|
|
)
|
|
obs, rew, done, info = super().step(action)
|
|
self._fix_action_mask(obs)
|
|
return obs, rew, done, info
|
|
|
|
def _fix_action_mask(self, obs):
|
|
# Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
|
|
self.valid_actions = np.round(obs["action_mask"])
|
|
obs["action_mask"] = self.valid_actions
|