ray/rllib/env/wrappers/model_vector_env.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

136 lines
4.6 KiB
Python
Raw Normal View History

import logging
import numpy as np
from gym.spaces import Discrete
from ray.rllib.utils.annotations import override
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.utils.typing import EnvType
logger = logging.getLogger(__name__)
def model_vector_env(env: EnvType) -> BaseEnv:
"""Returns a VectorizedEnv wrapper around the given environment.
To obtain worker configs, one can call get_global_worker().
Args:
env (EnvType): The input environment (of any supported environment
type) to be convert to a _VectorizedModelGymEnv (wrapped as
an RLlib BaseEnv).
Returns:
BaseEnv: The BaseEnv converted input `env`.
"""
worker = get_global_worker()
worker_index = worker.worker_index
if worker_index:
env = _VectorizedModelGymEnv(
make_env=worker.make_sub_env_fn,
existing_envs=[env],
num_envs=worker.num_envs,
observation_space=env.observation_space,
action_space=env.action_space,
)
return convert_to_base_env(
env,
make_env=worker.make_sub_env_fn,
num_envs=worker.num_envs,
remote_envs=False,
remote_env_batch_wait_ms=0,
)
class _VectorizedModelGymEnv(VectorEnv):
"""Vectorized Environment Wrapper for MB-MPO.
Primary change is in the `vector_step` method, which calls the dynamics
models for next_obs "calculation" (instead of the actual env). Also, the
actual envs need to have two extra methods implemented: `reward(obs)` and
(optionally) `done(obs)`. If `done` is not implemented, we will assume
that episodes in the env do not terminate, ever.
"""
def __init__(
self,
make_env=None,
existing_envs=None,
num_envs=1,
*,
observation_space=None,
action_space=None,
env_config=None
):
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
while len(self.envs) < num_envs:
self.envs.append(self.make_env(len(self.envs)))
super().__init__(
observation_space=observation_space or self.envs[0].observation_space,
action_space=action_space or self.envs[0].action_space,
num_envs=num_envs,
)
worker = get_global_worker()
self.model, self.device = worker.foreach_policy(
lambda x, y: (x.dynamics_model, x.device)
)[0]
@override(VectorEnv)
def vector_reset(self):
"""Override parent to store actual env obs for upcoming predictions."""
self.cur_obs = [e.reset() for e in self.envs]
return self.cur_obs
@override(VectorEnv)
def reset_at(self, index):
"""Override parent to store actual env obs for upcoming predictions."""
obs = self.envs[index].reset()
self.cur_obs[index] = obs
return obs
@override(VectorEnv)
def vector_step(self, actions):
if self.cur_obs is None:
raise ValueError("Need to reset env first")
# If discrete, need to one-hot actions
if isinstance(self.action_space, Discrete):
act = np.array(actions)
new_act = np.zeros((act.size, act.max() + 1))
new_act[np.arange(act.size), act] = 1
actions = new_act.astype("float32")
# Batch the TD-model prediction.
obs_batch = np.stack(self.cur_obs, axis=0)
action_batch = np.stack(actions, axis=0)
# Predict the next observation, given previous a) real obs
# (after a reset), b) predicted obs (any other time).
next_obs_batch = self.model.predict_model_batches(
obs_batch, action_batch, device=self.device
)
next_obs_batch = np.clip(next_obs_batch, -1000, 1000)
# Call env's reward function.
# Note: Each actual env must implement one to output exact rewards.
rew_batch = self.envs[0].reward(obs_batch, action_batch, next_obs_batch)
# If env has a `done` method, use it.
if hasattr(self.envs[0], "done"):
dones_batch = self.envs[0].done(next_obs_batch)
# Otherwise, assume the episode does not end.
else:
dones_batch = np.asarray([False for _ in range(self.num_envs)])
info_batch = [{} for _ in range(self.num_envs)]
self.cur_obs = next_obs_batch
return list(next_obs_batch), list(rew_batch), list(dones_batch), info_batch
@override(VectorEnv)
def get_sub_environments(self):
return self.envs