From c288b97e5f6f0d376b074927e3ff5596ff902dcc Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 24 Jan 2022 19:38:21 +0100 Subject: [PATCH] [RLlib] Issue 21629: Video recorder env wrapper not working. Added test case. (#21670) --- python/requirements/ml/requirements_rllib.txt | 1 + rllib/env/base_env.py | 9 ++- rllib/env/external_env.py | 2 +- rllib/env/multi_agent_env.py | 2 +- rllib/env/tests/test_record_env_wrapper.py | 61 +++++++++++++++---- rllib/env/utils.py | 14 +++-- rllib/env/vector_env.py | 2 +- rllib/examples/env/mock_env.py | 11 ++++ rllib/examples/env/multi_agent.py | 11 ++++ 9 files changed, 91 insertions(+), 22 deletions(-) diff --git a/python/requirements/ml/requirements_rllib.txt b/python/requirements/ml/requirements_rllib.txt index 600207108..3c182da51 100644 --- a/python/requirements/ml/requirements_rllib.txt +++ b/python/requirements/ml/requirements_rllib.txt @@ -26,6 +26,7 @@ tensorflow_estimator==2.6.0 higher==0.2.1 # For auto-generating an env-rendering Window. pyglet==1.5.15 +imageio-ffmpeg==0.4.5 # For JSON reader/writer. smart_open==5.0.0 # Ray Serve example diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 780f10ebd..a1e0b451f 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -79,7 +79,7 @@ class BaseEnv: def to_base_env( self, - make_env: Callable[[int], EnvType] = None, + make_env: Optional[Callable[[int], EnvType]] = None, num_envs: int = 1, remote_envs: bool = False, remote_env_batch_wait_ms: int = 0, @@ -729,7 +729,12 @@ def convert_to_base_env( # Given `env` is already a BaseEnv -> Return as is. if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)): - return env.to_base_env() + return env.to_base_env( + make_env=make_env, + num_envs=num_envs, + remote_envs=remote_envs, + remote_env_batch_wait_ms=remote_env_batch_wait_ms, + ) # `env` is not a BaseEnv yet -> Need to convert/vectorize. else: # Sub-environments are ray.remote actors: diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index 302f82a43..19601223a 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -190,7 +190,7 @@ class ExternalEnv(threading.Thread): def to_base_env( self, - make_env: Callable[[int], EnvType] = None, + make_env: Optional[Callable[[int], EnvType]] = None, num_envs: int = 1, remote_envs: bool = False, remote_env_batch_wait_ms: int = 0, diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index 1f5c78cd7..fc45ecc54 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -263,7 +263,7 @@ class MultiAgentEnv(gym.Env): @PublicAPI def to_base_env( self, - make_env: Callable[[int], EnvType] = None, + make_env: Optional[Callable[[int], EnvType]] = None, num_envs: int = 1, remote_envs: bool = False, remote_env_batch_wait_ms: int = 0, diff --git a/rllib/env/tests/test_record_env_wrapper.py b/rllib/env/tests/test_record_env_wrapper.py index 00d68a9eb..58cd32926 100644 --- a/rllib/env/tests/test_record_env_wrapper.py +++ b/rllib/env/tests/test_record_env_wrapper.py @@ -1,51 +1,90 @@ -from gym import wrappers -import tempfile +import glob +import gym +import numpy as np +import os +import shutil import unittest from ray.rllib.env.utils import VideoMonitor, record_env_wrapper from ray.rllib.examples.env.mock_env import MockEnv2 from ray.rllib.examples.env.multi_agent import BasicMultiAgent +from ray.rllib.utils.test_utils import check class TestRecordEnvWrapper(unittest.TestCase): def test_wrap_gym_env(self): + record_env_dir = os.popen("mktemp -d").read()[:-1] + print(f"tmp dir for videos={record_env_dir}") + + if not os.path.exists(record_env_dir): + sys.exit(1) + + num_steps_per_episode = 10 wrapped = record_env_wrapper( - env=MockEnv2(10), - record_env=tempfile.gettempdir(), + env=MockEnv2(num_steps_per_episode), + record_env=record_env_dir, log_dir="", policy_config={ "in_evaluation": False, }) - # Type is wrappers.Monitor. - self.assertTrue(isinstance(wrapped, wrappers.Monitor)) + # Non MultiAgentEnv: Wrapper's type is wrappers.Monitor. + self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor)) self.assertFalse(isinstance(wrapped, VideoMonitor)) wrapped.reset() + # Expect one video file to have been produced in the tmp dir. + os.chdir(record_env_dir) + ls = glob.glob("*.mp4") + self.assertTrue(len(ls) == 1) # 10 steps for a complete episode. - for i in range(10): + for i in range(num_steps_per_episode): wrapped.step(0) + # Another episode. + wrapped.reset() + for i in range(num_steps_per_episode): + wrapped.step(0) + # Expect another video file to have been produced (2nd episode). + ls = glob.glob("*.mp4") + self.assertTrue(len(ls) == 2) # MockEnv2 returns a reward of 100.0 every step. - # So total reward is 1000.0. - self.assertEqual(wrapped.get_episode_rewards(), [1000.0]) + # So total reward is 1000.0 per episode (10 steps). + check( + np.array([100.0, 100.0]) * num_steps_per_episode, + wrapped.get_episode_rewards()) + # Erase all generated files and the temp path just in case, + # as to not disturb further CI-tests. + shutil.rmtree(record_env_dir) def test_wrap_multi_agent_env(self): + record_env_dir = os.popen("mktemp -d").read()[:-1] + print(f"tmp dir for videos={record_env_dir}") + + if not os.path.exists(record_env_dir): + sys.exit(1) + wrapped = record_env_wrapper( env=BasicMultiAgent(3), - record_env=tempfile.gettempdir(), + record_env=record_env_dir, log_dir="", policy_config={ "in_evaluation": False, }) # Type is VideoMonitor. - self.assertTrue(isinstance(wrapped, wrappers.Monitor)) + self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor)) self.assertTrue(isinstance(wrapped, VideoMonitor)) wrapped.reset() + # BasicMultiAgent is hardcoded to run 25-step episodes. for i in range(25): wrapped.step({0: 0, 1: 0, 2: 0}) + # Expect one video file to have been produced in the tmp dir. + os.chdir(record_env_dir) + ls = glob.glob("*.mp4") + self.assertTrue(len(ls) == 1) + # However VideoMonitor's _after_step is overwritten to not # use stats_recorder. So nothing to verify here, except that # it runs fine. diff --git a/rllib/env/utils.py b/rllib/env/utils.py index 43d4a7eab..fbf56f9a0 100644 --- a/rllib/env/utils.py +++ b/rllib/env/utils.py @@ -1,3 +1,4 @@ +import gym from gym import wrappers import os @@ -7,7 +8,7 @@ from ray.rllib.utils import add_mixins from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError -def gym_env_creator(env_context: EnvContext, env_descriptor: str): +def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env: """Tries to create a gym env given an EnvContext object and descriptor. Note: This function tries to construct the env from a string descriptor @@ -17,20 +18,19 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str): necessary imports and construction logic below. Args: - env_context (EnvContext): The env context object to configure the env. + env_context: The env context object to configure the env. Note that this is a config dict, plus the properties: `worker_index`, `vector_index`, and `remote`. - env_descriptor (str): The env descriptor, e.g. CartPole-v0, + env_descriptor: The env descriptor, e.g. CartPole-v0, MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or CartPoleContinuousBulletEnv-v0. Returns: - gym.Env: The actual gym environment object. + The actual gym environment object. Raises: gym.error.Error: If the env cannot be constructed. """ - import gym # Allow for PyBullet or VizdoomGym envs to be used as well # (via string). This allows for doing things like # `env=CartPoleContinuousBulletEnv-v0` or @@ -85,7 +85,9 @@ def record_env_wrapper(env, record_env, log_dir, policy_config): print(f"Setting the path for recording to {path_}") wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \ else wrappers.Monitor - wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True) + if isinstance(env, MultiAgentEnv): + wrapper_cls = add_mixins( + wrapper_cls, [MultiAgentEnv], reversed=True) env = wrapper_cls( env, path_, diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 18e0476f9..78ce6ee11 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -138,7 +138,7 @@ class VectorEnv: @PublicAPI def to_base_env( self, - make_env: Callable[[int], EnvType] = None, + make_env: Optional[Callable[[int], EnvType]] = None, num_envs: int = 1, remote_envs: bool = False, remote_env_batch_wait_ms: int = 0, diff --git a/rllib/examples/env/mock_env.py b/rllib/examples/env/mock_env.py index f7950bfaf..5e587fa3a 100644 --- a/rllib/examples/env/mock_env.py +++ b/rllib/examples/env/mock_env.py @@ -1,4 +1,5 @@ import gym +import numpy as np from ray.rllib.env.vector_env import VectorEnv from ray.rllib.utils.annotations import override @@ -34,6 +35,10 @@ class MockEnv2(gym.Env): configurable. Actions are ignored. """ + metadata = { + "render.modes": ["rgb_array"], + } + def __init__(self, episode_length): self.episode_length = episode_length self.i = 0 @@ -52,6 +57,12 @@ class MockEnv2(gym.Env): def seed(self, rng_seed): self.rng_seed = rng_seed + def render(self, mode="rgb_array"): + # Just generate a random image here for demonstration purposes. + # Also see `gym/envs/classic_control/cartpole.py` for + # an example on how to use a Viewer object. + return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8) + class MockEnv3(gym.Env): """Mock environment for testing purposes. diff --git a/rllib/examples/env/multi_agent.py b/rllib/examples/env/multi_agent.py index d080a45df..b975c68ef 100644 --- a/rllib/examples/env/multi_agent.py +++ b/rllib/examples/env/multi_agent.py @@ -1,4 +1,5 @@ import gym +import numpy as np import random from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent @@ -18,6 +19,10 @@ def make_multiagent(env_name_or_creator): class BasicMultiAgent(MultiAgentEnv): """Env of N independent agents, each of which exits after 25 steps.""" + metadata = { + "render.modes": ["rgb_array"], + } + def __init__(self, num): super().__init__() self.agents = [MockEnv(25) for _ in range(num)] @@ -40,6 +45,12 @@ class BasicMultiAgent(MultiAgentEnv): done["__all__"] = len(self.dones) == len(self.agents) return obs, rew, done, info + def render(self, mode="rgb_array"): + # Just generate a random image here for demonstration purposes. + # Also see `gym/envs/classic_control/cartpole.py` for + # an example on how to use a Viewer object. + return np.random.randint(0, 256, size=(200, 300, 3), dtype=np.uint8) + class EarlyDoneMultiAgent(MultiAgentEnv): """Env for testing when the env terminates (after agent 0 does)."""