[RLlib] Unity3D example broken due to change in ML-Agents API. Attention-net prev-n-a/r. Attention-wrapper works with images. (#14569)

This commit is contained in:
Sven Mika 2021-03-12 18:27:25 +01:00 committed by GitHub
parent c93961e070
commit ee4b6e7e3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 350 additions and 51 deletions

View file

@ -609,19 +609,25 @@ class Trainer(Trainable):
elif "." in env:
self.env_creator = \
lambda env_context: from_config(env, env_context)
# Try gym/PyBullet.
# Try gym/PyBullet/Vizdoom.
else:
def _creator(env_context):
import gym
# Allow for PyBullet envs to be used as well (via string).
# This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0`.
# Allow for PyBullet or VizdoomGym envs to be used as well
# (via string). This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0` or
# `env=VizdoomBasic-v0`.
try:
import pybullet_envs
pybullet_envs.getList()
except (ModuleNotFoundError, ImportError):
pass
try:
import vizdoomgym
vizdoomgym.__name__ # trick LINTer.
except (ModuleNotFoundError, ImportError):
pass
# Try creating a gym env. If this fails we can output a
# decent error message.
try:
@ -629,12 +635,12 @@ class Trainer(Trainable):
except gym.error.Error:
raise ValueError(
"The env string you provided ({}) is a) not a "
"known gym/PyBullet environment specifier or b) "
"not registered! To register your custom envs, "
"do `from ray import tune; tune.register('[name]',"
" lambda cfg: [return actual "
"env from here using cfg])`. Then you can use "
"[name] as your config['env'].".format(env))
"known gym/PyBullet/VizdoomEnv environment "
"specifier or b) not registered! To register your "
"custom envs, do `from ray import tune; "
"tune.register('[name]', lambda cfg: [return "
"actual env from here using cfg])`. Then you can "
"use [name] as your config['env'].".format(env))
self.env_creator = _creator
else:

View file

