[RLlib] Fix Atari learning test regressions (2 bugs) and 1 minor attention net bug. (#18306)

This commit is contained in:
Sven Mika 2021-09-03 13:29:57 +02:00 committed by GitHub
parent fb38d06cfb
commit 9a8ca6a69d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 97 additions and 176 deletions

View file

@ -52,6 +52,33 @@ apex-breakoutnoframeskip-v4:
target_network_update_freq: 50000
timesteps_per_iteration: 25000
appo-pong-no-frameskip-v4:
env: PongNoFrameskip-v4
run: APPO
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 18.0
timesteps_total: 5000000
stop:
time_total_s: 2000
config:
vtrace: True
use_kl_loss: False
rollout_fragment_length: 50
train_batch_size: 750
num_workers: 31
broadcast_interval: 1
max_sample_requests_in_flight_per_worker: 1
num_multi_gpu_tower_stacks: 1
num_envs_per_worker: 8
num_sgd_iter: 2
vf_loss_coeff: 1.0
clip_param: 0.3
num_gpus: 1
grad_clip: 10
model:
dim: 42
ddpg-hopperbulletenv-v0:
env: HopperBulletEnv-v0
run: DDPG

View file

@ -1845,7 +1845,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=20"]
args = ["--as-test", "--stop-reward=60"]
)
py_test(
@ -1854,7 +1854,7 @@ py_test(
tags = ["team:ml", "examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=20", "--framework torch"]
args = ["--as-test", "--stop-reward=60", "--framework torch"]
)
py_test(

View file

@ -12,10 +12,12 @@ from ray.rllib.utils import check, check_compute_single_action, fc, \
class TestPG(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls) -> None:
ray.init()
def tearDown(self):
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_pg_compilation(self):

2
rllib/env/utils.py vendored
View file

@ -61,7 +61,7 @@ a) For Atari support: `pip install gym[atari] atari_py`.
For VizDoom support: Install VizDoom
(https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
`pip install vizdoomgym`.
For PyBullet support: `pip install pybullet pybullet_envs`.
For PyBullet support: `pip install pybullet`.
b) To register your custom env, do `from ray import tune;
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
Then in your config, do `config['env'] = [name]`.

View file

@ -283,17 +283,13 @@ class ScaledFloatFrame(gym.ObservationWrapper):
return np.array(observation).astype(np.float32) / 255.0
def wrap_deepmind(
env,
dim=84,
# TODO: (sven) Remove once traj. view is norm.
framestack=True,
framestack_via_traj_view_api=False):
def wrap_deepmind(env, dim=84, framestack=True):
"""Configure environment for DeepMind-style Atari.
Note that we assume reward clipping is done outside the wrapper.
Args:
env (EnvType): The env object to wrap.
dim (int): Dimension to resize observations to (dim x dim).
framestack (bool): Whether to framestack observations.
"""
@ -307,12 +303,7 @@ def wrap_deepmind(
env = WarpFrame(env, dim)
# env = ScaledFloatFrame(env) # TODO: use for dqn?
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
# New way of frame stacking via the trajectory view API (model config key:
# `num_framestacks=[int]`.
if framestack_via_traj_view_api:
env = FrameStackTrajectoryView(env)
# Old way (w/o traj. view API) via model config key: `framestack=True`.
# TODO: (sven) Remove once traj. view is norm.
elif framestack is True:
# 4x image framestacking.
if framestack is True:
env = FrameStack(env, 4)
return env

View file

@ -192,8 +192,10 @@ class _AgentCollector:
d.itemsize * int(np.product(d.shape[i + 1:]))
for i in range(1, len(d.shape))
]
start = self.shift_before - shift_win + 1 + obs_shift + \
view_req.shift_to
data = np.lib.stride_tricks.as_strided(
d[self.shift_before - shift_win:],
d[start:start + self.agent_steps],
[self.agent_steps, shift_win
] + [d.shape[i] for i in range(1, len(d.shape))],
[data_size, data_size] + strides)

View file

@ -31,7 +31,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.error import EnvError
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
@ -400,7 +400,8 @@ class RolloutWorker(ParallelIteratorWorker):
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
model_config: ModelConfigDict = \
model_config or self.policy_config.get("model") or {}
# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
# independent on the agent ID and the episode passed in.
@ -464,27 +465,14 @@ class RolloutWorker(ParallelIteratorWorker):
if clip_rewards is None:
clip_rewards = True
# Deprecated way of framestacking is used.
framestack = model_config.get("framestack") is True
# framestacking via trajectory view API is enabled.
num_framestacks = model_config.get("num_framestacks", 0)
# Trajectory view API is on and num_framestacks=auto:
# Only stack traj. view based if old
# `framestack=[invalid value]`.
if num_framestacks == "auto":
if framestack == DEPRECATED_VALUE:
model_config["num_framestacks"] = num_framestacks = 4
else:
model_config["num_framestacks"] = num_framestacks = 0
framestack_traj_view = num_framestacks > 1
# Framestacking is used.
use_framestack = model_config.get("framestack") is True
def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=framestack,
framestack_via_traj_view_api=framestack_traj_view)
framestack=use_framestack)
env = record_env_wrapper(env, record_env, log_dir,
policy_config)
return env
@ -740,7 +728,8 @@ class RolloutWorker(ParallelIteratorWorker):
return self.last_batch
elif self.input_reader is None:
raise ValueError("RolloutWorker has no `input_reader` object! "
"Cannot call `sample()`.")
"Cannot call `sample()`. You can try setting "
"`create_env_on_driver` to True.")
if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
@ -1423,6 +1412,8 @@ def _determine_spaces_for_multi_agent_dict(
policy_config: Optional[PartialTrainerConfigDict] = None,
) -> MultiAgentPolicyConfigDict:
policy_config = policy_config or {}
# Try extracting spaces from env or from given spaces dict.
env_obs_space = None
env_act_space = None
@ -1455,7 +1446,7 @@ def _determine_spaces_for_multi_agent_dict(
obs_space = spaces[pid][0]
elif env_obs_space is not None:
obs_space = env_obs_space
elif policy_config and policy_config.get("observation_space"):
elif policy_config.get("observation_space"):
obs_space = policy_config["observation_space"]
else:
raise ValueError(
@ -1463,6 +1454,7 @@ def _determine_spaces_for_multi_agent_dict(
f"{pid} and env does not have an observation space OR "
"no spaces received from other workers' env(s) OR no "
"`observation_space` specified in config!")
multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
observation_space=obs_space)
@ -1471,7 +1463,7 @@ def _determine_spaces_for_multi_agent_dict(
act_space = spaces[pid][1]
elif env_act_space is not None:
act_space = env_act_space
elif policy_config and policy_config.get("action_space"):
elif policy_config.get("action_space"):
act_space = policy_config["action_space"]
else:
raise ValueError(
@ -1489,8 +1481,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."
allowed_types = [
gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv,
ray.actor.ActorHandle
gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle
]
if not any(isinstance(env, tpe) for tpe in allowed_types):
# Allow this as a special case (assumed gym.Env).
@ -1508,7 +1499,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
f"(type={type(env)}).")
# Do some test runs with the provided env.
if isinstance(env, gym.Env):
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
# Make sure the gym.Env has the two space attributes properly set.
assert hasattr(env, "observation_space") and hasattr(
env, "action_space")

View file

@ -418,7 +418,6 @@ class WorkerSet:
normalize_actions=config["normalize_actions"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
num_workers=num_workers,

View file

@ -131,17 +131,8 @@ MODEL_DEFAULTS: ModelConfigDict = {
"attention_use_n_prev_rewards": 0,
# == Atari ==
# Which framestacking size to use for Atari envs.
# "auto": Use a value of 4, but only if the env is an Atari env.
# > 1: Use the trajectory view API in the default VisionNets to request the
# last n observations (single, grayscaled 84x84 image frames) as
# inputs. The time axis in the so provided observation tensors
# will come right after the batch axis (channels first format),
# e.g. BxTx84x84, where T=num_framestacks.
# 0 or 1: No framestacking used.
# Use the deprecated `framestack=True`, to disable the above behavor and to
# enable legacy stacking behavior (w/o trajectory view API) instead.
"num_framestacks": "auto",
# Set to True to enable 4x stacking behavior.
"framestack": True,
# Final resized frame dimension
"dim": 84,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
@ -166,8 +157,6 @@ MODEL_DEFAULTS: ModelConfigDict = {
# Deprecated keys:
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
# Use `num_framestacks` (int) instead.
"framestack": True,
}
# __sphinx_doc_end__
# yapf: enable
@ -807,10 +796,6 @@ class ModelCatalog:
"framework={} not supported in `ModelCatalog._get_v2_model_"
"class`!".format(framework))
# Discrete/1D obs-spaces or 2D obs space but traj. view framestacking
# disabled.
num_framestacks = model_config.get("num_framestacks", "auto")
# Tuple space, where at least one sub-space is image.
# -> Complex input model.
space_to_check = input_space if not hasattr(
@ -824,8 +809,7 @@ class ModelCatalog:
# Single, flattenable/one-hot-able space -> Simple FCNet.
if isinstance(input_space, (Discrete, MultiDiscrete)) or \
len(input_space.shape) == 1 or (
len(input_space.shape) == 2 and (
num_framestacks == "auto" or num_framestacks <= 1)):
len(input_space.shape) == 2):
# Keras native requested AND no auto-rnn-wrapping.
if model_config.get("_use_default_native_models") and Keras_FCNet:
return Keras_FCNet
@ -886,10 +870,3 @@ class ModelCatalog:
elif config.get("use_lstm"):
raise ValueError("`use_lstm` not available for "
"framework=jax so far!")
if config.get("framestack") != DEPRECATED_VALUE:
# deprecation_warning(
# old="framestack", new="num_framestacks (int)", error=False)
# If old behavior is desired, disable traj. view-style
# framestacking.
config["num_framestacks"] = 0

