ray/rllib/examples/env/mock_env.py

174 lines
5.3 KiB
Python

import gym
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.annotations import override
class MockEnv(gym.Env):
"""Mock environment for testing purposes.
Observation=0, reward=1.0, episode-len is configurable.
Actions are ignored.
"""
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return 0
def step(self, action):
self.i += 1
return 0, 1.0, self.i >= self.episode_length, {}
class MockEnv2(gym.Env):
"""Mock environment for testing purposes.
Observation=ts (discrete space!), reward=100.0, episode-len is
configurable. Actions are ignored.
"""
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(self.episode_length + 1)
self.action_space = gym.spaces.Discrete(2)
self.rng_seed = None
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, 100.0, self.i >= self.episode_length, {}
def seed(self, rng_seed):
self.rng_seed = rng_seed
class MockEnv3(gym.Env):
"""Mock environment for testing purposes.
Observation=ts (discrete space!), reward=100.0, episode-len is
configurable. Actions are ignored.
"""
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, self.i, self.i >= self.episode_length, {
"timestep": self.i
}
class VectorizedMockEnv(VectorEnv):
"""Vectorized version of the MockEnv.
Contains `num_envs` MockEnv instances, each one having its own
`episode_length` horizon.
"""
def __init__(self, episode_length, num_envs):
super().__init__(
observation_space=gym.spaces.Discrete(1),
action_space=gym.spaces.Discrete(2),
num_envs=num_envs)
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
@override(VectorEnv)
def vector_reset(self):
return [e.reset() for e in self.envs]
@override(VectorEnv)
def reset_at(self, index):
return self.envs[index].reset()
@override(VectorEnv)
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(len(self.envs)):
obs, rew, done, info = self.envs[i].step(actions[i])
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
@override(VectorEnv)
def get_unwrapped(self):
return self.envs
class MockVectorEnv(VectorEnv):
"""A custom vector env that uses a single(!) CartPole sub-env.
However, this env pretends to be a vectorized one to illustrate how one
could create custom VectorEnvs w/o the need for actual vectorizations of
sub-envs under the hood.
"""
def __init__(self, episode_length, mocked_num_envs):
self.env = gym.make("CartPole-v0")
super().__init__(
observation_space=self.env.observation_space,
action_space=self.env.action_space,
num_envs=mocked_num_envs)
self.episode_len = episode_length
self.ts = 0
@override(VectorEnv)
def vector_reset(self):
obs = self.env.reset()
return [obs for _ in range(self.num_envs)]
@override(VectorEnv)
def reset_at(self, index):
self.ts = 0
return self.env.reset()
@override(VectorEnv)
def vector_step(self, actions):
self.ts += 1
# Apply all actions sequentially to the same env.
# Whether this would make a lot of sense is debatable.
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(self.num_envs):
obs, rew, done, info = self.env.step(actions[i])
# Artificially terminate once time step limit has been reached.
# Note: Also terminate, when underlying CartPole is terminated.
if self.ts >= self.episode_len:
done = True
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
if done:
remaining = self.num_envs - (i + 1)
obs_batch.extend([obs for _ in range(remaining)])
rew_batch.extend([rew for _ in range(remaining)])
done_batch.extend([done for _ in range(remaining)])
info_batch.extend([info for _ in range(remaining)])
break
return obs_batch, rew_batch, done_batch, info_batch
@override(VectorEnv)
def get_unwrapped(self):
# You may also leave this method as-is, in which case, it would
# return an empty list.
return [self.env for _ in range(self.num_envs)]