ray/rllib/contrib/alpha_zero/environments/cartpole.py

41 lines
1.2 KiB
Python
Raw Normal View History

from copy import deepcopy
import gym
import numpy as np
from gym.spaces import Discrete, Dict, Box
class CartPole:
"""
Wrapper for gym CartPole environment where the reward
is accumulated to the end
"""
def __init__(self, config=None):
self.env = gym.make("CartPole-v0")
self.action_space = Discrete(2)
self.observation_space = Dict({
"obs": self.env.observation_space,
"action_mask": Box(low=0, high=1, shape=(self.action_space.n, ))
})
self.running_reward = 0
def reset(self):
self.running_reward = 0
return {"obs": self.env.reset(), "action_mask": np.array([1, 1])}
def step(self, action):
obs, rew, done, info = self.env.step(action)
self.running_reward += rew
score = self.running_reward if done else 0
return {"obs": obs, "action_mask": np.array([1, 1])}, score, done, info
def set_state(self, state):
self.running_reward = state[1]
self.env = deepcopy(state[0])
obs = np.array(list(self.env.unwrapped.state))
return {"obs": obs, "action_mask": np.array([1, 1])}
def get_state(self):
return deepcopy(self.env), self.running_reward