@ -73,11 +73,13 @@ class Unity3DEnv(MultiAgentEnv):
from mlagents_envs.environment import UnityEnvironment
# Try connecting to the Unity3D game instance. If a port is blocked
port_ = None
while True:
# Sleep for random time to allow for concurrent startup of many
# environments (num_workers >> 1). Otherwise, would lead to port
# conflicts sometimes.
time.sleep(random.randint(1, 10))
if port_ is not None:
time.sleep(random.randint(1, 10))
port_ = port or (self._BASE_PORT_ENVIRONMENT
if file_name else self._BASE_PORT_EDITOR)
# cache the worker_id and
@ -101,6 +103,10 @@ class Unity3DEnv(MultiAgentEnv):
else:
break
# ML-Agents API version.
self.api_version = self.unity_env.API_VERSION.split(".")
self.api_version = [int(s) for s in self.api_version]
# Reset entire env every this number of step calls.
self.episode_horizon = episode_horizon
# Keep track of how many times we have called `step` so far.
@ -128,16 +134,37 @@ class Unity3DEnv(MultiAgentEnv):
it. __all__=True, if episode is done for all agents.
- infos: An (empty) info dict.
"""
from mlagents_envs.base_env import ActionTuple
# Set only the required actions (from the DecisionSteps) in Unity3D.
all_agents = []
for behavior_name in self.unity_env.behavior_specs:
for agent_id in self.unity_env.get_steps(behavior_name)[
0].agent_id_to_index.keys():
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
self.unity_env.set_action_for_agent(behavior_name, agent_id,
action_dict[key])
# New ML-Agents API: Set all agents actions at the same time
# via an ActionTuple. Since API v1.4.0.
if self.api_version[0] > 1 or (self.api_version[0] == 1
and self.api_version[1] >= 4):
actions = []
for agent_id in self.unity_env.get_steps(behavior_name)[
0].agent_id:
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
actions.append(action_dict[key])
if actions:
if actions[0].dtype == np.float32:
action_tuple = ActionTuple(
continuous=np.array(actions))
else:
action_tuple = ActionTuple(discrete=np.array(actions))
self.unity_env.set_actions(behavior_name, action_tuple)
# Old behavior: Do not use an ActionTuple and set each agent's
# action individually.
else:
for agent_id in self.unity_env.get_steps(behavior_name)[
0].agent_id_to_index.keys():
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
self.unity_env.set_action_for_agent(
behavior_name, agent_id, action_dict[key])
# Do the step.
self.unity_env.step()
@ -215,6 +242,8 @@ class Unity3DEnv(MultiAgentEnv):
"3DBall": Box(float("-inf"), float("inf"), (8, )),
# 3DBallHard.
"3DBallHard": Box(float("-inf"), float("inf"), (45, )),
# GridFoodCollector
"GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
# Pyramids.
"Pyramids": TupleSpace([
Box(float("-inf"), float("inf"), (56, )),
@ -228,6 +257,15 @@ class Unity3DEnv(MultiAgentEnv):
Box(float("-inf"), float("inf"), (231, )),
Box(float("-inf"), float("inf"), (63, )),
]),
# Sorter.
"Sorter": TupleSpace([
Box(float("-inf"), float("inf"), (
20,
23,
)),
Box(float("-inf"), float("inf"), (10, )),
Box(float("-inf"), float("inf"), (8, )),
]),
# Tennis.
"Tennis": Box(float("-inf"), float("inf"), (27, )),
# VisualHallway.
@ -247,11 +285,15 @@ class Unity3DEnv(MultiAgentEnv):
# 3DBallHard.
"3DBallHard": Box(
float("-inf"), float("inf"), (2, ), dtype=np.float32),
# GridFoodCollector.
"GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
# Pyramids.
"Pyramids": MultiDiscrete([5]),
# SoccerStrikersVsGoalie.
"Goalie": MultiDiscrete([3, 3, 3]),
"Striker": MultiDiscrete([3, 3, 3]),
# Sorter.
"Sorter": MultiDiscrete([3, 3, 3]),
# Tennis.
"Tennis": Box(float("-inf"), float("inf"), (3, )),
# VisualHallway.

View file

@ -34,12 +34,19 @@ parser.add_argument(
type=str,
default="3DBall",
choices=[
"3DBall", "3DBallHard", "Pyramids", "SoccerStrikersVsGoalie", "Tennis",
"VisualHallway", "Walker"
"3DBall",
"3DBallHard",
"GridFoodCollector",
"Pyramids",
"SoccerStrikersVsGoalie",
"Sorter",
"Tennis",
"VisualHallway",
"Walker",
],
help="The name of the Env to run in the Unity3D editor: `3DBall(Hard)?|"
"Pyramids|SoccerStrikersVsGoalie|Tennis|VisualHallway|Walker`"
"(feel free to add more and PR!)")
"Pyramids|GridFoodCollector|SoccerStrikersVsGoalie|Sorter|Tennis|"
"VisualHallway|Walker` (feel free to add more and PR!)")
parser.add_argument(
"--file-name",
type=str,
@ -135,6 +142,13 @@ if __name__ == "__main__":
"forward_net_activation": "relu",
"inverse_net_activation": "relu",
}
elif args.env == "GridFoodCollector":
config["model"] = {
"conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2],
[256, [10, 10], 1]],
}
elif args.env == "Sorter":
config["model"]["use_attention"] = True
stop = {
"training_iteration": args.stop_iters,
@ -148,7 +162,8 @@ if __name__ == "__main__":
config=config,
stop=stop,
verbose=1,
checkpoint_freq=10,
checkpoint_freq=5,
checkpoint_at_end=True,
restore=args.from_checkpoint)
# And check the results.

View file

@ -0,0 +1,77 @@
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument(
"--from-checkpoint",
type=str,
default=None,
help="Full path to a checkpoint file for restoring a previously saved "
"Trainer state.")
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument(
"--use-n-prev-actions",
type=int,
default=0,
help="How many of the previous actions to use as attention input.")
parser.add_argument(
"--use-n-prev-rewards",
type=int,
default=0,
help="How many of the previous rewards to use as attention input.")
parser.add_argument("--stop-iters", type=int, default=9999)
parser.add_argument("--stop-timesteps", type=int, default=100000000)
parser.add_argument("--stop-reward", type=float, default=1000.0)
if __name__ == "__main__":
import ray
from ray import tune
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
config = {
"env": "VizdoomBasic-v0",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"model": {
"conv_filters": [],
"use_attention": True,
"attention_num_transformer_units": 1,
"attention_dim": 64,
"attention_num_heads": 2,
"attention_memory_inference": 100,
"attention_memory_training": 50,
"vf_share_layers": True,
"attention_use_n_prev_actions": args.use_n_prev_actions,
"attention_use_n_prev_rewards": args.use_n_prev_rewards,
},
"framework": args.framework,
# Run with tracing enabled for tfe/tf2.
"eager_tracing": args.framework in ["tfe", "tf2"],
"num_workers": args.num_workers,
"vf_loss_coeff": 0.01,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(
args.run,
config=config,
stop=stop,
verbose=2,
checkpoint_freq=5,
checkpoint_at_end=True,
restore=args.from_checkpoint,
)
print(results)
ray.shutdown()

View file

@ -115,10 +115,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
"attention_position_wise_mlp_dim": 32,
# The initial bias values for the 2 GRU gates within a transformer unit.
"attention_init_gru_gate_bias": 2.0,
# TODO: Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
# "attention_use_n_prev_actions": 0,
# Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
"attention_use_n_prev_actions": 0,
# Whether to feed r_{t-n:t-1} to GTrXL.
# "attention_use_n_prev_rewards": 0,
"attention_use_n_prev_rewards": 0,
# == Atari ==
# Which framestacking size to use for Atari envs.

View file

@ -414,14 +414,13 @@ def restore_original_dimensions(obs: TensorType,
observation space.
"""
if tensorlib == "tf":
if tensorlib in ["tf", "tfe", "tf2"]:
assert tf is not None
tensorlib = tf
elif tensorlib == "torch":
assert torch is not None
tensorlib = torch
original_space = getattr(obs_space, "original_space", obs_space)
if original_space is obs_space:
return obs
return _unpack_obs(obs, original_space, tensorlib=tensorlib)
@ -450,7 +449,12 @@ def _unpack_obs(obs: TensorType, space: gym.Space,
# Make an attempt to cache the result, if enough space left.
if len(_cache) < 999:
_cache[id(space)] = prep
if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
# Already unpacked?
if (isinstance(space, gym.spaces.Tuple) and
isinstance(obs, (list, tuple))) or \
(isinstance(space, gym.spaces.Dict) and isinstance(obs, dict)):
return obs
elif len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
raise ValueError(
"Expected flattened obs shape of [..., {}], got {}".format(
prep.shape[0], obs.shape))

View file

@ -22,6 +22,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
tf1, tf, tfv = try_import_tf()
@ -353,6 +354,9 @@ class AttentionWrapper(TFModelV2):
super().__init__(obs_space, action_space, None, model_config, name)
self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
if isinstance(action_space, Discrete):
self.action_dim = action_space.n
elif isinstance(action_space, MultiDiscrete):
@ -362,15 +366,30 @@ class AttentionWrapper(TFModelV2):
else:
self.action_dim = int(len(action_space))
# Add prev-action/reward nodes to input to LSTM.
if self.use_n_prev_actions:
self.num_outputs += self.use_n_prev_actions * self.action_dim
if self.use_n_prev_rewards:
self.num_outputs += self.use_n_prev_rewards
cfg = model_config
self.attention_dim = cfg["attention_dim"]
if self.num_outputs is not None:
in_space = gym.spaces.Box(
float("-inf"),
float("inf"),
shape=(self.num_outputs, ),
dtype=np.float32)
else:
in_space = obs_space
# Construct GTrXL sub-module w/ num_outputs=None (so it does not
# create a logits/value output; we'll do this ourselves in this wrapper
# here).
self.gtrxl = GTrXLNet(
obs_space,
in_space,
action_space,
None,
model_config,
@ -401,6 +420,20 @@ class AttentionWrapper(TFModelV2):
self._value_branch = tf.keras.models.Model([input_], [out])
self.view_requirements = self.gtrxl.view_requirements
self.view_requirements["obs"].space = self.obs_space
# Add prev-a/r to this model's view, if required.
if self.use_n_prev_actions:
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(
SampleBatch.ACTIONS,
space=self.action_space,
shift="-{}:-1".format(self.use_n_prev_actions))
if self.use_n_prev_rewards:
self.view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(
SampleBatch.REWARDS,
shift="-{}:-1".format(self.use_n_prev_rewards))
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],
@ -410,8 +443,41 @@ class AttentionWrapper(TFModelV2):
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
# Concat. prev-action/reward if required.
prev_a_r = []
if self.use_n_prev_actions:
if isinstance(self.action_space, Discrete):
for i in range(self.use_n_prev_actions):
prev_a_r.append(
one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
self.action_space))
elif isinstance(self.action_space, MultiDiscrete):
for i in range(
self.use_n_prev_actions,
step=self.action_space.shape[0]):
prev_a_r.append(
one_hot(
tf.cast(
input_dict[SampleBatch.PREV_ACTIONS]
[:, i:i + self.action_space.shape[0]],
tf.float32), self.action_space))
else:
prev_a_r.append(
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
tf.float32),
[-1, self.use_n_prev_actions * self.action_dim]))
if self.use_n_prev_rewards:
prev_a_r.append(
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
[-1, self.use_n_prev_rewards]))
if prev_a_r:
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
# Then through our GTrXL.
input_dict["obs_flat"] = wrapped_out
input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
model_out = self._logits_branch(self._features)

