mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix Atari learning test regressions (2 bugs) and 1 minor attention net bug. (#18306)
This commit is contained in:
parent
fb38d06cfb
commit
9a8ca6a69d
13 changed files with 97 additions and 176 deletions
|
@ -52,6 +52,33 @@ apex-breakoutnoframeskip-v4:
|
|||
target_network_update_freq: 50000
|
||||
timesteps_per_iteration: 25000
|
||||
|
||||
appo-pong-no-frameskip-v4:
|
||||
env: PongNoFrameskip-v4
|
||||
run: APPO
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 18.0
|
||||
timesteps_total: 5000000
|
||||
stop:
|
||||
time_total_s: 2000
|
||||
config:
|
||||
vtrace: True
|
||||
use_kl_loss: False
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 750
|
||||
num_workers: 31
|
||||
broadcast_interval: 1
|
||||
max_sample_requests_in_flight_per_worker: 1
|
||||
num_multi_gpu_tower_stacks: 1
|
||||
num_envs_per_worker: 8
|
||||
num_sgd_iter: 2
|
||||
vf_loss_coeff: 1.0
|
||||
clip_param: 0.3
|
||||
num_gpus: 1
|
||||
grad_clip: 10
|
||||
model:
|
||||
dim: 42
|
||||
|
||||
ddpg-hopperbulletenv-v0:
|
||||
env: HopperBulletEnv-v0
|
||||
run: DDPG
|
||||
|
|
|
@ -1845,7 +1845,7 @@ py_test(
|
|||
tags = ["team:ml", "examples", "examples_A"],
|
||||
size = "medium",
|
||||
srcs = ["examples/attention_net.py"],
|
||||
args = ["--as-test", "--stop-reward=20"]
|
||||
args = ["--as-test", "--stop-reward=60"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -1854,7 +1854,7 @@ py_test(
|
|||
tags = ["team:ml", "examples", "examples_A"],
|
||||
size = "medium",
|
||||
srcs = ["examples/attention_net.py"],
|
||||
args = ["--as-test", "--stop-reward=20", "--framework torch"]
|
||||
args = ["--as-test", "--stop-reward=60", "--framework torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -12,10 +12,12 @@ from ray.rllib.utils import check, check_compute_single_action, fc, \
|
|||
|
||||
|
||||
class TestPG(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_pg_compilation(self):
|
||||
|
|
2
rllib/env/utils.py
vendored
2
rllib/env/utils.py
vendored
|
@ -61,7 +61,7 @@ a) For Atari support: `pip install gym[atari] atari_py`.
|
|||
For VizDoom support: Install VizDoom
|
||||
(https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md) and
|
||||
`pip install vizdoomgym`.
|
||||
For PyBullet support: `pip install pybullet pybullet_envs`.
|
||||
For PyBullet support: `pip install pybullet`.
|
||||
b) To register your custom env, do `from ray import tune;
|
||||
tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
|
||||
Then in your config, do `config['env'] = [name]`.
|
||||
|
|
17
rllib/env/wrappers/atari_wrappers.py
vendored
17
rllib/env/wrappers/atari_wrappers.py
vendored
|
@ -283,17 +283,13 @@ class ScaledFloatFrame(gym.ObservationWrapper):
|
|||
return np.array(observation).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def wrap_deepmind(
|
||||
env,
|
||||
dim=84,
|
||||
# TODO: (sven) Remove once traj. view is norm.
|
||||
framestack=True,
|
||||
framestack_via_traj_view_api=False):
|
||||
def wrap_deepmind(env, dim=84, framestack=True):
|
||||
"""Configure environment for DeepMind-style Atari.
|
||||
|
||||
Note that we assume reward clipping is done outside the wrapper.
|
||||
|
||||
Args:
|
||||
env (EnvType): The env object to wrap.
|
||||
dim (int): Dimension to resize observations to (dim x dim).
|
||||
framestack (bool): Whether to framestack observations.
|
||||
"""
|
||||
|
@ -307,12 +303,7 @@ def wrap_deepmind(
|
|||
env = WarpFrame(env, dim)
|
||||
# env = ScaledFloatFrame(env) # TODO: use for dqn?
|
||||
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
|
||||
# New way of frame stacking via the trajectory view API (model config key:
|
||||
# `num_framestacks=[int]`.
|
||||
if framestack_via_traj_view_api:
|
||||
env = FrameStackTrajectoryView(env)
|
||||
# Old way (w/o traj. view API) via model config key: `framestack=True`.
|
||||
# TODO: (sven) Remove once traj. view is norm.
|
||||
elif framestack is True:
|
||||
# 4x image framestacking.
|
||||
if framestack is True:
|
||||
env = FrameStack(env, 4)
|
||||
return env
|
||||
|
|
|
@ -192,8 +192,10 @@ class _AgentCollector:
|
|||
d.itemsize * int(np.product(d.shape[i + 1:]))
|
||||
for i in range(1, len(d.shape))
|
||||
]
|
||||
start = self.shift_before - shift_win + 1 + obs_shift + \
|
||||
view_req.shift_to
|
||||
data = np.lib.stride_tricks.as_strided(
|
||||
d[self.shift_before - shift_win:],
|
||||
d[start:start + self.agent_steps],
|
||||
[self.agent_steps, shift_win
|
||||
] + [d.shape[i] for i in range(1, len(d.shape))],
|
||||
[data_size, data_size] + strides)
|
||||
|
|
|
@ -31,7 +31,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy
|
|||
from ray.rllib.utils import force_list, merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.error import EnvError
|
||||
from ray.rllib.utils.filter import get_filter, Filter
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
@ -400,7 +400,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.callbacks: DefaultCallbacks = DefaultCallbacks()
|
||||
self.worker_index: int = worker_index
|
||||
self.num_workers: int = num_workers
|
||||
model_config: ModelConfigDict = model_config or {}
|
||||
model_config: ModelConfigDict = \
|
||||
model_config or self.policy_config.get("model") or {}
|
||||
|
||||
# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
|
||||
# independent on the agent ID and the episode passed in.
|
||||
|
@ -464,27 +465,14 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
if clip_rewards is None:
|
||||
clip_rewards = True
|
||||
|
||||
# Deprecated way of framestacking is used.
|
||||
framestack = model_config.get("framestack") is True
|
||||
# framestacking via trajectory view API is enabled.
|
||||
num_framestacks = model_config.get("num_framestacks", 0)
|
||||
|
||||
# Trajectory view API is on and num_framestacks=auto:
|
||||
# Only stack traj. view based if old
|
||||
# `framestack=[invalid value]`.
|
||||
if num_framestacks == "auto":
|
||||
if framestack == DEPRECATED_VALUE:
|
||||
model_config["num_framestacks"] = num_framestacks = 4
|
||||
else:
|
||||
model_config["num_framestacks"] = num_framestacks = 0
|
||||
framestack_traj_view = num_framestacks > 1
|
||||
# Framestacking is used.
|
||||
use_framestack = model_config.get("framestack") is True
|
||||
|
||||
def wrap(env):
|
||||
env = wrap_deepmind(
|
||||
env,
|
||||
dim=model_config.get("dim"),
|
||||
framestack=framestack,
|
||||
framestack_via_traj_view_api=framestack_traj_view)
|
||||
framestack=use_framestack)
|
||||
env = record_env_wrapper(env, record_env, log_dir,
|
||||
policy_config)
|
||||
return env
|
||||
|
@ -740,7 +728,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
return self.last_batch
|
||||
elif self.input_reader is None:
|
||||
raise ValueError("RolloutWorker has no `input_reader` object! "
|
||||
"Cannot call `sample()`.")
|
||||
"Cannot call `sample()`. You can try setting "
|
||||
"`create_env_on_driver` to True.")
|
||||
|
||||
if log_once("sample_start"):
|
||||
logger.info("Generating sample batch of size {}".format(
|
||||
|
@ -1423,6 +1412,8 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
policy_config: Optional[PartialTrainerConfigDict] = None,
|
||||
) -> MultiAgentPolicyConfigDict:
|
||||
|
||||
policy_config = policy_config or {}
|
||||
|
||||
# Try extracting spaces from env or from given spaces dict.
|
||||
env_obs_space = None
|
||||
env_act_space = None
|
||||
|
@ -1455,7 +1446,7 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
obs_space = spaces[pid][0]
|
||||
elif env_obs_space is not None:
|
||||
obs_space = env_obs_space
|
||||
elif policy_config and policy_config.get("observation_space"):
|
||||
elif policy_config.get("observation_space"):
|
||||
obs_space = policy_config["observation_space"]
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -1463,6 +1454,7 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
f"{pid} and env does not have an observation space OR "
|
||||
"no spaces received from other workers' env(s) OR no "
|
||||
"`observation_space` specified in config!")
|
||||
|
||||
multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
|
||||
observation_space=obs_space)
|
||||
|
||||
|
@ -1471,7 +1463,7 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
act_space = spaces[pid][1]
|
||||
elif env_act_space is not None:
|
||||
act_space = env_act_space
|
||||
elif policy_config and policy_config.get("action_space"):
|
||||
elif policy_config.get("action_space"):
|
||||
act_space = policy_config["action_space"]
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -1489,8 +1481,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
|
|||
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."
|
||||
|
||||
allowed_types = [
|
||||
gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv,
|
||||
ray.actor.ActorHandle
|
||||
gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle
|
||||
]
|
||||
if not any(isinstance(env, tpe) for tpe in allowed_types):
|
||||
# Allow this as a special case (assumed gym.Env).
|
||||
|
@ -1508,7 +1499,7 @@ def _validate_env(env: EnvType, env_context: EnvContext = None):
|
|||
f"(type={type(env)}).")
|
||||
|
||||
# Do some test runs with the provided env.
|
||||
if isinstance(env, gym.Env):
|
||||
if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
|
||||
# Make sure the gym.Env has the two space attributes properly set.
|
||||
assert hasattr(env, "observation_space") and hasattr(
|
||||
env, "action_space")
|
||||
|
|
|
@ -418,7 +418,6 @@ class WorkerSet:
|
|||
normalize_actions=config["normalize_actions"],
|
||||
clip_actions=config["clip_actions"],
|
||||
env_config=config["env_config"],
|
||||
model_config=config["model"],
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
num_workers=num_workers,
|
||||
|
|
|
@ -131,17 +131,8 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
"attention_use_n_prev_rewards": 0,
|
||||
|
||||
# == Atari ==
|
||||
# Which framestacking size to use for Atari envs.
|
||||
# "auto": Use a value of 4, but only if the env is an Atari env.
|
||||
# > 1: Use the trajectory view API in the default VisionNets to request the
|
||||
# last n observations (single, grayscaled 84x84 image frames) as
|
||||
# inputs. The time axis in the so provided observation tensors
|
||||
# will come right after the batch axis (channels first format),
|
||||
# e.g. BxTx84x84, where T=num_framestacks.
|
||||
# 0 or 1: No framestacking used.
|
||||
# Use the deprecated `framestack=True`, to disable the above behavor and to
|
||||
# enable legacy stacking behavior (w/o trajectory view API) instead.
|
||||
"num_framestacks": "auto",
|
||||
# Set to True to enable 4x stacking behavior.
|
||||
"framestack": True,
|
||||
# Final resized frame dimension
|
||||
"dim": 84,
|
||||
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
|
||||
|
@ -166,8 +157,6 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
# Deprecated keys:
|
||||
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
|
||||
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
|
||||
# Use `num_framestacks` (int) instead.
|
||||
"framestack": True,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -807,10 +796,6 @@ class ModelCatalog:
|
|||
"framework={} not supported in `ModelCatalog._get_v2_model_"
|
||||
"class`!".format(framework))
|
||||
|
||||
# Discrete/1D obs-spaces or 2D obs space but traj. view framestacking
|
||||
# disabled.
|
||||
num_framestacks = model_config.get("num_framestacks", "auto")
|
||||
|
||||
# Tuple space, where at least one sub-space is image.
|
||||
# -> Complex input model.
|
||||
space_to_check = input_space if not hasattr(
|
||||
|
@ -824,8 +809,7 @@ class ModelCatalog:
|
|||
# Single, flattenable/one-hot-able space -> Simple FCNet.
|
||||
if isinstance(input_space, (Discrete, MultiDiscrete)) or \
|
||||
len(input_space.shape) == 1 or (
|
||||
len(input_space.shape) == 2 and (
|
||||
num_framestacks == "auto" or num_framestacks <= 1)):
|
||||
len(input_space.shape) == 2):
|
||||
# Keras native requested AND no auto-rnn-wrapping.
|
||||
if model_config.get("_use_default_native_models") and Keras_FCNet:
|
||||
return Keras_FCNet
|
||||
|
@ -886,10 +870,3 @@ class ModelCatalog:
|
|||
elif config.get("use_lstm"):
|
||||
raise ValueError("`use_lstm` not available for "
|
||||
"framework=jax so far!")
|
||||
|
||||
if config.get("framestack") != DEPRECATED_VALUE:
|
||||
# deprecation_warning(
|
||||
# old="framestack", new="num_framestacks (int)", error=False)
|
||||
# If old behavior is desired, disable traj. view-style
|
||||
# framestacking.
|
||||
config["num_framestacks"] = 0
|
||||
|
|
|
@ -5,7 +5,6 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.utils import get_activation_fn, get_filter_config
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
|
@ -44,17 +43,9 @@ class VisionNetwork(TFModelV2):
|
|||
|
||||
no_final_linear = self.model_config.get("no_final_linear")
|
||||
vf_share_layers = self.model_config.get("vf_share_layers")
|
||||
self.traj_view_framestacking = False
|
||||
|
||||
# Perform Atari framestacking via traj. view API.
|
||||
if model_config.get("num_framestacks") != "auto" and \
|
||||
model_config.get("num_framestacks", 0) > 1:
|
||||
input_shape = obs_space.shape + (model_config["num_framestacks"], )
|
||||
self.data_format = "channels_first"
|
||||
self.traj_view_framestacking = True
|
||||
else:
|
||||
input_shape = obs_space.shape
|
||||
self.data_format = "channels_last"
|
||||
input_shape = obs_space.shape
|
||||
self.data_format = "channels_last"
|
||||
|
||||
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
|
||||
last_layer = inputs
|
||||
|
@ -93,7 +84,10 @@ class VisionNetwork(TFModelV2):
|
|||
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
|
||||
if post_fcnet_hiddens else
|
||||
[])
|
||||
feature_out = last_layer
|
||||
|
||||
for i, out_size in enumerate(layer_sizes):
|
||||
feature_out = last_layer
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i),
|
||||
|
@ -126,6 +120,7 @@ class VisionNetwork(TFModelV2):
|
|||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
|
||||
[num_outputs]):
|
||||
feature_out = last_layer
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i + 1),
|
||||
|
@ -134,6 +129,7 @@ class VisionNetwork(TFModelV2):
|
|||
kernel_initializer=normc_initializer(1.0))(
|
||||
last_layer)
|
||||
else:
|
||||
feature_out = last_layer
|
||||
last_cnn = last_layer = tf.keras.layers.Conv2D(
|
||||
num_outputs, [1, 1],
|
||||
activation=None,
|
||||
|
@ -164,19 +160,20 @@ class VisionNetwork(TFModelV2):
|
|||
name="post_fcnet_{}".format(i),
|
||||
activation=post_fcnet_activation,
|
||||
kernel_initializer=normc_initializer(1.0))(last_layer)
|
||||
feature_out = last_layer
|
||||
self.num_outputs = last_layer.shape[1]
|
||||
logits_out = last_layer
|
||||
|
||||
# Build the value layers
|
||||
if vf_share_layers:
|
||||
if not self.last_layer_is_flattened:
|
||||
last_layer = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
feature_out = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(feature_out)
|
||||
value_out = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(last_layer)
|
||||
kernel_initializer=normc_initializer(0.01))(feature_out)
|
||||
else:
|
||||
# build a parallel set of hidden layers for the value net
|
||||
last_layer = inputs
|
||||
|
@ -211,20 +208,6 @@ class VisionNetwork(TFModelV2):
|
|||
|
||||
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
|
||||
|
||||
# Optional: framestacking obs/new_obs for Atari.
|
||||
if self.traj_view_framestacking:
|
||||
from_ = model_config["num_framestacks"] - 1
|
||||
self.view_requirements[SampleBatch.OBS].shift = \
|
||||
"-{}:0".format(from_)
|
||||
self.view_requirements[SampleBatch.OBS].shift_from = -from_
|
||||
self.view_requirements[SampleBatch.OBS].shift_to = 0
|
||||
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
shift="-{}:1".format(from_ - 1),
|
||||
space=self.view_requirements[SampleBatch.OBS].space,
|
||||
used_for_compute_actions=False,
|
||||
)
|
||||
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
|
@ -281,18 +264,8 @@ class Keras_VisionNetwork(tf.keras.Model if tf else object):
|
|||
post_fcnet_activation = get_activation_fn(
|
||||
post_fcnet_activation, framework="tf")
|
||||
|
||||
self.traj_view_framestacking = False
|
||||
|
||||
# Perform Atari framestacking via traj. view API.
|
||||
num_framestacks = kwargs.get("num_framestacks")
|
||||
if num_framestacks != "auto" and num_framestacks and \
|
||||
num_framestacks > 1:
|
||||
input_shape = input_space.shape + (num_framestacks, )
|
||||
self.data_format = "channels_first"
|
||||
self.traj_view_framestacking = True
|
||||
else:
|
||||
input_shape = input_space.shape
|
||||
self.data_format = "channels_last"
|
||||
input_shape = input_space.shape
|
||||
self.data_format = "channels_last"
|
||||
|
||||
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
|
||||
last_layer = inputs
|
||||
|
@ -448,20 +421,6 @@ class Keras_VisionNetwork(tf.keras.Model if tf else object):
|
|||
|
||||
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
|
||||
|
||||
# Optional: framestacking obs/new_obs for Atari.
|
||||
if self.traj_view_framestacking:
|
||||
from_ = num_framestacks - 1
|
||||
self.view_requirements[SampleBatch.OBS].shift = \
|
||||
"-{}:0".format(from_)
|
||||
self.view_requirements[SampleBatch.OBS].shift_from = -from_
|
||||
self.view_requirements[SampleBatch.OBS].shift_to = 0
|
||||
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
shift="-{}:1".format(from_ - 1),
|
||||
space=self.view_requirements[SampleBatch.OBS].space,
|
||||
used_for_compute_actions=False,
|
||||
)
|
||||
|
||||
def call(self, input_dict: SampleBatch) -> \
|
||||
(TensorType, List[TensorType], Dict[str, TensorType]):
|
||||
obs = input_dict["obs"]
|
||||
|
|
|
@ -6,8 +6,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|||
from ray.rllib.models.torch.misc import normc_initializer, same_padding, \
|
||||
SlimConv2d, SlimFC
|
||||
from ray.rllib.models.utils import get_activation_fn, get_filter_config
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
@ -46,17 +44,9 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
|||
# a n x (1,1) Conv2D).
|
||||
self.last_layer_is_flattened = False
|
||||
self._logits = None
|
||||
self.traj_view_framestacking = False
|
||||
|
||||
layers = []
|
||||
# Perform Atari framestacking via traj. view API.
|
||||
if model_config.get("num_framestacks") != "auto" and \
|
||||
model_config.get("num_framestacks", 0) > 1:
|
||||
(w, h) = obs_space.shape
|
||||
in_channels = model_config["num_framestacks"]
|
||||
self.traj_view_framestacking = True
|
||||
else:
|
||||
(w, h, in_channels) = obs_space.shape
|
||||
(w, h, in_channels) = obs_space.shape
|
||||
|
||||
in_size = [w, h]
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
|
@ -165,11 +155,7 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
|||
activation_fn=None)
|
||||
else:
|
||||
vf_layers = []
|
||||
if self.traj_view_framestacking:
|
||||
(w, h) = obs_space.shape
|
||||
in_channels = model_config["num_framestacks"]
|
||||
else:
|
||||
(w, h, in_channels) = obs_space.shape
|
||||
(w, h, in_channels) = obs_space.shape
|
||||
in_size = [w, h]
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
padding, out_size = same_padding(in_size, kernel, stride)
|
||||
|
@ -207,27 +193,13 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
|||
# Holds the current "base" output (before logits layer).
|
||||
self._features = None
|
||||
|
||||
# Optional: framestacking obs/new_obs for Atari.
|
||||
if self.traj_view_framestacking:
|
||||
from_ = model_config["num_framestacks"] - 1
|
||||
self.view_requirements[SampleBatch.OBS].shift = \
|
||||
"-{}:0".format(from_)
|
||||
self.view_requirements[SampleBatch.OBS].shift_from = -from_
|
||||
self.view_requirements[SampleBatch.OBS].shift_to = 0
|
||||
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
shift="-{}:1".format(from_ - 1),
|
||||
space=self.view_requirements[SampleBatch.OBS].space,
|
||||
)
|
||||
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
self._features = input_dict["obs"].float()
|
||||
# No framestacking:
|
||||
if not self.traj_view_framestacking:
|
||||
self._features = self._features.permute(0, 3, 1, 2)
|
||||
# Permuate b/c data comes in as [B, dim, dim, channels]:
|
||||
self._features = self._features.permute(0, 3, 1, 2)
|
||||
conv_out = self._convs(self._features)
|
||||
# Store features to save forward pass when getting value_function out.
|
||||
if not self._value_branch_separate:
|
||||
|
|
|
@ -943,20 +943,27 @@ class SampleBatch(dict):
|
|||
# Range needed.
|
||||
if view_req.shift_from is not None:
|
||||
data = self[view_col][-1]
|
||||
traj_len = len(self[data_col])
|
||||
missing_at_end = traj_len % view_req.batch_repeat_value
|
||||
obs_shift = -1 if data_col in [
|
||||
SampleBatch.OBS, SampleBatch.NEXT_OBS
|
||||
] else 0
|
||||
from_ = view_req.shift_from + obs_shift
|
||||
to_ = view_req.shift_to + obs_shift + 1
|
||||
if to_ == 0:
|
||||
to_ = None
|
||||
input_dict[view_col] = np.array([
|
||||
np.concatenate(
|
||||
[data,
|
||||
self[data_col][-missing_at_end:]])[from_:to_]
|
||||
])
|
||||
# Batch repeat value > 1: We have single frames in the
|
||||
# batch at each timestep.
|
||||
if view_req.batch_repeat_value > 1:
|
||||
traj_len = len(self[data_col])
|
||||
missing_at_end = traj_len % view_req.batch_repeat_value
|
||||
obs_shift = -1 if data_col in [
|
||||
SampleBatch.OBS, SampleBatch.NEXT_OBS
|
||||
] else 0
|
||||
from_ = view_req.shift_from + obs_shift
|
||||
to_ = view_req.shift_to + obs_shift + 1
|
||||
if to_ == 0:
|
||||
to_ = None
|
||||
input_dict[view_col] = np.array([
|
||||
np.concatenate(
|
||||
[self[data_col][-missing_at_end:],
|
||||
data])[from_:to_]
|
||||
])
|
||||
# Batch repeat value = 1: We already have framestacks
|
||||
# at each timestep.
|
||||
else:
|
||||
input_dict[view_col] = data[None]
|
||||
# Single index.
|
||||
else:
|
||||
data = self[data_col][-1]
|
||||
|
|
|
@ -336,12 +336,6 @@ def check_compute_single_action(trainer,
|
|||
call_kwargs["clip_actions"] = True
|
||||
|
||||
obs = obs_space.sample()
|
||||
# Framestacking w/ traj. view API.
|
||||
framestacks = pol.config["model"].get("num_framestacks",
|
||||
"auto")
|
||||
if isinstance(framestacks, int) and framestacks > 1:
|
||||
obs = np.stack(
|
||||
[obs] * pol.config["model"]["num_framestacks"])
|
||||
if isinstance(obs_space, gym.spaces.Box):
|
||||
obs = np.clip(obs, -1.0, 1.0)
|
||||
state_in = None
|
||||
|
|
Loading…
Add table
Reference in a new issue