ray/rllib/env/vector_env.py

477 lines
17 KiB
Python

import logging
import gym
import numpy as np
from typing import Callable, List, Optional, Tuple, Union, Set
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import (
EnvActionType,
EnvID,
EnvInfoDict,
EnvObsType,
EnvType,
MultiEnvDict,
AgentID,
)
from ray.util import log_once
logger = logging.getLogger(__name__)
@PublicAPI
class VectorEnv:
"""An environment that supports batch evaluation using clones of sub-envs."""
def __init__(
self, observation_space: gym.Space, action_space: gym.Space, num_envs: int
):
"""Initializes a VectorEnv instance.
Args:
observation_space: The observation Space of a single
sub-env.
action_space: The action Space of a single sub-env.
num_envs: The number of clones to make of the given sub-env.
"""
self.observation_space = observation_space
self.action_space = action_space
self.num_envs = num_envs
@staticmethod
def vectorize_gym_envs(
make_env: Optional[Callable[[int], EnvType]] = None,
existing_envs: Optional[List[gym.Env]] = None,
num_envs: int = 1,
action_space: Optional[gym.Space] = None,
observation_space: Optional[gym.Space] = None,
restart_failed_sub_environments: bool = False,
# Deprecated. These seem to have never been used.
env_config=None,
policy_config=None,
) -> "_VectorizedGymEnv":
"""Translates any given gym.Env(s) into a VectorizedEnv object.
Args:
make_env: Factory that produces a new gym.Env taking the sub-env's
vector index as only arg. Must be defined if the
number of `existing_envs` is less than `num_envs`.
existing_envs: Optional list of already instantiated sub
environments.
num_envs: Total number of sub environments in this VectorEnv.
action_space: The action space. If None, use existing_envs[0]'s
action space.
observation_space: The observation space. If None, use
existing_envs[0]'s observation space.
restart_failed_sub_environments: If True and any sub-environment (within
a vectorized env) throws any error during env stepping, the
Sampler will try to restart the faulty sub-environment. This is done
without disturbing the other (still intact) sub-environment and without
the RolloutWorker crashing.
Returns:
The resulting _VectorizedGymEnv object (subclass of VectorEnv).
"""
return _VectorizedGymEnv(
make_env=make_env,
existing_envs=existing_envs or [],
num_envs=num_envs,
observation_space=observation_space,
action_space=action_space,
restart_failed_sub_environments=restart_failed_sub_environments,
)
@PublicAPI
def vector_reset(self) -> List[EnvObsType]:
"""Resets all sub-environments.
Returns:
List of observations from each environment.
"""
raise NotImplementedError
@PublicAPI
def reset_at(self, index: Optional[int] = None) -> EnvObsType:
"""Resets a single sub-environment.
Args:
index: An optional sub-env index to reset.
Returns:
Observations from the reset sub environment.
"""
raise NotImplementedError
@PublicAPI
def restart_at(self, index: Optional[int] = None) -> None:
"""Restarts a single sub-environment.
Args:
index: An optional sub-env index to restart.
"""
raise NotImplementedError
@PublicAPI
def vector_step(
self, actions: List[EnvActionType]
) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
"""Performs a vectorized step on all sub environments using `actions`.
Args:
actions: List of actions (one for each sub-env).
Returns:
A tuple consisting of
1) New observations for each sub-env.
2) Reward values for each sub-env.
3) Done values for each sub-env.
4) Info values for each sub-env.
"""
raise NotImplementedError
@PublicAPI
def get_sub_environments(self) -> List[EnvType]:
"""Returns the underlying sub environments.
Returns:
List of all underlying sub environments.
"""
return []
# TODO: (sven) Experimental method. Make @PublicAPI at some point.
def try_render_at(self, index: Optional[int] = None) -> Optional[np.ndarray]:
"""Renders a single environment.
Args:
index: An optional sub-env index to render.
Returns:
Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or
None in case rendering is handled directly by this method.
"""
pass
@Deprecated(new="vectorize_gym_envs", error=False)
def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv":
return self.vectorize_gym_envs(*args, **kwargs)
@Deprecated(new="get_sub_environments", error=False)
def get_unwrapped(self) -> List[EnvType]:
return self.get_sub_environments()
@PublicAPI
def to_base_env(
self,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
restart_failed_sub_environments: bool = False,
) -> "BaseEnv":
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
The resulting BaseEnv is always vectorized (contains n
sub-environments) to support batched forward passes, where n may
also be 1. BaseEnv also supports async execution via the `poll` and
`send_actions` methods and thus supports external simulators.
Args:
make_env: A callable taking an int as input (which indicates
the number of individual sub-environments within the final
vectorized BaseEnv) and returning one individual
sub-environment.
num_envs: The number of sub-environments to create in the
resulting (vectorized) BaseEnv. The already existing `env`
will be one of the `num_envs`.
remote_envs: Whether each sub-env should be a @ray.remote
actor. You can set this behavior in your config via the
`remote_worker_envs=True` option.
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
sub-environments for, if applicable. Only used if
`remote_envs` is True.
Returns:
The resulting BaseEnv object.
"""
env = VectorEnvWrapper(self)
return env
class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper to translate any gym.Envs into a VectorEnv object."""
def __init__(
self,
make_env: Optional[Callable[[int], EnvType]] = None,
existing_envs: Optional[List[gym.Env]] = None,
num_envs: int = 1,
*,
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
restart_failed_sub_environments: bool = False,
# Deprecated. These seem to have never been used.
env_config=None,
policy_config=None,
):
"""Initializes a _VectorizedGymEnv object.
Args:
make_env: Factory that produces a new gym.Env taking the sub-env's
vector index as only arg. Must be defined if the
number of `existing_envs` is less than `num_envs`.
existing_envs: Optional list of already instantiated sub
environments.
num_envs: Total number of sub environments in this VectorEnv.
action_space: The action space. If None, use existing_envs[0]'s
action space.
observation_space: The observation space. If None, use
existing_envs[0]'s observation space.
restart_failed_sub_environments: If True and any sub-environment (within
a vectorized env) throws any error during env stepping, we will try to
restart the faulty sub-environment. This is done
without disturbing the other (still intact) sub-environments.
"""
self.envs = existing_envs
self.make_env = make_env
self.restart_failed_sub_environments = restart_failed_sub_environments
# Fill up missing envs (so we have exactly num_envs sub-envs in this
# VectorEnv.
while len(self.envs) < num_envs:
self.envs.append(make_env(len(self.envs)))
super().__init__(
observation_space=observation_space or self.envs[0].observation_space,
action_space=action_space or self.envs[0].action_space,
num_envs=num_envs,
)
@override(VectorEnv)
def vector_reset(self):
# Use reset_at(index) to restart and retry until
# we successfully create a new env.
resetted_obs = []
for i in range(len(self.envs)):
while True:
obs = self.reset_at(i)
if not isinstance(obs, Exception):
break
resetted_obs.append(obs)
return resetted_obs
@override(VectorEnv)
def reset_at(self, index: Optional[int] = None) -> EnvObsType:
if index is None:
index = 0
try:
obs = self.envs[index].reset()
except Exception as e:
if self.restart_failed_sub_environments:
logger.exception(e.args[0])
self.restart_at(index)
obs = e
else:
raise e
return obs
@override(VectorEnv)
def restart_at(self, index: Optional[int] = None) -> None:
if index is None:
index = 0
# Try closing down the old (possibly faulty) sub-env, but ignore errors.
try:
self.envs[index].close()
except Exception as e:
if log_once("close_sub_env"):
logger.warning(
"Trying to close old and replaced sub-environment (at vector "
f"index={index}), but closing resulted in error:\n{e}"
)
# Re-create the sub-env at the new index.
logger.warning(f"Trying to restart sub-environment at index {index}.")
self.envs[index] = self.make_env(index)
logger.warning(f"Sub-environment at index {index} restarted successfully.")
@override(VectorEnv)
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(self.num_envs):
try:
obs, r, done, info = self.envs[i].step(actions[i])
except Exception as e:
if self.restart_failed_sub_environments:
logger.exception(e.args[0])
self.restart_at(i)
obs, r, done, info = e, 0.0, True, {}
else:
raise e
if not isinstance(info, dict):
raise ValueError(
"Info should be a dict, got {} ({})".format(info, type(info))
)
obs_batch.append(obs)
rew_batch.append(r)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
@override(VectorEnv)
def get_sub_environments(self):
return self.envs
@override(VectorEnv)
def try_render_at(self, index: Optional[int] = None):
if index is None:
index = 0
return self.envs[index].render()
@PublicAPI
class VectorEnvWrapper(BaseEnv):
"""Internal adapter of VectorEnv to BaseEnv.
We assume the caller will always send the full vector of actions in each
call to send_actions(), and that they call reset_at() on all completed
environments before calling send_actions().
"""
def __init__(self, vector_env: VectorEnv):
self.vector_env = vector_env
self.num_envs = vector_env.num_envs
self._observation_space = vector_env.observation_space
self._action_space = vector_env.action_space
# Sub-environments' states.
self.new_obs = None
self.cur_rewards = None
self.cur_dones = None
self.cur_infos = None
# At first `poll()`, reset everything (all sub-environments).
self.first_reset_done = False
# Initialize sub-environments' state.
self._init_env_state(idx=None)
@override(BaseEnv)
def poll(
self,
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
from ray.rllib.env.base_env import with_dummy_agent_id
if not self.first_reset_done:
self.first_reset_done = True
self.new_obs = self.vector_env.vector_reset()
new_obs = dict(enumerate(self.new_obs))
rewards = dict(enumerate(self.cur_rewards))
dones = dict(enumerate(self.cur_dones))
infos = dict(enumerate(self.cur_infos))
# Empty all states (in case `poll()` gets called again).
self.new_obs = []
self.cur_rewards = []
self.cur_dones = []
self.cur_infos = []
return (
with_dummy_agent_id(new_obs),
with_dummy_agent_id(rewards),
with_dummy_agent_id(dones, "__all__"),
with_dummy_agent_id(infos),
{},
)
@override(BaseEnv)
def send_actions(self, action_dict: MultiEnvDict) -> None:
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
action_vector = [None] * self.num_envs
for i in range(self.num_envs):
action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
(
self.new_obs,
self.cur_rewards,
self.cur_dones,
self.cur_infos,
) = self.vector_env.vector_step(action_vector)
@override(BaseEnv)
def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict:
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
assert env_id is None or isinstance(env_id, int)
env_id = env_id if env_id is not None else 0
obs = self.vector_env.reset_at(env_id)
return {env_id: obs if isinstance(obs, Exception) else {_DUMMY_AGENT_ID: obs}}
@override(BaseEnv)
def try_restart(self, env_id: Optional[EnvID] = None) -> None:
assert env_id is None or isinstance(env_id, int)
# Restart the sub-env at the index.
self.vector_env.restart_at(env_id)
# Auto-reset (get ready for next `poll()`).
self._init_env_state(env_id)
@override(BaseEnv)
def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
if not as_dict:
return self.vector_env.get_sub_environments()
else:
return {
_id: env
for _id, env in enumerate(self.vector_env.get_sub_environments())
}
@override(BaseEnv)
def try_render(self, env_id: Optional[EnvID] = None) -> None:
assert env_id is None or isinstance(env_id, int)
return self.vector_env.try_render_at(env_id)
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.Space:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@override(BaseEnv)
@PublicAPI
def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
del agent_id
return {0: {_DUMMY_AGENT_ID: self._action_space.sample()}}
@override(BaseEnv)
@PublicAPI
def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
del agent_id
return {0: {_DUMMY_AGENT_ID: self._observation_space.sample()}}
@override(BaseEnv)
@PublicAPI
def get_agent_ids(self) -> Set[AgentID]:
return {_DUMMY_AGENT_ID}
def _init_env_state(self, idx: Optional[int] = None) -> None:
"""Resets all or one particular sub-environment's state (by index).
Args:
idx: The index to reset at. If None, reset all the sub-environments' states.
"""
# If index is None, reset all sub-envs' states:
if idx is None:
self.new_obs = [None for _ in range(self.num_envs)]
self.cur_rewards = [0.0 for _ in range(self.num_envs)]
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [{} for _ in range(self.num_envs)]
# Index provided, reset only the sub-env's state at the given index.
else:
self.new_obs[idx] = self.vector_env.reset_at(idx)
# Reset all other states to null values.
self.cur_rewards[idx] = 0.0
self.cur_dones[idx] = False
self.cur_infos[idx] = {}