[RLlib] Issue 21629: Video recorder env wrapper not working. Added test case. (#21670)

This commit is contained in:
Sven Mika 2022-01-24 19:38:21 +01:00 committed by GitHub
parent 2010f13175
commit c288b97e5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 22 deletions

View file

@ -26,6 +26,7 @@ tensorflow_estimator==2.6.0
higher==0.2.1
# For auto-generating an env-rendering Window.
pyglet==1.5.15
imageio-ffmpeg==0.4.5
# For JSON reader/writer.
smart_open==5.0.0
# Ray Serve example

View file

@ -79,7 +79,7 @@ class BaseEnv:
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
@ -729,7 +729,12 @@ def convert_to_base_env(
# Given `env` is already a BaseEnv -> Return as is.
if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
return env.to_base_env()
return env.to_base_env(
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
)
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
else:
# Sub-environments are ray.remote actors:

View file

@ -190,7 +190,7 @@ class ExternalEnv(threading.Thread):
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,

View file

@ -263,7 +263,7 @@ class MultiAgentEnv(gym.Env):
@PublicAPI
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,

View file

@ -1,51 +1,90 @@
from gym import wrappers
import tempfile
import glob
import gym
import numpy as np
import os
import shutil
import unittest
from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
from ray.rllib.examples.env.mock_env import MockEnv2
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
from ray.rllib.utils.test_utils import check
class TestRecordEnvWrapper(unittest.TestCase):
def test_wrap_gym_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")
if not os.path.exists(record_env_dir):
sys.exit(1)
num_steps_per_episode = 10
wrapped = record_env_wrapper(
env=MockEnv2(10),
record_env=tempfile.gettempdir(),
env=MockEnv2(num_steps_per_episode),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
})
# Type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertFalse(isinstance(wrapped, VideoMonitor))
wrapped.reset()
# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)
# 10 steps for a complete episode.
for i in range(10):
for i in range(num_steps_per_episode):
wrapped.step(0)
# Another episode.
wrapped.reset()
for i in range(num_steps_per_episode):
wrapped.step(0)
# Expect another video file to have been produced (2nd episode).
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 2)
# MockEnv2 returns a reward of 100.0 every step.
# So total reward is 1000.0.
self.assertEqual(wrapped.get_episode_rewards(), [1000.0])
# So total reward is 1000.0 per episode (10 steps).
check(
np.array([100.0, 100.0]) * num_steps_per_episode,
wrapped.get_episode_rewards())
# Erase all generated files and the temp path just in case,
# as to not disturb further CI-tests.
shutil.rmtree(record_env_dir)
def test_wrap_multi_agent_env(self):
record_env_dir = os.popen("mktemp -d").read()[:-1]
print(f"tmp dir for videos={record_env_dir}")
if not os.path.exists(record_env_dir):
sys.exit(1)
wrapped = record_env_wrapper(
env=BasicMultiAgent(3),
record_env=tempfile.gettempdir(),
record_env=record_env_dir,
log_dir="",
policy_config={
"in_evaluation": False,
})
# Type is VideoMonitor.
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
self.assertTrue(isinstance(wrapped, VideoMonitor))
wrapped.reset()
# BasicMultiAgent is hardcoded to run 25-step episodes.
for i in range(25):
wrapped.step({0: 0, 1: 0, 2: 0})
# Expect one video file to have been produced in the tmp dir.
os.chdir(record_env_dir)
ls = glob.glob("*.mp4")
self.assertTrue(len(ls) == 1)
# However VideoMonitor's _after_step is overwritten to not
# use stats_recorder. So nothing to verify here, except that
# it runs fine.

14
rllib/env/utils.py vendored
View file

@ -1,3 +1,4 @@
import gym
from gym import wrappers
import os
@ -7,7 +8,7 @@ from ray.rllib.utils import add_mixins
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
def gym_env_creator(env_context: EnvContext, env_descriptor: str):
def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
"""Tries to create a gym env given an EnvContext object and descriptor.
Note: This function tries to construct the env from a string descriptor
@ -17,20 +18,19 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
necessary imports and construction logic below.
Args:
env_context (EnvContext): The env context object to configure the env.
env_context: The env context object to configure the env.
Note that this is a config dict, plus the properties:
`worker_index`, `vector_index`, and `remote`.
env_descriptor (str): The env descriptor, e.g. CartPole-v0,
env_descriptor: The env descriptor, e.g. CartPole-v0,
MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
CartPoleContinuousBulletEnv-v0.
Returns:
gym.Env: The actual gym environment object.
The actual gym environment object.
Raises:
gym.error.Error: If the env cannot be constructed.
"""
import gym
# Allow for PyBullet or VizdoomGym envs to be used as well
# (via string). This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0` or
@ -85,7 +85,9 @@ def record_env_wrapper(env, record_env, log_dir, policy_config):
print(f"Setting the path for recording to {path_}")
wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
else wrappers.Monitor
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
if isinstance(env, MultiAgentEnv):
wrapper_cls = add_mixins(
wrapper_cls, [MultiAgentEnv], reversed=True)
env = wrapper_cls(
env,
path_,

View file

@ -138,7 +138,7 @@ class VectorEnv:
@PublicAPI
def to_base_env(
self,
make_env: Callable[[int], EnvType] = None,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,

View file

@ -1,4 +1,5 @@
import gym
import numpy as np
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.annotations import override
@ -34,6 +35,10 @@ class MockEnv2(gym.Env):
configurable. Actions are ignored.
"""
metadata = {
"render.modes": ["rgb_array"],
}
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
@ -52,6 +57,12 @@ class MockEnv2(gym.Env):
def seed(self, rng_seed):
self.rng_seed = rng_seed
def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
class MockEnv3(gym.Env):
"""Mock environment for testing purposes.

View file

@ -1,4 +1,5 @@
import gym
import numpy as np
import random
from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
@ -18,6 +19,10 @@ def make_multiagent(env_name_or_creator):
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""
metadata = {
"render.modes": ["rgb_array"],
}
def __init__(self, num):
super().__init__()
self.agents = [MockEnv(25) for _ in range(num)]
@ -40,6 +45,12 @@ class BasicMultiAgent(MultiAgentEnv):
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
def render(self, mode="rgb_array"):
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(200, 300, 3), dtype=np.uint8)
class EarlyDoneMultiAgent(MultiAgentEnv):
"""Env for testing when the env terminates (after agent 0 does)."""