ray/rllib/env/vector_env.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

122 lines
3.8 KiB
Python

import logging
import numpy as np
from ray.rllib.utils.annotations import override, PublicAPI
logger = logging.getLogger(__name__)
@PublicAPI
class VectorEnv:
"""An environment that supports batch evaluation.
Subclasses must define the following attributes:
Attributes:
action_space (gym.Space): Action space of individual envs.
observation_space (gym.Space): Observation space of individual envs.
num_envs (int): Number of envs in this vector env.
"""
@staticmethod
def wrap(make_env=None,
existing_envs=None,
num_envs=1,
action_space=None,
observation_space=None):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
action_space, observation_space)
@PublicAPI
def vector_reset(self):
"""Resets all environments.
Returns:
obs (list): Vector of observations from each environment.
"""
raise NotImplementedError
@PublicAPI
def reset_at(self, index):
"""Resets a single environment.
Returns:
obs (obj): Observations from the resetted environment.
"""
raise NotImplementedError
@PublicAPI
def vector_step(self, actions):
"""Vectorized step.
Arguments:
actions (list): Actions for each env.
Returns:
obs (list): New observations for each env.
rewards (list): Reward values for each env.
dones (list): Done values for each env.
infos (list): Info values for each env.
"""
raise NotImplementedError
@PublicAPI
def get_unwrapped(self):
"""Returns the underlying env instances."""
raise NotImplementedError
class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper for gym envs to implement VectorEnv.
Arguments:
make_env (func|None): Factory that produces a new gym env. Must be
defined if the number of existing envs is less than num_envs.
existing_envs (list): List of existing gym envs.
num_envs (int): Desired num gym envs to keep total.
"""
def __init__(self,
make_env,
existing_envs,
num_envs,
action_space=None,
observation_space=None):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
while len(self.envs) < self.num_envs:
self.envs.append(self.make_env(len(self.envs)))
self.action_space = action_space or self.envs[0].action_space
self.observation_space = observation_space or \
self.envs[0].observation_space
@override(VectorEnv)
def vector_reset(self):
return [e.reset() for e in self.envs]
@override(VectorEnv)
def reset_at(self, index):
return self.envs[index].reset()
@override(VectorEnv)
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(self.num_envs):
obs, r, done, info = self.envs[i].step(actions[i])
if not np.isscalar(r) or not np.isreal(r) or not np.isfinite(r):
raise ValueError(
"Reward should be finite scalar, got {} ({})".format(
r, type(r)))
if type(info) is not dict:
raise ValueError("Info should be a dict, got {} ({})".format(
info, type(info)))
obs_batch.append(obs)
rew_batch.append(r)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
@override(VectorEnv)
def get_unwrapped(self):
return self.envs