View file

@ -5,7 +5,6 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
@ -44,17 +43,9 @@ class VisionNetwork(TFModelV2):
no_final_linear = self.model_config.get("no_final_linear")
vf_share_layers = self.model_config.get("vf_share_layers")
self.traj_view_framestacking = False
# Perform Atari framestacking via traj. view API.
if model_config.get("num_framestacks") != "auto" and \
model_config.get("num_framestacks", 0) > 1:
input_shape = obs_space.shape + (model_config["num_framestacks"], )
self.data_format = "channels_first"
self.traj_view_framestacking = True
else:
input_shape = obs_space.shape
self.data_format = "channels_last"
input_shape = obs_space.shape
self.data_format = "channels_last"
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
last_layer = inputs
@ -93,7 +84,10 @@ class VisionNetwork(TFModelV2):
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
if post_fcnet_hiddens else
[])
feature_out = last_layer
for i, out_size in enumerate(layer_sizes):
feature_out = last_layer
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
@ -126,6 +120,7 @@ class VisionNetwork(TFModelV2):
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
[num_outputs]):
feature_out = last_layer
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i + 1),
@ -134,6 +129,7 @@ class VisionNetwork(TFModelV2):
kernel_initializer=normc_initializer(1.0))(
last_layer)
else:
feature_out = last_layer
last_cnn = last_layer = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
@ -164,19 +160,20 @@ class VisionNetwork(TFModelV2):
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
feature_out = last_layer
self.num_outputs = last_layer.shape[1]
logits_out = last_layer
# Build the value layers
if vf_share_layers:
if not self.last_layer_is_flattened:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
feature_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(feature_out)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
kernel_initializer=normc_initializer(0.01))(feature_out)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
@ -211,20 +208,6 @@ class VisionNetwork(TFModelV2):
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
# Optional: framestacking obs/new_obs for Atari.
if self.traj_view_framestacking:
from_ = model_config["num_framestacks"] - 1
self.view_requirements[SampleBatch.OBS].shift = \
"-{}:0".format(from_)
self.view_requirements[SampleBatch.OBS].shift_from = -from_
self.view_requirements[SampleBatch.OBS].shift_to = 0
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
data_col=SampleBatch.OBS,
shift="-{}:1".format(from_ - 1),
space=self.view_requirements[SampleBatch.OBS].space,
used_for_compute_actions=False,
)
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
@ -281,18 +264,8 @@ class Keras_VisionNetwork(tf.keras.Model if tf else object):
post_fcnet_activation = get_activation_fn(
post_fcnet_activation, framework="tf")
self.traj_view_framestacking = False
# Perform Atari framestacking via traj. view API.
num_framestacks = kwargs.get("num_framestacks")
if num_framestacks != "auto" and num_framestacks and \
num_framestacks > 1:
input_shape = input_space.shape + (num_framestacks, )
self.data_format = "channels_first"
self.traj_view_framestacking = True
else:
input_shape = input_space.shape
self.data_format = "channels_last"
input_shape = input_space.shape
self.data_format = "channels_last"
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
last_layer = inputs
@ -448,20 +421,6 @@ class Keras_VisionNetwork(tf.keras.Model if tf else object):
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
# Optional: framestacking obs/new_obs for Atari.
if self.traj_view_framestacking:
from_ = num_framestacks - 1
self.view_requirements[SampleBatch.OBS].shift = \
"-{}:0".format(from_)
self.view_requirements[SampleBatch.OBS].shift_from = -from_
self.view_requirements[SampleBatch.OBS].shift_to = 0
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
data_col=SampleBatch.OBS,
shift="-{}:1".format(from_ - 1),
space=self.view_requirements[SampleBatch.OBS].space,
used_for_compute_actions=False,
)
def call(self, input_dict: SampleBatch) -> \
(TensorType, List[TensorType], Dict[str, TensorType]):
obs = input_dict["obs"]

