mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
101 lines
3.6 KiB
Python
101 lines
3.6 KiB
Python
import gym
|
|
from gym import wrappers
|
|
import os
|
|
|
|
from ray.rllib.env.env_context import EnvContext
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
from ray.rllib.utils import add_mixins
|
|
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
|
|
|
|
|
|
def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
|
|
"""Tries to create a gym env given an EnvContext object and descriptor.
|
|
|
|
Note: This function tries to construct the env from a string descriptor
|
|
only using possibly installed RL env packages (such as gym, pybullet_envs,
|
|
vizdoomgym, etc..). These packages are no installation requirements for
|
|
RLlib. In case you would like to support more such env packages, add the
|
|
necessary imports and construction logic below.
|
|
|
|
Args:
|
|
env_context: The env context object to configure the env.
|
|
Note that this is a config dict, plus the properties:
|
|
`worker_index`, `vector_index`, and `remote`.
|
|
env_descriptor: The env descriptor, e.g. CartPole-v0,
|
|
MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
|
|
CartPoleContinuousBulletEnv-v0.
|
|
|
|
Returns:
|
|
The actual gym environment object.
|
|
|
|
Raises:
|
|
gym.error.Error: If the env cannot be constructed.
|
|
"""
|
|
# Allow for PyBullet or VizdoomGym envs to be used as well
|
|
# (via string). This allows for doing things like
|
|
# `env=CartPoleContinuousBulletEnv-v0` or
|
|
# `env=VizdoomBasic-v0`.
|
|
try:
|
|
import pybullet_envs
|
|
|
|
pybullet_envs.getList()
|
|
except (ModuleNotFoundError, ImportError):
|
|
pass
|
|
try:
|
|
import vizdoomgym
|
|
|
|
vizdoomgym.__name__ # trick LINTer.
|
|
except (ModuleNotFoundError, ImportError):
|
|
pass
|
|
|
|
# Try creating a gym env. If this fails we can output a
|
|
# decent error message.
|
|
try:
|
|
return gym.make(env_descriptor, **env_context)
|
|
except gym.error.Error:
|
|
raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))
|
|
|
|
|
|
class VideoMonitor(wrappers.Monitor):
|
|
# Same as original method, but doesn't use the StatsRecorder as it will
|
|
# try to add up multi-agent rewards dicts, which throws errors.
|
|
def _after_step(self, observation, reward, done, info):
|
|
if not self.enabled:
|
|
return done
|
|
|
|
# Use done["__all__"] b/c this is a multi-agent dict.
|
|
if done["__all__"] and self.env_semantics_autoreset:
|
|
# For envs with BlockingReset wrapping VNCEnv, this observation
|
|
# will be the first one of the new episode
|
|
self.reset_video_recorder()
|
|
self.episode_id += 1
|
|
self._flush()
|
|
|
|
# Record video
|
|
self.video_recorder.capture_frame()
|
|
|
|
return done
|
|
|
|
|
|
def record_env_wrapper(env, record_env, log_dir, policy_config):
|
|
if record_env:
|
|
path_ = record_env if isinstance(record_env, str) else log_dir
|
|
# Relative path: Add logdir here, otherwise, this would
|
|
# not work for non-local workers.
|
|
if not os.path.isabs(path_):
|
|
path_ = os.path.join(log_dir, path_)
|
|
print(f"Setting the path for recording to {path_}")
|
|
wrapper_cls = (
|
|
VideoMonitor if isinstance(env, MultiAgentEnv) else wrappers.Monitor
|
|
)
|
|
if isinstance(env, MultiAgentEnv):
|
|
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
|
|
env = wrapper_cls(
|
|
env,
|
|
path_,
|
|
resume=True,
|
|
force=True,
|
|
video_callable=lambda _: True,
|
|
mode="evaluation" if policy_config["in_evaluation"] else "training",
|
|
)
|
|
return env
|