import gym from gym.spaces import Box, Dict, Discrete import numpy as np import random class ParametricActionsCartPole(gym.Env): """Parametric action version of CartPole. In this env there are only ever two valid actions, but we pretend there are actually up to `max_avail_actions` actions that can be taken, and the two valid actions are randomly hidden among this set. At each step, we emit a dict of: - the actual cart observation - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail) - the list of action embeddings (w/ zeroes for invalid actions) (e.g., [[0, 0], [0, 0], [-0.2322, -0.2569], [0, 0], [0, 0], [0.7878, 1.2297]] for max_avail_actions=6) In a real environment, the actions embeddings would be larger than two units of course, and also there would be a variable number of valid actions per step instead of always [LEFT, RIGHT]. """ def __init__(self, max_avail_actions): # Use simple random 2-unit action embeddings for [LEFT, RIGHT] self.left_action_embed = np.random.randn(2) self.right_action_embed = np.random.randn(2) self.action_space = Discrete(max_avail_actions) self.wrapped = gym.make("CartPole-v0") self.observation_space = Dict({ "action_mask": Box(0, 1, shape=(max_avail_actions, )), "avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)), "cart": self.wrapped.observation_space, }) def update_avail_actions(self): self.action_assignments = np.array([[0., 0.]] * self.action_space.n) self.action_mask = np.array([0.] * self.action_space.n) self.left_idx, self.right_idx = random.sample( range(self.action_space.n), 2) self.action_assignments[self.left_idx] = self.left_action_embed self.action_assignments[self.right_idx] = self.right_action_embed self.action_mask[self.left_idx] = 1 self.action_mask[self.right_idx] = 1 def reset(self): self.update_avail_actions() return { "action_mask": self.action_mask, "avail_actions": self.action_assignments, "cart": self.wrapped.reset(), } def step(self, action): if action == self.left_idx: actual_action = 0 elif action == self.right_idx: actual_action = 1 else: raise ValueError( "Chosen action was not one of the non-zero action embeddings", action, self.action_assignments, self.action_mask, self.left_idx, self.right_idx) orig_obs, rew, done, info = self.wrapped.step(actual_action) self.update_avail_actions() obs = { "action_mask": self.action_mask, "avail_actions": self.action_assignments, "cart": orig_obs, } return obs, rew, done, info