[RLlib] Fix env rendering and recording options (for non-local mode; >0 workers; +evaluation-workers). (#14796)

This commit is contained in:
Sven Mika 2021-03-23 10:06:06 +01:00 committed by GitHub
parent 9ccf291f4d
commit f859ebb99f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 236 additions and 32 deletions

View file

@ -4,6 +4,8 @@ gast
torch>=1.6.0
# Version requirement to match Tune
torchvision>=0.6.0
# For auto-generating a rendering Window.
pyglet
smart_open
# For testing in MuJoCo-like envs (in PyBullet).

View file

@ -124,10 +124,17 @@ COMMON_CONFIG: TrainerConfigDict = {
# If True, try to render the environment on the local worker or on worker
# 1 (if num_workers > 0). For vectorized envs, this usually means that only
# the first sub-environment will be rendered.
# In order for this to work, your env will have to implement the
# `render()` method which either:
# a) handles window generation and rendering itself (returning True) or
# b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
"render_env": False,
# If True, store evaluation videos in the output dir.
# Alternatively, provide a path (str) to a directory here, where the env
# recordings should be stored instead.
# If True, stores videos in this relative directory inside the default
# output dir (~/ray_results/...). Alternatively, you can specify an
# absolute path (str), in which the env recordings should be
# stored instead.
# Set to False for not recording anything.
# Note: This setting replaces the deprecated `monitor` key.
"record_env": False,
# Unsquash actions to the upper and lower bounds of env's action space
"normalize_actions": False,
@ -146,9 +153,6 @@ COMMON_CONFIG: TrainerConfigDict = {
"lr": 0.0001,
# === Debug Settings ===
# Whether to write episode stats and videos to the agent log dir. This is
# typically located in ~/ray_results.
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its workers.
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
# periodically print out summaries of relevant internal dataflow (this is
@ -410,12 +414,15 @@ COMMON_CONFIG: TrainerConfigDict = {
# Default value None allows overwriting with nested dicts
"logger_config": None,
# Deprecated values.
# === Deprecated keys ===
# Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with
# the default optimizer.
# This will be set automatically from now on.
"simple_optimizer": DEPRECATED_VALUE,
# Whether to write episode stats and videos to the agent log dir. This is
# typically located in ~/ray_results.
"monitor": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# yapf: enable
@ -710,7 +717,13 @@ class Trainer(Trainable):
# Assert that user has not unset "in_evaluation".
assert "in_evaluation" not in extra_config or \
extra_config["in_evaluation"] is True
extra_config.update({
evaluation_config = merge_dicts(self.config, extra_config)
# Validate evaluation config.
self._validate_config(evaluation_config)
# Switch on complete_episode rollouts (evaluations are
# always done on n complete episodes) and set the
# `in_evaluation` flag.
evaluation_config.update({
"batch_mode": "complete_episodes",
"in_evaluation": True,
})
@ -721,7 +734,7 @@ class Trainer(Trainable):
env_creator=self.env_creator,
validate_env=None,
policy_class=self._policy_class,
config=merge_dicts(self.config, extra_config),
config=evaluation_config,
num_workers=self.config["evaluation_num_workers"])
self.evaluation_metrics = {}
@ -1138,6 +1151,16 @@ class Trainer(Trainable):
if model_config is None:
config["model"] = model_config = {}
# Monitor should be replaced by `record_env`.
if config.get("monitor", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning("monitor", "record_env", error=False)
config["record_env"] = config.get("monitor", False)
# Empty string would fail some if-blocks checking for this setting.
# Set to True instead, meaning: use default output dir to store
# the videos.
if config.get("record_env") == "":
config["record_env"] = True
# Multi-GPU settings.
simple_optim_setting = config.get("simple_optimizer", DEPRECATED_VALUE)
if simple_optim_setting != DEPRECATED_VALUE:

View file

@ -1,6 +1,5 @@
import logging
import gym
from gym import wrappers as gym_wrappers
import numpy as np
from typing import Callable, List, Optional, Tuple
@ -145,18 +144,6 @@ class _VectorizedGymEnv(VectorEnv):
while len(self.envs) < num_envs:
self.envs.append(make_env(len(self.envs)))
# Wrap all envs with video recorder if necessary.
if policy_config is not None and policy_config.get("record_env"):
def wrapper_(env):
return gym_wrappers.Monitor(
env=env,
directory=policy_config["record_env"],
video_callable=lambda _: True,
force=True)
self.envs = [wrapper_(e) for e in self.envs]
super().__init__(
observation_space=observation_space
or self.envs[0].observation_space,

View file

@ -5,6 +5,7 @@ import logging
import pickle
import platform
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
TYPE_CHECKING, Union
@ -159,7 +160,7 @@ class RolloutWorker(ParallelIteratorWorker):
policy_config: TrainerConfigDict = None,
worker_index: int = 0,
num_workers: int = 0,
monitor_path: str = None,
record_env: Union[bool, str] = False,
log_dir: str = None,
log_level: str = None,
callbacks: Type["DefaultCallbacks"] = None,
@ -182,6 +183,7 @@ class RolloutWorker(ParallelIteratorWorker):
policy: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
monitor_path=None,
):
"""Initialize a rollout worker.
@ -258,8 +260,10 @@ class RolloutWorker(ParallelIteratorWorker):
through EnvContext so that envs can be configured per worker.
num_workers (int): For remote workers, how many workers altogether
have been created?
monitor_path (str): Write out episode stats and videos to this
directory if specified.
record_env (Union[bool, str]): Write out episode stats and videos
using gym.wrappers.Monitor to this directory if specified. If
True, use the default output dir in ~/ray_results/.... If
False, do not record anything.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (Type[DefaultCallbacks]): Custom sub-class of
@ -303,13 +307,17 @@ class RolloutWorker(ParallelIteratorWorker):
_use_trajectory_view_api (bool): Whether to collect samples through
the experimental Trajectory View API.
policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead.
"""
# Deprecated arg.
# Deprecated args.
if policy is not None:
deprecation_warning("policy", "policy_spec", error=False)
policy_spec = policy
assert policy_spec is not None, "Must provide `policy_spec` when " \
"creating RolloutWorker!"
if monitor_path is not None:
deprecation_warning("monitor_path", "record_env", error=False)
record_env = monitor_path
self._original_kwargs: dict = locals().copy()
del self._original_kwargs["self"]
@ -419,16 +427,44 @@ class RolloutWorker(ParallelIteratorWorker):
dim=model_config.get("dim"),
framestack=framestack,
framestack_via_traj_view_api=framestack_traj_view)
if monitor_path:
if record_env:
from gym import wrappers
env = wrappers.Monitor(env, monitor_path, resume=True)
path_ = record_env if isinstance(record_env,
str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not re.search("[/\\\]", path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
env = wrappers.Monitor(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation"
if policy_config["in_evaluation"] else "training")
return env
else:
def wrap(env):
if monitor_path:
if record_env:
from gym import wrappers
env = wrappers.Monitor(env, monitor_path, resume=True)
path_ = record_env if isinstance(record_env,
str) else log_dir
# Relative path: Add logdir here, otherwise, this would
# not work for non-local workers.
if not re.search("[/\\\]", path_):
path_ = os.path.join(log_dir, path_)
print(f"Setting the path for recording to {path_}")
env = wrappers.Monitor(
env,
path_,
resume=True,
force=True,
video_callable=lambda _: True,
mode="evaluation"
if policy_config["in_evaluation"] else "training")
return env
self.env: EnvType = wrap(self.env)

View file

@ -519,6 +519,9 @@ def _env_runner(
terminal condition, and other fields as dictated by `policy`.
"""
# May be populated with used for image rendering
simple_image_viewer: Optional["SimpleImageViewer"] = None
# Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
# error and continue with max_episode_steps=None.
max_episode_steps = None
@ -704,7 +707,24 @@ def _env_runner(
# Try to render the env, if required.
if render:
t5 = time.time()
base_env.try_render()
# Render can either return an RGB image (uint8 [w x h x 3] numpy
# array) or take care of rendering itself (returning True).
rendered = base_env.try_render()
# Rendering returned an image -> Display it in a SimpleImageViewer.
if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
# ImageViewer not defined yet, try to create one.
if simple_image_viewer is None:
try:
from gym.envs.classic_control.rendering import \
SimpleImageViewer
simple_image_viewer = SimpleImageViewer()
except (ImportError, ModuleNotFoundError):
render = False # disable rendering
logger.warning(
"Could not import gym.envs.classic_control."
"rendering! Try `pip install gym[all]`.")
if simple_image_viewer:
simple_image_viewer.imshow(rendered)
perf_stats.env_render_time += time.time() - t5

View file

@ -339,7 +339,7 @@ class WorkerSet:
policy_config=config,
worker_index=worker_index,
num_workers=num_workers,
monitor_path=self._logdir if config["monitor"] else None,
record_env=config["record_env"],
log_dir=self._logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],

View file

@ -0,0 +1,136 @@
# ---------------
# IMPORTANT NOTE:
# ---------------
# A recent bug in openAI gym prevents RLlib's "record_env" option
# from recording videos properly. Instead, the produced mp4 files
# have a size of 1kb and are corrupted.
# A simple fix for this is described here:
# https://github.com/openai/gym/issues/1925
import argparse
import gym
import numpy as np
import ray
from gym.spaces import Box, Discrete
from ray import tune
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument("--stop-iters", type=int, default=10)
parser.add_argument("--stop-timesteps", type=int, default=10000)
parser.add_argument("--stop-reward", type=float, default=9.0)
class CustomRenderedEnv(gym.Env):
"""Example of a custom env, for which you can specify rendering behavior.
"""
# Must specify, which render modes are supported by your custom env.
# For RLlib to render your env via the "render_env" config key, only
# mode="rgb_array" is needed. RLlib will automatically produce a simple
# viewer for the returned RGB-images for mode="human", such that you don't
# have to provide your own window+render handling.
metadata = {
"render.modes": ["rgb_array"],
}
def __init__(self, config):
self.end_pos = config.get("corridor_length", 10)
self.max_steps = config.get("max_steps", 100)
self.cur_pos = 0
self.steps = 0
self.action_space = Discrete(2)
self.observation_space = Box(0.0, 999.0, shape=(1, ), dtype=np.float32)
def reset(self):
self.cur_pos = 0.0
self.steps = 0
return [self.cur_pos]
def step(self, action):
self.steps += 1
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1.0
elif action == 1:
self.cur_pos += 1.0
done = self.cur_pos >= self.end_pos or \
self.steps >= self.max_steps
return [self.cur_pos], 10.0 if done else -0.1, done, {}
def render(self, mode="rgb"):
"""Implements rendering logic for this env (given current state).
You can either return an RGB image:
np.array([height, width, 3], dtype=np.uint8) or take care of
rendering in a window yourself here (return True then).
For RLlib, though, only mode=rgb (returning an image) is needed,
even when "render_env" is True in the RLlib config.
Args:
mode (str): One of "rgb", "human", or "ascii". See gym.Env for
more information.
Returns:
Union[np.ndarray, bool]: An image to render or True (if rendering
is handled entirely in here).
"""
# Just generate a random image here for demonstration purposes.
# Also see `gym/envs/classic_control/cartpole.py` for
# an example on how to use a Viewer object.
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
if __name__ == "__main__":
# Note: Recording and rendering in this example
# should work for both local_mode=True|False.
ray.init(num_cpus=4)
args = parser.parse_args()
# Example config causing
config = {
# Also try common gym envs like: "CartPole-v0" or "Pendulum-v0".
"env": CustomRenderedEnv,
"env_config": {
"corridor_length": 10,
"max_steps": 100,
},
# Evaluate once per training iteration.
"evaluation_interval": 1,
# Run evaluation on (at least) two episodes
"evaluation_num_episodes": 2,
# ... using one evaluation worker (setting this to 0 will cause
# evaluation to run on the local evaluation worker, blocking
# training until evaluation is done).
"evaluation_num_workers": 1,
# Special evaluation config. Keys specified here will override
# the same keys in the main config, but only for evaluation.
"evaluation_config": {
# Store videos in this relative directory here inside
# the default output dir (~/ray_results/...).
# Alternatively, you can specify an absolute path.
# Set to True for using the default output dir (~/ray_results/...).
# Set to False for not recording anything.
"record_env": "videos",
# "record_env": "/Users/xyz/my_videos/",
# Render the env while evaluating.
# Note that this will always only render the 1st RolloutWorker's
# env and only the 1st sub-env in a vectorized env.
"render_env": True,
},
"num_workers": 1,
# Use a vectorized env with 2 sub-envs.
"num_envs_per_worker": 2,
"framework": args.framework,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run("PPO", config=config, stop=stop)