ray/rllib/env/wrappers/model_vector_env.py

135 lines
4.6 KiB
Python

import logging
import numpy as np
from gym.spaces import Discrete
from ray.rllib.utils.annotations import override
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.utils.typing import EnvType
logger = logging.getLogger(__name__)
def model_vector_env(env: EnvType) -> BaseEnv:
"""Returns a VectorizedEnv wrapper around the given environment.
To obtain worker configs, one can call get_global_worker().
Args:
env: The input environment (of any supported environment
type) to be convert to a _VectorizedModelGymEnv (wrapped as
an RLlib BaseEnv).
Returns:
BaseEnv: The BaseEnv converted input `env`.
"""
worker = get_global_worker()
worker_index = worker.worker_index
if worker_index:
env = _VectorizedModelGymEnv(
make_env=worker.make_sub_env_fn,
existing_envs=[env],
num_envs=worker.num_envs,
observation_space=env.observation_space,
action_space=env.action_space,
)
return convert_to_base_env(
env,
make_env=worker.make_sub_env_fn,
num_envs=worker.num_envs,
remote_envs=False,
remote_env_batch_wait_ms=0,
)
class _VectorizedModelGymEnv(VectorEnv):
"""Vectorized Environment Wrapper for MB-MPO.
Primary change is in the `vector_step` method, which calls the dynamics
models for next_obs "calculation" (instead of the actual env). Also, the
actual envs need to have two extra methods implemented: `reward(obs)` and
(optionally) `done(obs)`. If `done` is not implemented, we will assume
that episodes in the env do not terminate, ever.
"""
def __init__(
self,
make_env=None,
existing_envs=None,
num_envs=1,
*,
observation_space=None,
action_space=None,
env_config=None
):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
while len(self.envs) < num_envs:
self.envs.append(self.make_env(len(self.envs)))
super().__init__(
observation_space=observation_space or self.envs[0].observation_space,
action_space=action_space or self.envs[0].action_space,
num_envs=num_envs,
)
worker = get_global_worker()
self.model, self.device = worker.foreach_policy(
lambda x, y: (x.dynamics_model, x.device)
)[0]
@override(VectorEnv)
def vector_reset(self):
"""Override parent to store actual env obs for upcoming predictions."""
self.cur_obs = [e.reset() for e in self.envs]
return self.cur_obs
@override(VectorEnv)
def reset_at(self, index):
"""Override parent to store actual env obs for upcoming predictions."""
obs = self.envs[index].reset()
self.cur_obs[index] = obs
return obs
@override(VectorEnv)
def vector_step(self, actions):
if self.cur_obs is None:
raise ValueError("Need to reset env first")
# If discrete, need to one-hot actions
if isinstance(self.action_space, Discrete):
act = np.array(actions)
new_act = np.zeros((act.size, act.max() + 1))
new_act[np.arange(act.size), act] = 1
actions = new_act.astype("float32")
# Batch the TD-model prediction.
obs_batch = np.stack(self.cur_obs, axis=0)
action_batch = np.stack(actions, axis=0)
# Predict the next observation, given previous a) real obs
# (after a reset), b) predicted obs (any other time).
next_obs_batch = self.model.predict_model_batches(
obs_batch, action_batch, device=self.device
)
next_obs_batch = np.clip(next_obs_batch, -1000, 1000)
# Call env's reward function.
# Note: Each actual env must implement one to output exact rewards.
rew_batch = self.envs[0].reward(obs_batch, action_batch, next_obs_batch)
# If env has a `done` method, use it.
if hasattr(self.envs[0], "done"):
dones_batch = self.envs[0].done(next_obs_batch)
# Otherwise, assume the episode does not end.
else:
dones_batch = np.asarray([False for _ in range(self.num_envs)])
info_batch = [{} for _ in range(self.num_envs)]
self.cur_obs = next_obs_batch
return list(next_obs_batch), list(rew_batch), list(dones_batch), info_batch
@override(VectorEnv)
def get_sub_environments(self):
return self.envs