mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Unity3D example broken due to change in ML-Agents API. Attention-net prev-n-a/r. Attention-wrapper works with images. (#14569)
This commit is contained in:
parent
c93961e070
commit
ee4b6e7e3b
13 changed files with 350 additions and 51 deletions
|
@ -609,19 +609,25 @@ class Trainer(Trainable):
|
|||
elif "." in env:
|
||||
self.env_creator = \
|
||||
lambda env_context: from_config(env, env_context)
|
||||
# Try gym/PyBullet.
|
||||
# Try gym/PyBullet/Vizdoom.
|
||||
else:
|
||||
|
||||
def _creator(env_context):
|
||||
import gym
|
||||
# Allow for PyBullet envs to be used as well (via string).
|
||||
# This allows for doing things like
|
||||
# `env=CartPoleContinuousBulletEnv-v0`.
|
||||
# Allow for PyBullet or VizdoomGym envs to be used as well
|
||||
# (via string). This allows for doing things like
|
||||
# `env=CartPoleContinuousBulletEnv-v0` or
|
||||
# `env=VizdoomBasic-v0`.
|
||||
try:
|
||||
import pybullet_envs
|
||||
pybullet_envs.getList()
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
pass
|
||||
try:
|
||||
import vizdoomgym
|
||||
vizdoomgym.__name__ # trick LINTer.
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
pass
|
||||
# Try creating a gym env. If this fails we can output a
|
||||
# decent error message.
|
||||
try:
|
||||
|
@ -629,12 +635,12 @@ class Trainer(Trainable):
|
|||
except gym.error.Error:
|
||||
raise ValueError(
|
||||
"The env string you provided ({}) is a) not a "
|
||||
"known gym/PyBullet environment specifier or b) "
|
||||
"not registered! To register your custom envs, "
|
||||
"do `from ray import tune; tune.register('[name]',"
|
||||
" lambda cfg: [return actual "
|
||||
"env from here using cfg])`. Then you can use "
|
||||
"[name] as your config['env'].".format(env))
|
||||
"known gym/PyBullet/VizdoomEnv environment "
|
||||
"specifier or b) not registered! To register your "
|
||||
"custom envs, do `from ray import tune; "
|
||||
"tune.register('[name]', lambda cfg: [return "
|
||||
"actual env from here using cfg])`. Then you can "
|
||||
"use [name] as your config['env'].".format(env))
|
||||
|
||||
self.env_creator = _creator
|
||||
else:
|
||||
|
|
56
rllib/env/wrappers/unity3d_env.py
vendored
56
rllib/env/wrappers/unity3d_env.py
vendored
|
@ -73,11 +73,13 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
from mlagents_envs.environment import UnityEnvironment
|
||||
|
||||
# Try connecting to the Unity3D game instance. If a port is blocked
|
||||
port_ = None
|
||||
while True:
|
||||
# Sleep for random time to allow for concurrent startup of many
|
||||
# environments (num_workers >> 1). Otherwise, would lead to port
|
||||
# conflicts sometimes.
|
||||
time.sleep(random.randint(1, 10))
|
||||
if port_ is not None:
|
||||
time.sleep(random.randint(1, 10))
|
||||
port_ = port or (self._BASE_PORT_ENVIRONMENT
|
||||
if file_name else self._BASE_PORT_EDITOR)
|
||||
# cache the worker_id and
|
||||
|
@ -101,6 +103,10 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
else:
|
||||
break
|
||||
|
||||
# ML-Agents API version.
|
||||
self.api_version = self.unity_env.API_VERSION.split(".")
|
||||
self.api_version = [int(s) for s in self.api_version]
|
||||
|
||||
# Reset entire env every this number of step calls.
|
||||
self.episode_horizon = episode_horizon
|
||||
# Keep track of how many times we have called `step` so far.
|
||||
|
@ -128,16 +134,37 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
it. __all__=True, if episode is done for all agents.
|
||||
- infos: An (empty) info dict.
|
||||
"""
|
||||
from mlagents_envs.base_env import ActionTuple
|
||||
|
||||
# Set only the required actions (from the DecisionSteps) in Unity3D.
|
||||
all_agents = []
|
||||
for behavior_name in self.unity_env.behavior_specs:
|
||||
for agent_id in self.unity_env.get_steps(behavior_name)[
|
||||
0].agent_id_to_index.keys():
|
||||
key = behavior_name + "_{}".format(agent_id)
|
||||
all_agents.append(key)
|
||||
self.unity_env.set_action_for_agent(behavior_name, agent_id,
|
||||
action_dict[key])
|
||||
# New ML-Agents API: Set all agents actions at the same time
|
||||
# via an ActionTuple. Since API v1.4.0.
|
||||
if self.api_version[0] > 1 or (self.api_version[0] == 1
|
||||
and self.api_version[1] >= 4):
|
||||
actions = []
|
||||
for agent_id in self.unity_env.get_steps(behavior_name)[
|
||||
0].agent_id:
|
||||
key = behavior_name + "_{}".format(agent_id)
|
||||
all_agents.append(key)
|
||||
actions.append(action_dict[key])
|
||||
if actions:
|
||||
if actions[0].dtype == np.float32:
|
||||
action_tuple = ActionTuple(
|
||||
continuous=np.array(actions))
|
||||
else:
|
||||
action_tuple = ActionTuple(discrete=np.array(actions))
|
||||
self.unity_env.set_actions(behavior_name, action_tuple)
|
||||
# Old behavior: Do not use an ActionTuple and set each agent's
|
||||
# action individually.
|
||||
else:
|
||||
for agent_id in self.unity_env.get_steps(behavior_name)[
|
||||
0].agent_id_to_index.keys():
|
||||
key = behavior_name + "_{}".format(agent_id)
|
||||
all_agents.append(key)
|
||||
self.unity_env.set_action_for_agent(
|
||||
behavior_name, agent_id, action_dict[key])
|
||||
# Do the step.
|
||||
self.unity_env.step()
|
||||
|
||||
|
@ -215,6 +242,8 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
"3DBall": Box(float("-inf"), float("inf"), (8, )),
|
||||
# 3DBallHard.
|
||||
"3DBallHard": Box(float("-inf"), float("inf"), (45, )),
|
||||
# GridFoodCollector
|
||||
"GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
|
||||
# Pyramids.
|
||||
"Pyramids": TupleSpace([
|
||||
Box(float("-inf"), float("inf"), (56, )),
|
||||
|
@ -228,6 +257,15 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
Box(float("-inf"), float("inf"), (231, )),
|
||||
Box(float("-inf"), float("inf"), (63, )),
|
||||
]),
|
||||
# Sorter.
|
||||
"Sorter": TupleSpace([
|
||||
Box(float("-inf"), float("inf"), (
|
||||
20,
|
||||
23,
|
||||
)),
|
||||
Box(float("-inf"), float("inf"), (10, )),
|
||||
Box(float("-inf"), float("inf"), (8, )),
|
||||
]),
|
||||
# Tennis.
|
||||
"Tennis": Box(float("-inf"), float("inf"), (27, )),
|
||||
# VisualHallway.
|
||||
|
@ -247,11 +285,15 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
# 3DBallHard.
|
||||
"3DBallHard": Box(
|
||||
float("-inf"), float("inf"), (2, ), dtype=np.float32),
|
||||
# GridFoodCollector.
|
||||
"GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
|
||||
# Pyramids.
|
||||
"Pyramids": MultiDiscrete([5]),
|
||||
# SoccerStrikersVsGoalie.
|
||||
"Goalie": MultiDiscrete([3, 3, 3]),
|
||||
"Striker": MultiDiscrete([3, 3, 3]),
|
||||
# Sorter.
|
||||
"Sorter": MultiDiscrete([3, 3, 3]),
|
||||
# Tennis.
|
||||
"Tennis": Box(float("-inf"), float("inf"), (3, )),
|
||||
# VisualHallway.
|
||||
|
|
|
@ -34,12 +34,19 @@ parser.add_argument(
|
|||
type=str,
|
||||
default="3DBall",
|
||||
choices=[
|
||||
"3DBall", "3DBallHard", "Pyramids", "SoccerStrikersVsGoalie", "Tennis",
|
||||
"VisualHallway", "Walker"
|
||||
"3DBall",
|
||||
"3DBallHard",
|
||||
"GridFoodCollector",
|
||||
"Pyramids",
|
||||
"SoccerStrikersVsGoalie",
|
||||
"Sorter",
|
||||
"Tennis",
|
||||
"VisualHallway",
|
||||
"Walker",
|
||||
],
|
||||
help="The name of the Env to run in the Unity3D editor: `3DBall(Hard)?|"
|
||||
"Pyramids|SoccerStrikersVsGoalie|Tennis|VisualHallway|Walker`"
|
||||
"(feel free to add more and PR!)")
|
||||
"Pyramids|GridFoodCollector|SoccerStrikersVsGoalie|Sorter|Tennis|"
|
||||
"VisualHallway|Walker` (feel free to add more and PR!)")
|
||||
parser.add_argument(
|
||||
"--file-name",
|
||||
type=str,
|
||||
|
@ -135,6 +142,13 @@ if __name__ == "__main__":
|
|||
"forward_net_activation": "relu",
|
||||
"inverse_net_activation": "relu",
|
||||
}
|
||||
elif args.env == "GridFoodCollector":
|
||||
config["model"] = {
|
||||
"conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2],
|
||||
[256, [10, 10], 1]],
|
||||
}
|
||||
elif args.env == "Sorter":
|
||||
config["model"]["use_attention"] = True
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
|
@ -148,7 +162,8 @@ if __name__ == "__main__":
|
|||
config=config,
|
||||
stop=stop,
|
||||
verbose=1,
|
||||
checkpoint_freq=10,
|
||||
checkpoint_freq=5,
|
||||
checkpoint_at_end=True,
|
||||
restore=args.from_checkpoint)
|
||||
|
||||
# And check the results.
|
||||
|
|
77
rllib/examples/vizdoom_with_attention_net.py
Normal file
77
rllib/examples/vizdoom_with_attention_net.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
parser.add_argument(
|
||||
"--from-checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Full path to a checkpoint file for restoring a previously saved "
|
||||
"Trainer state.")
|
||||
parser.add_argument("--num-workers", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--use-n-prev-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="How many of the previous actions to use as attention input.")
|
||||
parser.add_argument(
|
||||
"--use-n-prev-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="How many of the previous rewards to use as attention input.")
|
||||
parser.add_argument("--stop-iters", type=int, default=9999)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000000)
|
||||
parser.add_argument("--stop-reward", type=float, default=1000.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
|
||||
config = {
|
||||
"env": "VizdoomBasic-v0",
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"model": {
|
||||
"conv_filters": [],
|
||||
"use_attention": True,
|
||||
"attention_num_transformer_units": 1,
|
||||
"attention_dim": 64,
|
||||
"attention_num_heads": 2,
|
||||
"attention_memory_inference": 100,
|
||||
"attention_memory_training": 50,
|
||||
"vf_share_layers": True,
|
||||
"attention_use_n_prev_actions": args.use_n_prev_actions,
|
||||
"attention_use_n_prev_rewards": args.use_n_prev_rewards,
|
||||
},
|
||||
"framework": args.framework,
|
||||
# Run with tracing enabled for tfe/tf2.
|
||||
"eager_tracing": args.framework in ["tfe", "tf2"],
|
||||
"num_workers": args.num_workers,
|
||||
"vf_loss_coeff": 0.01,
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(
|
||||
args.run,
|
||||
config=config,
|
||||
stop=stop,
|
||||
verbose=2,
|
||||
checkpoint_freq=5,
|
||||
checkpoint_at_end=True,
|
||||
restore=args.from_checkpoint,
|
||||
)
|
||||
print(results)
|
||||
ray.shutdown()
|
|
@ -115,10 +115,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
"attention_position_wise_mlp_dim": 32,
|
||||
# The initial bias values for the 2 GRU gates within a transformer unit.
|
||||
"attention_init_gru_gate_bias": 2.0,
|
||||
# TODO: Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
|
||||
# "attention_use_n_prev_actions": 0,
|
||||
# Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
|
||||
"attention_use_n_prev_actions": 0,
|
||||
# Whether to feed r_{t-n:t-1} to GTrXL.
|
||||
# "attention_use_n_prev_rewards": 0,
|
||||
"attention_use_n_prev_rewards": 0,
|
||||
|
||||
# == Atari ==
|
||||
# Which framestacking size to use for Atari envs.
|
||||
|
|
|
@ -414,14 +414,13 @@ def restore_original_dimensions(obs: TensorType,
|
|||
observation space.
|
||||
"""
|
||||
|
||||
if tensorlib == "tf":
|
||||
if tensorlib in ["tf", "tfe", "tf2"]:
|
||||
assert tf is not None
|
||||
tensorlib = tf
|
||||
elif tensorlib == "torch":
|
||||
assert torch is not None
|
||||
tensorlib = torch
|
||||
original_space = getattr(obs_space, "original_space", obs_space)
|
||||
if original_space is obs_space:
|
||||
return obs
|
||||
return _unpack_obs(obs, original_space, tensorlib=tensorlib)
|
||||
|
||||
|
||||
|
@ -450,7 +449,12 @@ def _unpack_obs(obs: TensorType, space: gym.Space,
|
|||
# Make an attempt to cache the result, if enough space left.
|
||||
if len(_cache) < 999:
|
||||
_cache[id(space)] = prep
|
||||
if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
|
||||
# Already unpacked?
|
||||
if (isinstance(space, gym.spaces.Tuple) and
|
||||
isinstance(obs, (list, tuple))) or \
|
||||
(isinstance(space, gym.spaces.Dict) and isinstance(obs, dict)):
|
||||
return obs
|
||||
elif len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
|
||||
raise ValueError(
|
||||
"Expected flattened obs shape of [..., {}], got {}".format(
|
||||
prep.shape[0], obs.shape))
|
||||
|
|
|
@ -22,6 +22,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import one_hot
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -353,6 +354,9 @@ class AttentionWrapper(TFModelV2):
|
|||
|
||||
super().__init__(obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
|
||||
self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
|
||||
|
||||
if isinstance(action_space, Discrete):
|
||||
self.action_dim = action_space.n
|
||||
elif isinstance(action_space, MultiDiscrete):
|
||||
|
@ -362,15 +366,30 @@ class AttentionWrapper(TFModelV2):
|
|||
else:
|
||||
self.action_dim = int(len(action_space))
|
||||
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_n_prev_actions:
|
||||
self.num_outputs += self.use_n_prev_actions * self.action_dim
|
||||
if self.use_n_prev_rewards:
|
||||
self.num_outputs += self.use_n_prev_rewards
|
||||
|
||||
cfg = model_config
|
||||
|
||||
self.attention_dim = cfg["attention_dim"]
|
||||
|
||||
if self.num_outputs is not None:
|
||||
in_space = gym.spaces.Box(
|
||||
float("-inf"),
|
||||
float("inf"),
|
||||
shape=(self.num_outputs, ),
|
||||
dtype=np.float32)
|
||||
else:
|
||||
in_space = obs_space
|
||||
|
||||
# Construct GTrXL sub-module w/ num_outputs=None (so it does not
|
||||
# create a logits/value output; we'll do this ourselves in this wrapper
|
||||
# here).
|
||||
self.gtrxl = GTrXLNet(
|
||||
obs_space,
|
||||
in_space,
|
||||
action_space,
|
||||
None,
|
||||
model_config,
|
||||
|
@ -401,6 +420,20 @@ class AttentionWrapper(TFModelV2):
|
|||
self._value_branch = tf.keras.models.Model([input_], [out])
|
||||
|
||||
self.view_requirements = self.gtrxl.view_requirements
|
||||
self.view_requirements["obs"].space = self.obs_space
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if self.use_n_prev_actions:
|
||||
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(
|
||||
SampleBatch.ACTIONS,
|
||||
space=self.action_space,
|
||||
shift="-{}:-1".format(self.use_n_prev_actions))
|
||||
if self.use_n_prev_rewards:
|
||||
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(
|
||||
SampleBatch.REWARDS,
|
||||
shift="-{}:-1".format(self.use_n_prev_rewards))
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
@ -410,8 +443,41 @@ class AttentionWrapper(TFModelV2):
|
|||
# Push obs through "unwrapped" net's `forward()` first.
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
prev_a_r = []
|
||||
if self.use_n_prev_actions:
|
||||
if isinstance(self.action_space, Discrete):
|
||||
for i in range(self.use_n_prev_actions):
|
||||
prev_a_r.append(
|
||||
one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
|
||||
self.action_space))
|
||||
elif isinstance(self.action_space, MultiDiscrete):
|
||||
for i in range(
|
||||
self.use_n_prev_actions,
|
||||
step=self.action_space.shape[0]):
|
||||
prev_a_r.append(
|
||||
one_hot(
|
||||
tf.cast(
|
||||
input_dict[SampleBatch.PREV_ACTIONS]
|
||||
[:, i:i + self.action_space.shape[0]],
|
||||
tf.float32), self.action_space))
|
||||
else:
|
||||
prev_a_r.append(
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
|
||||
tf.float32),
|
||||
[-1, self.use_n_prev_actions * self.action_dim]))
|
||||
if self.use_n_prev_rewards:
|
||||
prev_a_r.append(
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
|
||||
[-1, self.use_n_prev_rewards]))
|
||||
|
||||
if prev_a_r:
|
||||
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
|
||||
|
||||
# Then through our GTrXL.
|
||||
input_dict["obs_flat"] = wrapped_out
|
||||
input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
|
||||
|
||||
self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
|
||||
model_out = self._logits_branch(self._features)
|
||||
|
|
|
@ -70,7 +70,8 @@ class VisionNetwork(TFModelV2):
|
|||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=(stride, stride),
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=activation,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
|
@ -85,7 +86,8 @@ class VisionNetwork(TFModelV2):
|
|||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size if post_fcnet_hiddens else num_outputs,
|
||||
kernel,
|
||||
strides=(stride, stride),
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
|
@ -107,7 +109,8 @@ class VisionNetwork(TFModelV2):
|
|||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=(stride, stride),
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
|
@ -169,8 +172,9 @@ class VisionNetwork(TFModelV2):
|
|||
|
||||
# Build the value layers
|
||||
if vf_share_layers:
|
||||
last_layer = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
if not self.last_layer_is_flattened:
|
||||
last_layer = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
value_out = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
|
@ -183,7 +187,8 @@ class VisionNetwork(TFModelV2):
|
|||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=(stride, stride),
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=activation,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
|
@ -192,7 +197,8 @@ class VisionNetwork(TFModelV2):
|
|||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=(stride, stride),
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
|
|
|
@ -23,6 +23,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import one_hot
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
@ -250,6 +251,9 @@ class AttentionWrapper(TorchModelV2, nn.Module):
|
|||
nn.Module.__init__(self)
|
||||
super().__init__(obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
|
||||
self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
|
||||
|
||||
if isinstance(action_space, Discrete):
|
||||
self.action_dim = action_space.n
|
||||
elif isinstance(action_space, MultiDiscrete):
|
||||
|
@ -259,15 +263,30 @@ class AttentionWrapper(TorchModelV2, nn.Module):
|
|||
else:
|
||||
self.action_dim = int(len(action_space))
|
||||
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_n_prev_actions:
|
||||
self.num_outputs += self.use_n_prev_actions * self.action_dim
|
||||
if self.use_n_prev_rewards:
|
||||
self.num_outputs += self.use_n_prev_rewards
|
||||
|
||||
cfg = model_config
|
||||
|
||||
self.attention_dim = cfg["attention_dim"]
|
||||
|
||||
if self.num_outputs is not None:
|
||||
in_space = gym.spaces.Box(
|
||||
float("-inf"),
|
||||
float("inf"),
|
||||
shape=(self.num_outputs, ),
|
||||
dtype=np.float32)
|
||||
else:
|
||||
in_space = obs_space
|
||||
|
||||
# Construct GTrXL sub-module w/ num_outputs=None (so it does not
|
||||
# create a logits/value output; we'll do this ourselves in this wrapper
|
||||
# here).
|
||||
self.gtrxl = GTrXLNet(
|
||||
obs_space,
|
||||
in_space,
|
||||
action_space,
|
||||
None,
|
||||
model_config,
|
||||
|
@ -299,6 +318,20 @@ class AttentionWrapper(TorchModelV2, nn.Module):
|
|||
initializer=torch.nn.init.xavier_uniform_)
|
||||
|
||||
self.view_requirements = self.gtrxl.view_requirements
|
||||
self.view_requirements["obs"].space = self.obs_space
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if self.use_n_prev_actions:
|
||||
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(
|
||||
SampleBatch.ACTIONS,
|
||||
space=self.action_space,
|
||||
shift="-{}:-1".format(self.use_n_prev_actions))
|
||||
if self.use_n_prev_rewards:
|
||||
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(
|
||||
SampleBatch.REWARDS,
|
||||
shift="-{}:-1".format(self.use_n_prev_rewards))
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
@ -308,12 +341,43 @@ class AttentionWrapper(TorchModelV2, nn.Module):
|
|||
# Push obs through "unwrapped" net's `forward()` first.
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
prev_a_r = []
|
||||
if self.use_n_prev_actions:
|
||||
if isinstance(self.action_space, Discrete):
|
||||
for i in range(self.use_n_prev_actions):
|
||||
prev_a_r.append(
|
||||
one_hot(
|
||||
input_dict[SampleBatch.PREV_ACTIONS][:, i].float(),
|
||||
self.action_space))
|
||||
elif isinstance(self.action_space, MultiDiscrete):
|
||||
for i in range(
|
||||
self.use_n_prev_actions,
|
||||
step=self.action_space.shape[0]):
|
||||
prev_a_r.append(
|
||||
one_hot(
|
||||
input_dict[SampleBatch.PREV_ACTIONS]
|
||||
[:, i:i + self.action_space.shape[0]].float(),
|
||||
self.action_space))
|
||||
else:
|
||||
prev_a_r.append(
|
||||
torch.reshape(
|
||||
input_dict[SampleBatch.PREV_ACTIONS].float(),
|
||||
[-1, self.use_n_prev_actions * self.action_dim]))
|
||||
if self.use_n_prev_rewards:
|
||||
prev_a_r.append(
|
||||
torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
|
||||
[-1, self.use_n_prev_rewards]))
|
||||
|
||||
if prev_a_r:
|
||||
wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)
|
||||
|
||||
# Then through our GTrXL.
|
||||
input_dict["obs_flat"] = wrapped_out
|
||||
input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
|
||||
|
||||
self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
|
||||
model_out = self._logits_branch(self._features)
|
||||
return model_out, [torch.squeeze(m, 0) for m in memory_outs]
|
||||
return model_out, memory_outs
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
|
||||
|
|
|
@ -39,7 +39,10 @@ def same_padding(in_size: Tuple[int, int], filter_size: Tuple[int, int],
|
|||
filter_height, filter_width = filter_size, filter_size
|
||||
else:
|
||||
filter_height, filter_width = filter_size
|
||||
stride_height, stride_width = stride_size
|
||||
if isinstance(stride_size, (int, float)):
|
||||
stride_height, stride_width = int(stride_size), int(stride_size)
|
||||
else:
|
||||
stride_height, stride_width = int(stride_size[0]), int(stride_size[1])
|
||||
|
||||
out_height = np.ceil(float(in_height) / float(stride_height))
|
||||
out_width = np.ceil(float(in_width) / float(stride_width))
|
||||
|
|
|
@ -60,7 +60,7 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
|||
|
||||
in_size = [w, h]
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
padding, out_size = same_padding(in_size, kernel, [stride, stride])
|
||||
padding, out_size = same_padding(in_size, kernel, stride)
|
||||
layers.append(
|
||||
SlimConv2d(
|
||||
in_channels,
|
||||
|
@ -172,8 +172,7 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
|||
(w, h, in_channels) = obs_space.shape
|
||||
in_size = [w, h]
|
||||
for out_channels, kernel, stride in filters[:-1]:
|
||||
padding, out_size = same_padding(in_size, kernel,
|
||||
[stride, stride])
|
||||
padding, out_size = same_padding(in_size, kernel, stride)
|
||||
vf_layers.append(
|
||||
SlimConv2d(
|
||||
in_channels,
|
||||
|
|
|
@ -70,18 +70,29 @@ def get_filter_config(shape):
|
|||
inside a model config dict.
|
||||
"""
|
||||
shape = list(shape)
|
||||
# VizdoomGym.
|
||||
filters_240x320x = [
|
||||
[16, [12, 16], [7, 9]],
|
||||
[32, [6, 6], 4],
|
||||
[256, [9, 9], 1],
|
||||
]
|
||||
# Atari.
|
||||
filters_84x84 = [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
[256, [11, 11], 1],
|
||||
]
|
||||
# Small (1/2) Atari.
|
||||
filters_42x42 = [
|
||||
[16, [4, 4], 2],
|
||||
[32, [4, 4], 2],
|
||||
[256, [11, 11], 1],
|
||||
]
|
||||
if len(shape) in [2, 3] and (shape[:2] == [84, 84]
|
||||
or shape[1:] == [84, 84]):
|
||||
if len(shape) in [2, 3] and (shape[:2] == [240, 320]
|
||||
or shape[1:] == [240, 320]):
|
||||
return filters_240x320x
|
||||
elif len(shape) in [2, 3] and (shape[:2] == [84, 84]
|
||||
or shape[1:] == [84, 84]):
|
||||
return filters_84x84
|
||||
elif len(shape) in [2, 3] and (shape[:2] == [42, 42]
|
||||
or shape[1:] == [42, 42]):
|
||||
|
|
|
@ -214,15 +214,21 @@ class DynamicTFPolicy(TFPolicy):
|
|||
self.view_requirements, existing_inputs)
|
||||
else:
|
||||
action_ph = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_action_ph = ModelCatalog.get_action_placeholder(
|
||||
action_space, "prev_action")
|
||||
if self.config["_use_trajectory_view_api"]:
|
||||
prev_action_ph = {}
|
||||
if SampleBatch.PREV_ACTIONS not in self.view_requirements:
|
||||
prev_action_ph = {
|
||||
SampleBatch.PREV_ACTIONS: ModelCatalog.
|
||||
get_action_placeholder(action_space, "prev_action")
|
||||
}
|
||||
self._input_dict, self._dummy_batch = \
|
||||
self._get_input_dict_and_dummy_batch(
|
||||
self.view_requirements,
|
||||
{SampleBatch.ACTIONS: action_ph,
|
||||
SampleBatch.PREV_ACTIONS: prev_action_ph})
|
||||
dict({SampleBatch.ACTIONS: action_ph},
|
||||
**prev_action_ph))
|
||||
else:
|
||||
prev_action_ph = ModelCatalog.get_action_placeholder(
|
||||
action_space, "prev_action")
|
||||
self._input_dict = {
|
||||
SampleBatch.CUR_OBS: tf1.placeholder(
|
||||
tf.float32,
|
||||
|
|
Loading…
Add table
Reference in a new issue