mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
40 lines
1.2 KiB
Python
40 lines
1.2 KiB
Python
from copy import deepcopy
|
|
|
|
import gym
|
|
import numpy as np
|
|
from gym.spaces import Discrete, Dict, Box
|
|
|
|
|
|
class CartPole:
|
|
"""
|
|
Wrapper for gym CartPole environment where the reward
|
|
is accumulated to the end
|
|
"""
|
|
|
|
def __init__(self, config=None):
|
|
self.env = gym.make("CartPole-v0")
|
|
self.action_space = Discrete(2)
|
|
self.observation_space = Dict({
|
|
"obs": self.env.observation_space,
|
|
"action_mask": Box(low=0, high=1, shape=(self.action_space.n, ))
|
|
})
|
|
self.running_reward = 0
|
|
|
|
def reset(self):
|
|
self.running_reward = 0
|
|
return {"obs": self.env.reset(), "action_mask": np.array([1, 1])}
|
|
|
|
def step(self, action):
|
|
obs, rew, done, info = self.env.step(action)
|
|
self.running_reward += rew
|
|
score = self.running_reward if done else 0
|
|
return {"obs": obs, "action_mask": np.array([1, 1])}, score, done, info
|
|
|
|
def set_state(self, state):
|
|
self.running_reward = state[1]
|
|
self.env = deepcopy(state[0])
|
|
obs = np.array(list(self.env.unwrapped.state))
|
|
return {"obs": obs, "action_mask": np.array([1, 1])}
|
|
|
|
def get_state(self):
|
|
return deepcopy(self.env), self.running_reward
|