View file

@ -70,7 +70,8 @@ class VisionNetwork(TFModelV2):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
@ -85,7 +86,8 @@ class VisionNetwork(TFModelV2):
last_layer = tf.keras.layers.Conv2D(
out_size if post_fcnet_hiddens else num_outputs,
kernel,
strides=(stride, stride),
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
@ -107,7 +109,8 @@ class VisionNetwork(TFModelV2):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
@ -169,8 +172,9 @@ class VisionNetwork(TFModelV2):
# Build the value layers
if vf_share_layers:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
if not self.last_layer_is_flattened:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
@ -183,7 +187,8 @@ class VisionNetwork(TFModelV2):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
@ -192,7 +197,8 @@ class VisionNetwork(TFModelV2):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",

View file

@ -23,6 +23,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
torch, nn = try_import_torch()
@ -250,6 +251,9 @@ class AttentionWrapper(TorchModelV2, nn.Module):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, None, model_config, name)
self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
if isinstance(action_space, Discrete):
self.action_dim = action_space.n
elif isinstance(action_space, MultiDiscrete):
@ -259,15 +263,30 @@ class AttentionWrapper(TorchModelV2, nn.Module):
else:
self.action_dim = int(len(action_space))
# Add prev-action/reward nodes to input to LSTM.
if self.use_n_prev_actions:
self.num_outputs += self.use_n_prev_actions * self.action_dim
if self.use_n_prev_rewards:
self.num_outputs += self.use_n_prev_rewards
cfg = model_config
self.attention_dim = cfg["attention_dim"]
if self.num_outputs is not None:
in_space = gym.spaces.Box(
float("-inf"),
float("inf"),
shape=(self.num_outputs, ),
dtype=np.float32)
else:
in_space = obs_space
# Construct GTrXL sub-module w/ num_outputs=None (so it does not
# create a logits/value output; we'll do this ourselves in this wrapper
# here).
self.gtrxl = GTrXLNet(
obs_space,
in_space,
action_space,
None,
model_config,
@ -299,6 +318,20 @@ class AttentionWrapper(TorchModelV2, nn.Module):
initializer=torch.nn.init.xavier_uniform_)
self.view_requirements = self.gtrxl.view_requirements
self.view_requirements["obs"].space = self.obs_space
# Add prev-a/r to this model's view, if required.
if self.use_n_prev_actions:
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(
SampleBatch.ACTIONS,
space=self.action_space,
shift="-{}:-1".format(self.use_n_prev_actions))
if self.use_n_prev_rewards:
self.view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(
SampleBatch.REWARDS,
shift="-{}:-1".format(self.use_n_prev_rewards))
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],
@ -308,12 +341,43 @@ class AttentionWrapper(TorchModelV2, nn.Module):
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
# Concat. prev-action/reward if required.
prev_a_r = []
if self.use_n_prev_actions:
if isinstance(self.action_space, Discrete):
for i in range(self.use_n_prev_actions):
prev_a_r.append(
one_hot(
input_dict[SampleBatch.PREV_ACTIONS][:, i].float(),
self.action_space))
elif isinstance(self.action_space, MultiDiscrete):
for i in range(
self.use_n_prev_actions,
step=self.action_space.shape[0]):
prev_a_r.append(
one_hot(
input_dict[SampleBatch.PREV_ACTIONS]
[:, i:i + self.action_space.shape[0]].float(),
self.action_space))
else:
prev_a_r.append(
torch.reshape(
input_dict[SampleBatch.PREV_ACTIONS].float(),
[-1, self.use_n_prev_actions * self.action_dim]))
if self.use_n_prev_rewards:
prev_a_r.append(
torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
[-1, self.use_n_prev_rewards]))
if prev_a_r:
wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)
# Then through our GTrXL.
input_dict["obs_flat"] = wrapped_out
input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
model_out = self._logits_branch(self._features)
return model_out, [torch.squeeze(m, 0) for m in memory_outs]
return model_out, memory_outs
@override(ModelV2)
def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:

