ray/rllib/contrib/alpha_zero/environments/cartpole.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

50 lines
1.4 KiB
Python

from copy import deepcopy
import gym
import numpy as np
from gym.spaces import Discrete, Dict, Box
class CartPole:
"""
Wrapper for gym CartPole environment where the reward
is accumulated to the end
"""
def __init__(self, config=None):
self.env = gym.make("CartPole-v0")
self.action_space = Discrete(2)
self.observation_space = Dict(
{
"obs": self.env.observation_space,
"action_mask": Box(low=0, high=1, shape=(self.action_space.n,)),
}
)
self.running_reward = 0
def reset(self):
self.running_reward = 0
return {
"obs": self.env.reset(),
"action_mask": np.array([1, 1], dtype=np.float32),
}
def step(self, action):
obs, rew, done, info = self.env.step(action)
self.running_reward += rew
score = self.running_reward if done else 0
return (
{"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)},
score,
done,
info,
)
def set_state(self, state):
self.running_reward = state[1]
self.env = deepcopy(state[0])
obs = np.array(list(self.env.unwrapped.state))
return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)}
def get_state(self):
return deepcopy(self.env), self.running_reward