View file

@ -6,8 +6,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import normc_initializer, same_padding, \
SlimConv2d, SlimFC
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
@ -46,17 +44,9 @@ class VisionNetwork(TorchModelV2, nn.Module):
# a n x (1,1) Conv2D).
self.last_layer_is_flattened = False
self._logits = None
self.traj_view_framestacking = False
layers = []
# Perform Atari framestacking via traj. view API.
if model_config.get("num_framestacks") != "auto" and \
model_config.get("num_framestacks", 0) > 1:
(w, h) = obs_space.shape
in_channels = model_config["num_framestacks"]
self.traj_view_framestacking = True
else:
(w, h, in_channels) = obs_space.shape
(w, h, in_channels) = obs_space.shape
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
@ -165,11 +155,7 @@ class VisionNetwork(TorchModelV2, nn.Module):
activation_fn=None)
else:
vf_layers = []
if self.traj_view_framestacking:
(w, h) = obs_space.shape
in_channels = model_config["num_framestacks"]
else:
(w, h, in_channels) = obs_space.shape
(w, h, in_channels) = obs_space.shape
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = same_padding(in_size, kernel, stride)
@ -207,27 +193,13 @@ class VisionNetwork(TorchModelV2, nn.Module):
# Holds the current "base" output (before logits layer).
self._features = None
# Optional: framestacking obs/new_obs for Atari.
if self.traj_view_framestacking:
from_ = model_config["num_framestacks"] - 1
self.view_requirements[SampleBatch.OBS].shift = \
"-{}:0".format(from_)
self.view_requirements[SampleBatch.OBS].shift_from = -from_
self.view_requirements[SampleBatch.OBS].shift_to = 0
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
data_col=SampleBatch.OBS,
shift="-{}:1".format(from_ - 1),
space=self.view_requirements[SampleBatch.OBS].space,
)
@override(TorchModelV2)
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
self._features = input_dict["obs"].float()
# No framestacking:
if not self.traj_view_framestacking:
self._features = self._features.permute(0, 3, 1, 2)
# Permuate b/c data comes in as [B, dim, dim, channels]:
self._features = self._features.permute(0, 3, 1, 2)
conv_out = self._convs(self._features)
# Store features to save forward pass when getting value_function out.
if not self._value_branch_separate:

