mirror of
synced 2025-03-08 19:41:38 -05:00
185 lines
5.7 KiB
185 lines
5.7 KiB
import gym
import numpy as np
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.annotations import override
class MockEnv(gym.Env):
"""Mock environment for testing purposes.
Observation=0, reward=1.0, episode-len is configurable.
Actions are ignored.
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return 0
def step(self, action):
self.i += 1
return 0, 1.0, self.i >= self.episode_length, {}
class MockEnv2(gym.Env):
"""Mock environment for testing purposes.
Observation=ts (discrete space!), reward=100.0, episode-len is
configurable. Actions are ignored.
metadata = {
"render.modes": ["rgb_array"],
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(self.episode_length + 1)
self.action_space = gym.spaces.Discrete(2)
self.rng_seed = None
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, 100.0, self.i >= self.episode_length, {}
def seed(self, rng_seed):
self.rng_seed = rng_seed
def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
class MockEnv3(gym.Env):
"""Mock environment for testing purposes.
Observation=ts (discrete space!), reward=100.0, episode-len is
configurable. Actions are ignored.
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, self.i, self.i >= self.episode_length, {"timestep": self.i}
class VectorizedMockEnv(VectorEnv):
"""Vectorized version of the MockEnv.
Contains `num_envs` MockEnv instances, each one having its own
`episode_length` horizon.
def __init__(self, episode_length, num_envs):
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
def vector_reset(self):
return [e.reset() for e in self.envs]
def reset_at(self, index):
return self.envs[index].reset()
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(len(self.envs)):
obs, rew, done, info = self.envs[i].step(actions[i])
return obs_batch, rew_batch, done_batch, info_batch
def get_sub_environments(self):
return self.envs
class MockVectorEnv(VectorEnv):
"""A custom vector env that uses a single(!) CartPole sub-env.
However, this env pretends to be a vectorized one to illustrate how one
could create custom VectorEnvs w/o the need for actual vectorizations of
sub-envs under the hood.
def __init__(self, episode_length, mocked_num_envs):
self.env = gym.make("CartPole-v0")
self.episode_len = episode_length
self.ts = 0
def vector_reset(self):
obs = self.env.reset()
return [obs for _ in range(self.num_envs)]
def reset_at(self, index):
self.ts = 0
return self.env.reset()
def vector_step(self, actions):
self.ts += 1
# Apply all actions sequentially to the same env.
# Whether this would make a lot of sense is debatable.
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(self.num_envs):
obs, rew, done, info = self.env.step(actions[i])
# Artificially terminate once time step limit has been reached.
# Note: Also terminate, when underlying CartPole is terminated.
if self.ts >= self.episode_len:
done = True
if done:
remaining = self.num_envs - (i + 1)
obs_batch.extend([obs for _ in range(remaining)])
rew_batch.extend([rew for _ in range(remaining)])
done_batch.extend([done for _ in range(remaining)])
info_batch.extend([info for _ in range(remaining)])
return obs_batch, rew_batch, done_batch, info_batch
def get_sub_environments(self):
# You may also leave this method as-is, in which case, it would
# return an empty list.
return [self.env for _ in range(self.num_envs)]