ray/rllib/env/vector_env.py
2020-01-02 17:42:13 -08:00

126 lines
3.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from ray.rllib.utils.annotations import override, PublicAPI
logger = logging.getLogger(__name__)
@PublicAPI
class VectorEnv:
"""An environment that supports batch evaluation.
Subclasses must define the following attributes:
Attributes:
action_space (gym.Space): Action space of individual envs.
observation_space (gym.Space): Observation space of individual envs.
num_envs (int): Number of envs in this vector env.
"""
@staticmethod
def wrap(make_env=None,
existing_envs=None,
num_envs=1,
action_space=None,
observation_space=None):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
action_space, observation_space)
@PublicAPI
def vector_reset(self):
"""Resets all environments.
Returns:
obs (list): Vector of observations from each environment.
"""
raise NotImplementedError
@PublicAPI
def reset_at(self, index):
"""Resets a single environment.
Returns:
obs (obj): Observations from the resetted environment.
"""
raise NotImplementedError
@PublicAPI
def vector_step(self, actions):
"""Vectorized step.
Arguments:
actions (list): Actions for each env.
Returns:
obs (list): New observations for each env.
rewards (list): Reward values for each env.
dones (list): Done values for each env.
infos (list): Info values for each env.
"""
raise NotImplementedError
@PublicAPI
def get_unwrapped(self):
"""Returns the underlying env instances."""
raise NotImplementedError
class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper for gym envs to implement VectorEnv.
Arguments:
make_env (func|None): Factory that produces a new gym env. Must be
defined if the number of existing envs is less than num_envs.
existing_envs (list): List of existing gym envs.
num_envs (int): Desired num gym envs to keep total.
"""
def __init__(self,
make_env,
existing_envs,
num_envs,
action_space=None,
observation_space=None):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env(len(self.envs)))
self.action_space = action_space or self.envs[0].action_space
self.observation_space = observation_space or \
self.envs[0].observation_space
@override(VectorEnv)
def vector_reset(self):
return [e.reset() for e in self.envs]
@override(VectorEnv)
def reset_at(self, index):
return self.envs[index].reset()
@override(VectorEnv)
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(self.num_envs):
obs, r, done, info = self.envs[i].step(actions[i])
if not np.isscalar(r) or not np.isreal(r) or not np.isfinite(r):
raise ValueError(
"Reward should be finite scalar, got {} ({})".format(
r, type(r)))
if type(info) is not dict:
raise ValueError("Info should be a dict, got {} ({})".format(
info, type(info)))
obs_batch.append(obs)
rew_batch.append(r)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
@override(VectorEnv)
def get_unwrapped(self):
return self.envs