mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Issue 21629: Video recorder env wrapper not working. Added test case. (#21670)
This commit is contained in:
parent
2010f13175
commit
c288b97e5f
9 changed files with 91 additions and 22 deletions
|
@ -26,6 +26,7 @@ tensorflow_estimator==2.6.0
|
|||
higher==0.2.1
|
||||
# For auto-generating an env-rendering Window.
|
||||
pyglet==1.5.15
|
||||
imageio-ffmpeg==0.4.5
|
||||
# For JSON reader/writer.
|
||||
smart_open==5.0.0
|
||||
# Ray Serve example
|
||||
|
|
9
rllib/env/base_env.py
vendored
9
rllib/env/base_env.py
vendored
|
@ -79,7 +79,7 @@ class BaseEnv:
|
|||
|
||||
def to_base_env(
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
make_env: Optional[Callable[[int], EnvType]] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
|
@ -729,7 +729,12 @@ def convert_to_base_env(
|
|||
|
||||
# Given `env` is already a BaseEnv -> Return as is.
|
||||
if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
|
||||
return env.to_base_env()
|
||||
return env.to_base_env(
|
||||
make_env=make_env,
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_envs,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
||||
)
|
||||
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
|
||||
else:
|
||||
# Sub-environments are ray.remote actors:
|
||||
|
|
2
rllib/env/external_env.py
vendored
2
rllib/env/external_env.py
vendored
|
@ -190,7 +190,7 @@ class ExternalEnv(threading.Thread):
|
|||
|
||||
def to_base_env(
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
make_env: Optional[Callable[[int], EnvType]] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
|
|
2
rllib/env/multi_agent_env.py
vendored
2
rllib/env/multi_agent_env.py
vendored
|
@ -263,7 +263,7 @@ class MultiAgentEnv(gym.Env):
|
|||
@PublicAPI
|
||||
def to_base_env(
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
make_env: Optional[Callable[[int], EnvType]] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
|
|
61
rllib/env/tests/test_record_env_wrapper.py
vendored
61
rllib/env/tests/test_record_env_wrapper.py
vendored
|
@ -1,51 +1,90 @@
|
|||
from gym import wrappers
|
||||
import tempfile
|
||||
import glob
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from ray.rllib.env.utils import VideoMonitor, record_env_wrapper
|
||||
from ray.rllib.examples.env.mock_env import MockEnv2
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestRecordEnvWrapper(unittest.TestCase):
|
||||
def test_wrap_gym_env(self):
|
||||
record_env_dir = os.popen("mktemp -d").read()[:-1]
|
||||
print(f"tmp dir for videos={record_env_dir}")
|
||||
|
||||
if not os.path.exists(record_env_dir):
|
||||
sys.exit(1)
|
||||
|
||||
num_steps_per_episode = 10
|
||||
wrapped = record_env_wrapper(
|
||||
env=MockEnv2(10),
|
||||
record_env=tempfile.gettempdir(),
|
||||
env=MockEnv2(num_steps_per_episode),
|
||||
record_env=record_env_dir,
|
||||
log_dir="",
|
||||
policy_config={
|
||||
"in_evaluation": False,
|
||||
})
|
||||
# Type is wrappers.Monitor.
|
||||
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
|
||||
# Non MultiAgentEnv: Wrapper's type is wrappers.Monitor.
|
||||
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
|
||||
self.assertFalse(isinstance(wrapped, VideoMonitor))
|
||||
|
||||
wrapped.reset()
|
||||
# Expect one video file to have been produced in the tmp dir.
|
||||
os.chdir(record_env_dir)
|
||||
ls = glob.glob("*.mp4")
|
||||
self.assertTrue(len(ls) == 1)
|
||||
# 10 steps for a complete episode.
|
||||
for i in range(10):
|
||||
for i in range(num_steps_per_episode):
|
||||
wrapped.step(0)
|
||||
# Another episode.
|
||||
wrapped.reset()
|
||||
for i in range(num_steps_per_episode):
|
||||
wrapped.step(0)
|
||||
# Expect another video file to have been produced (2nd episode).
|
||||
ls = glob.glob("*.mp4")
|
||||
self.assertTrue(len(ls) == 2)
|
||||
|
||||
# MockEnv2 returns a reward of 100.0 every step.
|
||||
# So total reward is 1000.0.
|
||||
self.assertEqual(wrapped.get_episode_rewards(), [1000.0])
|
||||
# So total reward is 1000.0 per episode (10 steps).
|
||||
check(
|
||||
np.array([100.0, 100.0]) * num_steps_per_episode,
|
||||
wrapped.get_episode_rewards())
|
||||
# Erase all generated files and the temp path just in case,
|
||||
# as to not disturb further CI-tests.
|
||||
shutil.rmtree(record_env_dir)
|
||||
|
||||
def test_wrap_multi_agent_env(self):
|
||||
record_env_dir = os.popen("mktemp -d").read()[:-1]
|
||||
print(f"tmp dir for videos={record_env_dir}")
|
||||
|
||||
if not os.path.exists(record_env_dir):
|
||||
sys.exit(1)
|
||||
|
||||
wrapped = record_env_wrapper(
|
||||
env=BasicMultiAgent(3),
|
||||
record_env=tempfile.gettempdir(),
|
||||
record_env=record_env_dir,
|
||||
log_dir="",
|
||||
policy_config={
|
||||
"in_evaluation": False,
|
||||
})
|
||||
# Type is VideoMonitor.
|
||||
self.assertTrue(isinstance(wrapped, wrappers.Monitor))
|
||||
self.assertTrue(isinstance(wrapped, gym.wrappers.Monitor))
|
||||
self.assertTrue(isinstance(wrapped, VideoMonitor))
|
||||
|
||||
wrapped.reset()
|
||||
|
||||
# BasicMultiAgent is hardcoded to run 25-step episodes.
|
||||
for i in range(25):
|
||||
wrapped.step({0: 0, 1: 0, 2: 0})
|
||||
|
||||
# Expect one video file to have been produced in the tmp dir.
|
||||
os.chdir(record_env_dir)
|
||||
ls = glob.glob("*.mp4")
|
||||
self.assertTrue(len(ls) == 1)
|
||||
|
||||
# However VideoMonitor's _after_step is overwritten to not
|
||||
# use stats_recorder. So nothing to verify here, except that
|
||||
# it runs fine.
|
||||
|
|
14
rllib/env/utils.py
vendored
14
rllib/env/utils.py
vendored
|
@ -1,3 +1,4 @@
|
|||
import gym
|
||||
from gym import wrappers
|
||||
import os
|
||||
|
||||
|
@ -7,7 +8,7 @@ from ray.rllib.utils import add_mixins
|
|||
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
|
||||
|
||||
|
||||
def gym_env_creator(env_context: EnvContext, env_descriptor: str):
|
||||
def gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
|
||||
"""Tries to create a gym env given an EnvContext object and descriptor.
|
||||
|
||||
Note: This function tries to construct the env from a string descriptor
|
||||
|
@ -17,20 +18,19 @@ def gym_env_creator(env_context: EnvContext, env_descriptor: str):
|
|||
necessary imports and construction logic below.
|
||||
|
||||
Args:
|
||||
env_context (EnvContext): The env context object to configure the env.
|
||||
env_context: The env context object to configure the env.
|
||||
Note that this is a config dict, plus the properties:
|
||||
`worker_index`, `vector_index`, and `remote`.
|
||||
env_descriptor (str): The env descriptor, e.g. CartPole-v0,
|
||||
env_descriptor: The env descriptor, e.g. CartPole-v0,
|
||||
MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
|
||||
CartPoleContinuousBulletEnv-v0.
|
||||
|
||||
Returns:
|
||||
gym.Env: The actual gym environment object.
|
||||
The actual gym environment object.
|
||||
|
||||
Raises:
|
||||
gym.error.Error: If the env cannot be constructed.
|
||||
"""
|
||||
import gym
|
||||
# Allow for PyBullet or VizdoomGym envs to be used as well
|
||||
# (via string). This allows for doing things like
|
||||
# `env=CartPoleContinuousBulletEnv-v0` or
|
||||
|
@ -85,7 +85,9 @@ def record_env_wrapper(env, record_env, log_dir, policy_config):
|
|||
print(f"Setting the path for recording to {path_}")
|
||||
wrapper_cls = VideoMonitor if isinstance(env, MultiAgentEnv) \
|
||||
else wrappers.Monitor
|
||||
wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
wrapper_cls = add_mixins(
|
||||
wrapper_cls, [MultiAgentEnv], reversed=True)
|
||||
env = wrapper_cls(
|
||||
env,
|
||||
path_,
|
||||
|
|
2
rllib/env/vector_env.py
vendored
2
rllib/env/vector_env.py
vendored
|
@ -138,7 +138,7 @@ class VectorEnv:
|
|||
@PublicAPI
|
||||
def to_base_env(
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
make_env: Optional[Callable[[int], EnvType]] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
|
|
11
rllib/examples/env/mock_env.py
vendored
11
rllib/examples/env/mock_env.py
vendored
|
@ -1,4 +1,5 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -34,6 +35,10 @@ class MockEnv2(gym.Env):
|
|||
configurable. Actions are ignored.
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render.modes": ["rgb_array"],
|
||||
}
|
||||
|
||||
def __init__(self, episode_length):
|
||||
self.episode_length = episode_length
|
||||
self.i = 0
|
||||
|
@ -52,6 +57,12 @@ class MockEnv2(gym.Env):
|
|||
def seed(self, rng_seed):
|
||||
self.rng_seed = rng_seed
|
||||
|
||||
def render(self, mode="rgb_array"):
|
||||
# Just generate a random image here for demonstration purposes.
|
||||
# Also see `gym/envs/classic_control/cartpole.py` for
|
||||
# an example on how to use a Viewer object.
|
||||
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
class MockEnv3(gym.Env):
|
||||
"""Mock environment for testing purposes.
|
||||
|
|
11
rllib/examples/env/multi_agent.py
vendored
11
rllib/examples/env/multi_agent.py
vendored
|
@ -1,4 +1,5 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent
|
||||
|
@ -18,6 +19,10 @@ def make_multiagent(env_name_or_creator):
|
|||
class BasicMultiAgent(MultiAgentEnv):
|
||||
"""Env of N independent agents, each of which exits after 25 steps."""
|
||||
|
||||
metadata = {
|
||||
"render.modes": ["rgb_array"],
|
||||
}
|
||||
|
||||
def __init__(self, num):
|
||||
super().__init__()
|
||||
self.agents = [MockEnv(25) for _ in range(num)]
|
||||
|
@ -40,6 +45,12 @@ class BasicMultiAgent(MultiAgentEnv):
|
|||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
|
||||
def render(self, mode="rgb_array"):
|
||||
# Just generate a random image here for demonstration purposes.
|
||||
# Also see `gym/envs/classic_control/cartpole.py` for
|
||||
# an example on how to use a Viewer object.
|
||||
return np.random.randint(0, 256, size=(200, 300, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
class EarlyDoneMultiAgent(MultiAgentEnv):
|
||||
"""Env for testing when the env terminates (after agent 0 does)."""
|
||||
|
|
Loading…
Add table
Reference in a new issue