View file

@ -943,20 +943,27 @@ class SampleBatch(dict):
# Range needed.
if view_req.shift_from is not None:
data = self[view_col][-1]
traj_len = len(self[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
obs_shift = -1 if data_col in [
SampleBatch.OBS, SampleBatch.NEXT_OBS
] else 0
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array([
np.concatenate(
[data,
self[data_col][-missing_at_end:]])[from_:to_]
])
# Batch repeat value > 1: We have single frames in the
# batch at each timestep.
if view_req.batch_repeat_value > 1:
traj_len = len(self[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
obs_shift = -1 if data_col in [
SampleBatch.OBS, SampleBatch.NEXT_OBS
] else 0
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array([
np.concatenate(
[self[data_col][-missing_at_end:],
data])[from_:to_]
])
# Batch repeat value = 1: We already have framestacks
# at each timestep.
else:
input_dict[view_col] = data[None]
# Single index.
else:
data = self[data_col][-1]

View file

@ -336,12 +336,6 @@ def check_compute_single_action(trainer,
call_kwargs["clip_actions"] = True
obs = obs_space.sample()
# Framestacking w/ traj. view API.
framestacks = pol.config["model"].get("num_framestacks",
"auto")
if isinstance(framestacks, int) and framestacks > 1:
obs = np.stack(
[obs] * pol.config["model"]["num_framestacks"])
if isinstance(obs_space, gym.spaces.Box):
obs = np.clip(obs, -1.0, 1.0)
state_in = None