mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
135 lines
4.6 KiB
Python
135 lines
4.6 KiB
Python
import logging
|
|
import numpy as np
|
|
from gym.spaces import Discrete
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.env.vector_env import VectorEnv
|
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
|
|
from ray.rllib.utils.typing import EnvType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def model_vector_env(env: EnvType) -> BaseEnv:
|
|
"""Returns a VectorizedEnv wrapper around the given environment.
|
|
|
|
To obtain worker configs, one can call get_global_worker().
|
|
|
|
Args:
|
|
env: The input environment (of any supported environment
|
|
type) to be convert to a _VectorizedModelGymEnv (wrapped as
|
|
an RLlib BaseEnv).
|
|
|
|
Returns:
|
|
BaseEnv: The BaseEnv converted input `env`.
|
|
"""
|
|
worker = get_global_worker()
|
|
worker_index = worker.worker_index
|
|
if worker_index:
|
|
env = _VectorizedModelGymEnv(
|
|
make_env=worker.make_sub_env_fn,
|
|
existing_envs=[env],
|
|
num_envs=worker.num_envs,
|
|
observation_space=env.observation_space,
|
|
action_space=env.action_space,
|
|
)
|
|
return convert_to_base_env(
|
|
env,
|
|
make_env=worker.make_sub_env_fn,
|
|
num_envs=worker.num_envs,
|
|
remote_envs=False,
|
|
remote_env_batch_wait_ms=0,
|
|
)
|
|
|
|
|
|
class _VectorizedModelGymEnv(VectorEnv):
|
|
"""Vectorized Environment Wrapper for MB-MPO.
|
|
|
|
Primary change is in the `vector_step` method, which calls the dynamics
|
|
models for next_obs "calculation" (instead of the actual env). Also, the
|
|
actual envs need to have two extra methods implemented: `reward(obs)` and
|
|
(optionally) `done(obs)`. If `done` is not implemented, we will assume
|
|
that episodes in the env do not terminate, ever.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
make_env=None,
|
|
existing_envs=None,
|
|
num_envs=1,
|
|
*,
|
|
observation_space=None,
|
|
action_space=None,
|
|
env_config=None
|
|
):
|
|
self.make_env = make_env
|
|
self.envs = existing_envs
|
|
self.num_envs = num_envs
|
|
while len(self.envs) < num_envs:
|
|
self.envs.append(self.make_env(len(self.envs)))
|
|
|
|
super().__init__(
|
|
observation_space=observation_space or self.envs[0].observation_space,
|
|
action_space=action_space or self.envs[0].action_space,
|
|
num_envs=num_envs,
|
|
)
|
|
worker = get_global_worker()
|
|
self.model, self.device = worker.foreach_policy(
|
|
lambda x, y: (x.dynamics_model, x.device)
|
|
)[0]
|
|
|
|
@override(VectorEnv)
|
|
def vector_reset(self):
|
|
"""Override parent to store actual env obs for upcoming predictions."""
|
|
self.cur_obs = [e.reset() for e in self.envs]
|
|
return self.cur_obs
|
|
|
|
@override(VectorEnv)
|
|
def reset_at(self, index):
|
|
"""Override parent to store actual env obs for upcoming predictions."""
|
|
obs = self.envs[index].reset()
|
|
self.cur_obs[index] = obs
|
|
return obs
|
|
|
|
@override(VectorEnv)
|
|
def vector_step(self, actions):
|
|
if self.cur_obs is None:
|
|
raise ValueError("Need to reset env first")
|
|
|
|
# If discrete, need to one-hot actions
|
|
if isinstance(self.action_space, Discrete):
|
|
act = np.array(actions)
|
|
new_act = np.zeros((act.size, act.max() + 1))
|
|
new_act[np.arange(act.size), act] = 1
|
|
actions = new_act.astype("float32")
|
|
|
|
# Batch the TD-model prediction.
|
|
obs_batch = np.stack(self.cur_obs, axis=0)
|
|
action_batch = np.stack(actions, axis=0)
|
|
# Predict the next observation, given previous a) real obs
|
|
# (after a reset), b) predicted obs (any other time).
|
|
next_obs_batch = self.model.predict_model_batches(
|
|
obs_batch, action_batch, device=self.device
|
|
)
|
|
next_obs_batch = np.clip(next_obs_batch, -1000, 1000)
|
|
|
|
# Call env's reward function.
|
|
# Note: Each actual env must implement one to output exact rewards.
|
|
rew_batch = self.envs[0].reward(obs_batch, action_batch, next_obs_batch)
|
|
|
|
# If env has a `done` method, use it.
|
|
if hasattr(self.envs[0], "done"):
|
|
dones_batch = self.envs[0].done(next_obs_batch)
|
|
# Otherwise, assume the episode does not end.
|
|
else:
|
|
dones_batch = np.asarray([False for _ in range(self.num_envs)])
|
|
|
|
info_batch = [{} for _ in range(self.num_envs)]
|
|
|
|
self.cur_obs = next_obs_batch
|
|
|
|
return list(next_obs_batch), list(rew_batch), list(dones_batch), info_batch
|
|
|
|
@override(VectorEnv)
|
|
def get_sub_environments(self):
|
|
return self.envs
|