View file

@ -39,7 +39,10 @@ def same_padding(in_size: Tuple[int, int], filter_size: Tuple[int, int],
filter_height, filter_width = filter_size, filter_size
else:
filter_height, filter_width = filter_size
stride_height, stride_width = stride_size
if isinstance(stride_size, (int, float)):
stride_height, stride_width = int(stride_size), int(stride_size)
else:
stride_height, stride_width = int(stride_size[0]), int(stride_size[1])
out_height = np.ceil(float(in_height) / float(stride_height))
out_width = np.ceil(float(in_width) / float(stride_width))

View file

@ -60,7 +60,7 @@ class VisionNetwork(TorchModelV2, nn.Module):
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = same_padding(in_size, kernel, [stride, stride])
padding, out_size = same_padding(in_size, kernel, stride)
layers.append(
SlimConv2d(
in_channels,
@ -172,8 +172,7 @@ class VisionNetwork(TorchModelV2, nn.Module):
(w, h, in_channels) = obs_space.shape
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = same_padding(in_size, kernel,
[stride, stride])
padding, out_size = same_padding(in_size, kernel, stride)
vf_layers.append(
SlimConv2d(
in_channels,

View file

@ -70,18 +70,29 @@ def get_filter_config(shape):
inside a model config dict.
"""
shape = list(shape)
# VizdoomGym.
filters_240x320x = [
[16, [12, 16], [7, 9]],
[32, [6, 6], 4],
[256, [9, 9], 1],
]
# Atari.
filters_84x84 = [
[16, [8, 8], 4],
[32, [4, 4], 2],
[256, [11, 11], 1],
]
# Small (1/2) Atari.
filters_42x42 = [
[16, [4, 4], 2],
[32, [4, 4], 2],
[256, [11, 11], 1],
]
if len(shape) in [2, 3] and (shape[:2] == [84, 84]
or shape[1:] == [84, 84]):
if len(shape) in [2, 3] and (shape[:2] == [240, 320]
or shape[1:] == [240, 320]):
return filters_240x320x
elif len(shape) in [2, 3] and (shape[:2] == [84, 84]
or shape[1:] == [84, 84]):
return filters_84x84
elif len(shape) in [2, 3] and (shape[:2] == [42, 42]
or shape[1:] == [42, 42]):

View file

@ -214,15 +214,21 @@ class DynamicTFPolicy(TFPolicy):
self.view_requirements, existing_inputs)
else:
action_ph = ModelCatalog.get_action_placeholder(action_space)
prev_action_ph = ModelCatalog.get_action_placeholder(
action_space, "prev_action")
if self.config["_use_trajectory_view_api"]:
prev_action_ph = {}
if SampleBatch.PREV_ACTIONS not in self.view_requirements:
prev_action_ph = {
SampleBatch.PREV_ACTIONS: ModelCatalog.
get_action_placeholder(action_space, "prev_action")
}
self._input_dict, self._dummy_batch = \
self._get_input_dict_and_dummy_batch(
self.view_requirements,
{SampleBatch.ACTIONS: action_ph,
SampleBatch.PREV_ACTIONS: prev_action_ph})
dict({SampleBatch.ACTIONS: action_ph},
**prev_action_ph))
else:
prev_action_ph = ModelCatalog.get_action_placeholder(
action_space, "prev_action")
self._input_dict = {
SampleBatch.CUR_OBS: tf1.placeholder(
tf.float32,