mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Support native tf.keras.Models (part 2) - Default keras models for Vision/RNN/Attention. (#15273)
This commit is contained in:
parent
bdbf39f9d5
commit
e973b726c2
26 changed files with 924 additions and 143 deletions
|
@ -531,6 +531,12 @@ py_test(
|
|||
size = "large",
|
||||
srcs = ["agents/dqn/tests/test_dqn.py"]
|
||||
)
|
||||
py_test(
|
||||
name = "test_r2d2",
|
||||
tags = ["agents_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/dqn/tests/test_r2d2.py"]
|
||||
)
|
||||
py_test(
|
||||
name = "test_simple_q",
|
||||
tags = ["agents_dir"],
|
||||
|
|
|
@ -36,7 +36,7 @@ class TestR2D2(unittest.TestCase):
|
|||
config["lr"] = 5e-4
|
||||
config["exploration_config"]["epsilon_timesteps"] = 100000
|
||||
|
||||
num_iterations = 2
|
||||
num_iterations = 1
|
||||
|
||||
# Test building an R2D2 agent in all frameworks.
|
||||
for _ in framework_iterator(config):
|
||||
|
|
|
@ -82,7 +82,7 @@ class TestPPO(unittest.TestCase):
|
|||
# Settings in case we use an LSTM.
|
||||
config["model"]["lstm_cell_size"] = 10
|
||||
config["model"]["max_seq_len"] = 20
|
||||
# Use default-native keras model whenever possible.
|
||||
# Use default-native keras models whenever possible.
|
||||
config["model"]["_use_default_native_models"] = True
|
||||
|
||||
config["train_batch_size"] = 128
|
||||
|
@ -93,7 +93,7 @@ class TestPPO(unittest.TestCase):
|
|||
for _ in framework_iterator(config):
|
||||
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
|
||||
print("Env={}".format(env))
|
||||
for lstm in [False, True]:
|
||||
for lstm in [True, False]:
|
||||
print("LSTM={}".format(lstm))
|
||||
config["model"]["use_lstm"] = lstm
|
||||
config["model"]["lstm_use_prev_action"] = lstm
|
||||
|
|
|
@ -366,7 +366,7 @@ class _PolicyCollector:
|
|||
this policy.
|
||||
"""
|
||||
# Create batch from our buffers.
|
||||
batch = SampleBatch(self.buffers, _seq_lens=self.seq_lens)
|
||||
batch = SampleBatch(self.buffers, seq_lens=self.seq_lens)
|
||||
# Clear buffers for future samples.
|
||||
self.buffers.clear()
|
||||
# Reset agent steps to 0 and seq-lens to empty list.
|
||||
|
|
|
@ -121,7 +121,7 @@ def compute_gae_for_sample_batch(
|
|||
# Create an input dict according to the Model's requirements.
|
||||
input_dict = sample_batch.get_single_step_input_dict(
|
||||
policy.model.view_requirements, index="last")
|
||||
last_r = policy._value(**input_dict, seq_lens=input_dict.seq_lens)
|
||||
last_r = policy._value(**input_dict)
|
||||
|
||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||
# using GAE ("generalized advantage estimation") or not.
|
||||
|
|
|
@ -26,10 +26,10 @@ class MyCallbacks(DefaultCallbacks):
|
|||
@override(DefaultCallbacks)
|
||||
def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
|
||||
assert train_batch.count == 201
|
||||
assert sum(train_batch.seq_lens) == 201
|
||||
assert sum(train_batch["seq_lens"]) == 201
|
||||
for k, v in train_batch.items():
|
||||
if k == "state_in_0":
|
||||
assert len(v) == len(train_batch.seq_lens)
|
||||
assert len(v) == len(train_batch["seq_lens"])
|
||||
else:
|
||||
assert len(v) == 201
|
||||
current = None
|
||||
|
|
|
@ -17,7 +17,8 @@ parser = argparse.ArgumentParser()
|
|||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
|
||||
parser.add_argument("--num-cpus", type=int, default=3)
|
||||
parser.add_argument("--framework", choices=["tf", "torch"], default="tf")
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf", "tf2", "tfe", "torch"], default="tf")
|
||||
parser.add_argument("--as-test", action="store_true")
|
||||
parser.add_argument("--stop-iters", type=int, default=200)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=500000)
|
||||
|
@ -48,6 +49,9 @@ if __name__ == "__main__":
|
|||
"num_sgd_iter": 10,
|
||||
"vf_loss_coeff": 1e-5,
|
||||
"model": {
|
||||
# Attention net wrapping (for tf) can already use the native keras
|
||||
# model versions. For torch, this will have no effect.
|
||||
"_use_default_native_models": True,
|
||||
"use_attention": True,
|
||||
"max_seq_len": 10,
|
||||
"attention_num_transformer_units": 1,
|
||||
|
|
|
@ -36,11 +36,11 @@ class RNNModel(tf.keras.models.Model if tf else object):
|
|||
|
||||
def call(self, sample_batch):
|
||||
dense_out = self.dense(sample_batch["obs"])
|
||||
B = tf.shape(sample_batch.seq_lens)[0]
|
||||
B = tf.shape(sample_batch["seq_lens"])[0]
|
||||
lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]])
|
||||
lstm_out, h, c = self.lstm(
|
||||
inputs=lstm_in,
|
||||
mask=tf.sequence_mask(sample_batch.seq_lens),
|
||||
mask=tf.sequence_mask(sample_batch["seq_lens"]),
|
||||
initial_state=[
|
||||
sample_batch["state_in_0"], sample_batch["state_in_1"]
|
||||
],
|
||||
|
|
|
@ -82,7 +82,7 @@ class TorchParametricActionsModel(DQNTorchModel):
|
|||
model_config, name + "_action_embed")
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Extract the available actions tensor from the observation.
|
||||
# Extract the available act ions tensor from the observation.
|
||||
avail_actions = input_dict["obs"]["avail_actions"]
|
||||
action_mask = input_dict["obs"]["action_mask"]
|
||||
|
||||
|
|
|
@ -39,8 +39,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
# If True, try to use a native (tf.keras.Model or torch.Module) default
|
||||
# model instead of our built-in ModelV2 defaults.
|
||||
# If False (default), use "classic" ModelV2 default models.
|
||||
# Note that this currently only works for framework != torch AND fully
|
||||
# connected default networks.
|
||||
# Note that this currently only works for:
|
||||
# 1) framework != torch AND
|
||||
# 2) fully connected and CNN default networks as well as
|
||||
# auto-wrapped LSTM- and attention nets.
|
||||
"_use_default_native_models": False,
|
||||
|
||||
# === Built-in options ===
|
||||
|
@ -418,15 +420,26 @@ class ModelCatalog:
|
|||
if model_config.get("use_lstm") or \
|
||||
model_config.get("use_attention"):
|
||||
from ray.rllib.models.tf.attention_net import \
|
||||
AttentionWrapper
|
||||
from ray.rllib.models.tf.recurrent_net import LSTMWrapper
|
||||
AttentionWrapper, Keras_AttentionWrapper
|
||||
from ray.rllib.models.tf.recurrent_net import \
|
||||
LSTMWrapper, Keras_LSTMWrapper
|
||||
|
||||
wrapped_cls = model_cls
|
||||
forward = wrapped_cls.forward
|
||||
model_cls = ModelCatalog._wrap_if_needed(
|
||||
wrapped_cls, LSTMWrapper
|
||||
if model_config.get("use_lstm") else AttentionWrapper)
|
||||
model_cls._wrapped_forward = forward
|
||||
# Wrapped (custom) model is itself a keras Model ->
|
||||
# wrap with keras LSTM/GTrXL (attention) wrappers.
|
||||
if issubclass(wrapped_cls, tf.keras.Model):
|
||||
model_cls = Keras_LSTMWrapper if \
|
||||
model_config.get("use_lstm") else \
|
||||
Keras_AttentionWrapper
|
||||
model_config["wrapped_cls"] = wrapped_cls
|
||||
# Wrapped (custom) model is ModelV2 ->
|
||||
# wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
|
||||
else:
|
||||
forward = wrapped_cls.forward
|
||||
model_cls = ModelCatalog._wrap_if_needed(
|
||||
wrapped_cls, LSTMWrapper if
|
||||
model_config.get("use_lstm") else AttentionWrapper)
|
||||
model_cls._wrapped_forward = forward
|
||||
|
||||
# Obsolete: Track and warn if vars were created but not
|
||||
# registered. Only still do this, if users do register their
|
||||
|
@ -561,31 +574,41 @@ class ModelCatalog:
|
|||
model_config.get("use_attention"):
|
||||
|
||||
from ray.rllib.models.tf.attention_net import \
|
||||
AttentionWrapper
|
||||
from ray.rllib.models.tf.recurrent_net import LSTMWrapper
|
||||
AttentionWrapper, Keras_AttentionWrapper
|
||||
from ray.rllib.models.tf.recurrent_net import LSTMWrapper, \
|
||||
Keras_LSTMWrapper
|
||||
|
||||
wrapped_cls = v2_class
|
||||
forward = wrapped_cls.forward
|
||||
if model_config.get("use_lstm"):
|
||||
v2_class = ModelCatalog._wrap_if_needed(
|
||||
wrapped_cls, LSTMWrapper)
|
||||
if issubclass(wrapped_cls, tf.keras.Model):
|
||||
v2_class = Keras_LSTMWrapper
|
||||
model_config["wrapped_cls"] = wrapped_cls
|
||||
else:
|
||||
v2_class = ModelCatalog._wrap_if_needed(
|
||||
wrapped_cls, LSTMWrapper)
|
||||
v2_class._wrapped_forward = wrapped_cls.forward
|
||||
else:
|
||||
v2_class = ModelCatalog._wrap_if_needed(
|
||||
wrapped_cls, AttentionWrapper)
|
||||
|
||||
v2_class._wrapped_forward = forward
|
||||
if issubclass(wrapped_cls, tf.keras.Model):
|
||||
v2_class = Keras_AttentionWrapper
|
||||
model_config["wrapped_cls"] = wrapped_cls
|
||||
else:
|
||||
v2_class = ModelCatalog._wrap_if_needed(
|
||||
wrapped_cls, AttentionWrapper)
|
||||
v2_class._wrapped_forward = wrapped_cls.forward
|
||||
|
||||
# Wrap in the requested interface.
|
||||
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
|
||||
|
||||
if issubclass(wrapper, tf.keras.Model):
|
||||
return wrapper(
|
||||
model = wrapper(
|
||||
input_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
name=name,
|
||||
**dict(model_kwargs, **model_config),
|
||||
)
|
||||
return model
|
||||
|
||||
return wrapper(obs_space, action_space, num_outputs, model_config,
|
||||
name, **model_kwargs)
|
||||
|
||||
|
@ -759,13 +782,15 @@ class ModelCatalog:
|
|||
VisionNet = None
|
||||
ComplexNet = None
|
||||
Keras_FCNet = None
|
||||
Keras_VisionNet = None
|
||||
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
from ray.rllib.models.tf.fcnet import \
|
||||
FullyConnectedNetwork as FCNet, \
|
||||
Keras_FullyConnectedNetwork as Keras_FCNet
|
||||
from ray.rllib.models.tf.visionnet import \
|
||||
VisionNetwork as VisionNet
|
||||
VisionNetwork as VisionNet, \
|
||||
Keras_VisionNetwork as Keras_VisionNet
|
||||
from ray.rllib.models.tf.complex_input_net import \
|
||||
ComplexInputNetwork as ComplexNet
|
||||
elif framework == "torch":
|
||||
|
@ -802,10 +827,8 @@ class ModelCatalog:
|
|||
len(input_space.shape) == 1 or (
|
||||
len(input_space.shape) == 2 and (
|
||||
num_framestacks == "auto" or num_framestacks <= 1)):
|
||||
# Keras native requested AND no auto-rnn-wrapping AND .
|
||||
if model_config.get("_use_default_native_models") and \
|
||||
Keras_FCNet and not model_config.get("use_lstm") and \
|
||||
not model_config.get("use_attention"):
|
||||
# Keras native requested AND no auto-rnn-wrapping.
|
||||
if model_config.get("_use_default_native_models") and Keras_FCNet:
|
||||
return Keras_FCNet
|
||||
# Classic ModelV2 FCNet.
|
||||
else:
|
||||
|
@ -815,6 +838,8 @@ class ModelCatalog:
|
|||
raise NotImplementedError("No non-FC default net for JAX yet!")
|
||||
|
||||
# Last resort: Conv2D stack for single image spaces.
|
||||
if model_config.get("_use_default_native_models") and Keras_VisionNet:
|
||||
return Keras_VisionNet
|
||||
return VisionNet
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -210,7 +210,7 @@ class ModelV2:
|
|||
restored = input_dict.copy(shallow=True)
|
||||
# Backward compatibility.
|
||||
if seq_lens is None:
|
||||
seq_lens = input_dict.seq_lens
|
||||
seq_lens = input_dict.get("seq_lens")
|
||||
if not state:
|
||||
state = []
|
||||
i = 0
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
from gym.spaces import Box, Discrete, MultiDiscrete
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
|
||||
|
@ -491,3 +491,347 @@ class AttentionWrapper(TFModelV2):
|
|||
def value_function(self) -> TensorType:
|
||||
assert self._features is not None, "Must call forward() first!"
|
||||
return tf.reshape(self._value_branch(self._features), [-1])
|
||||
|
||||
|
||||
class Keras_GTrXLNet(tf.keras.Model if tf else object):
|
||||
"""A GTrXL net Model described in [2].
|
||||
|
||||
This is still in an experimental phase.
|
||||
Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
|
||||
For an example script, see: `ray/rllib/examples/attention_net.py`.
|
||||
|
||||
To use this network as a replacement for an RNN, configure your Trainer
|
||||
as follows:
|
||||
|
||||
Examples:
|
||||
>> config["model"]["custom_model"] = GTrXLNet
|
||||
>> config["model"]["max_seq_len"] = 10
|
||||
>> config["model"]["custom_model_config"] = {
|
||||
>> num_transformer_units=1,
|
||||
>> attention_dim=32,
|
||||
>> num_heads=2,
|
||||
>> memory_inference=100,
|
||||
>> memory_training=50,
|
||||
>> etc..
|
||||
>> }
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
*,
|
||||
name: str,
|
||||
max_seq_len: int = 20,
|
||||
num_transformer_units: int = 1,
|
||||
attention_dim: int = 64,
|
||||
num_heads: int = 2,
|
||||
memory_inference: int = 50,
|
||||
memory_training: int = 50,
|
||||
head_dim: int = 32,
|
||||
position_wise_mlp_dim: int = 32,
|
||||
init_gru_gate_bias: float = 2.0):
|
||||
"""Initializes a GTrXLNet instance.
|
||||
|
||||
Args:
|
||||
num_transformer_units (int): The number of Transformer repeats to
|
||||
use (denoted L in [2]).
|
||||
attention_dim (int): The input and output dimensions of one
|
||||
Transformer unit.
|
||||
num_heads (int): The number of attention heads to use in parallel.
|
||||
Denoted as `H` in [3].
|
||||
memory_inference (int): The number of timesteps to concat (time
|
||||
axis) and feed into the next transformer unit as inference
|
||||
input. The first transformer unit will receive this number of
|
||||
past observations (plus the current one), instead.
|
||||
memory_training (int): The number of timesteps to concat (time
|
||||
axis) and feed into the next transformer unit as training
|
||||
input (plus the actual input sequence of len=max_seq_len).
|
||||
The first transformer unit will receive this number of
|
||||
past observations (plus the input sequence), instead.
|
||||
head_dim (int): The dimension of a single(!) attention head within
|
||||
a multi-head attention unit. Denoted as `d` in [3].
|
||||
position_wise_mlp_dim (int): The dimension of the hidden layer
|
||||
within the position-wise MLP (after the multi-head attention
|
||||
block within one Transformer unit). This is the size of the
|
||||
first of the two layers within the PositionwiseFeedforward. The
|
||||
second layer always has size=`attention_dim`.
|
||||
init_gru_gate_bias (float): Initial bias values for the GRU gates
|
||||
(two GRUs per Transformer unit, one after the MHA, one after
|
||||
the position-wise MLP).
|
||||
"""
|
||||
|
||||
super().__init__(name=name)
|
||||
|
||||
self.num_transformer_units = num_transformer_units
|
||||
self.attention_dim = attention_dim
|
||||
self.num_heads = num_heads
|
||||
self.memory_inference = memory_inference
|
||||
self.memory_training = memory_training
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.obs_dim = input_space.shape[0]
|
||||
|
||||
# Raw observation input (plus (None) time axis).
|
||||
input_layer = tf.keras.layers.Input(
|
||||
shape=(
|
||||
None,
|
||||
self.obs_dim,
|
||||
), name="inputs")
|
||||
memory_ins = [
|
||||
tf.keras.layers.Input(
|
||||
shape=(
|
||||
None,
|
||||
self.attention_dim,
|
||||
),
|
||||
dtype=tf.float32,
|
||||
name="memory_in_{}".format(i))
|
||||
for i in range(self.num_transformer_units)
|
||||
]
|
||||
|
||||
# Map observation dim to input/output transformer (attention) dim.
|
||||
E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
|
||||
# Output, collected and concat'd to build the internal, tau-len
|
||||
# Memory units used for additional contextual information.
|
||||
memory_outs = [E_out]
|
||||
|
||||
# 2) Create L Transformer blocks according to [2].
|
||||
for i in range(self.num_transformer_units):
|
||||
# RelativeMultiHeadAttention part.
|
||||
MHA_out = SkipConnection(
|
||||
RelativeMultiHeadAttention(
|
||||
out_dim=self.attention_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
input_layernorm=True,
|
||||
output_activation=tf.nn.relu),
|
||||
fan_in_layer=GRUGate(init_gru_gate_bias),
|
||||
name="mha_{}".format(i + 1))(
|
||||
E_out, memory=memory_ins[i])
|
||||
# Position-wise MLP part.
|
||||
E_out = SkipConnection(
|
||||
tf.keras.Sequential(
|
||||
(tf.keras.layers.LayerNormalization(axis=-1),
|
||||
PositionwiseFeedforward(
|
||||
out_dim=self.attention_dim,
|
||||
hidden_dim=position_wise_mlp_dim,
|
||||
output_activation=tf.nn.relu))),
|
||||
fan_in_layer=GRUGate(init_gru_gate_bias),
|
||||
name="pos_wise_mlp_{}".format(i + 1))(MHA_out)
|
||||
# Output of position-wise MLP == E(l-1), which is concat'd
|
||||
# to the current Mem block (M(l-1)) to yield E~(l-1), which is then
|
||||
# used by the next transformer block.
|
||||
memory_outs.append(E_out)
|
||||
|
||||
self._logits = None
|
||||
self._value_out = None
|
||||
|
||||
self.trxl_model = tf.keras.Model(
|
||||
inputs=[input_layer] + memory_ins,
|
||||
outputs=[E_out] + memory_outs[:-1])
|
||||
|
||||
self.view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(space=input_space),
|
||||
}
|
||||
# Setup trajectory views (`memory-inference` x past memory outs).
|
||||
for i in range(self.num_transformer_units):
|
||||
space = Box(-1.0, 1.0, shape=(self.attention_dim, ))
|
||||
self.view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift="-{}:-1".format(self.memory_inference),
|
||||
# Repeat the incoming state every max-seq-len times.
|
||||
batch_repeat_value=self.max_seq_len,
|
||||
space=space)
|
||||
self.view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=space,
|
||||
used_for_training=False)
|
||||
|
||||
def call(self, inputs, memory_ins) -> (TensorType, List[TensorType]):
|
||||
# Add the time dim to observations.
|
||||
B = tf.shape(memory_ins[0])[0]
|
||||
|
||||
shape = tf.shape(inputs)
|
||||
T = shape[0] // B
|
||||
inputs = tf.reshape(inputs, tf.concat([[-1, T], shape[1:]], axis=0))
|
||||
|
||||
all_out = self.trxl_model([inputs] + memory_ins)
|
||||
out = tf.reshape(all_out[0], [-1, self.attention_dim])
|
||||
memory_outs = all_out[1:]
|
||||
|
||||
return out, [
|
||||
tf.reshape(m, [-1, self.attention_dim]) for m in memory_outs
|
||||
]
|
||||
|
||||
|
||||
class Keras_AttentionWrapper(tf.keras.Model if tf else object):
|
||||
"""A tf keras auto-GTrXL wrapper used when `use_attention`=True."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: Optional[int] = None,
|
||||
*,
|
||||
name: str,
|
||||
wrapped_cls: Type["tf.keras.Model"],
|
||||
max_seq_len: int = 20,
|
||||
attention_num_transformer_units: int = 1,
|
||||
attention_dim: int = 64,
|
||||
attention_num_heads: int = 1,
|
||||
attention_head_dim: int = 32,
|
||||
attention_memory_inference: int = 50,
|
||||
attention_memory_training: int = 50,
|
||||
attention_position_wise_mlp_dim: int = 32,
|
||||
attention_init_gru_gate_bias: int = 2.0,
|
||||
attention_use_n_prev_actions: int = 0,
|
||||
attention_use_n_prev_rewards: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(name=name)
|
||||
self.wrapped_keras_model = wrapped_cls(
|
||||
input_space, action_space, None, name="wrapped_" + name, **kwargs)
|
||||
|
||||
self.action_space = action_space
|
||||
self.max_seq_len = max_seq_len
|
||||
self.use_n_prev_actions = attention_use_n_prev_actions
|
||||
self.use_n_prev_rewards = attention_use_n_prev_rewards
|
||||
self.attention_dim = attention_dim
|
||||
|
||||
# Guess the number of outputs for the wrapped model by looking
|
||||
# at its first output's shape.
|
||||
# This will be the input size for the LSTM layer (plus
|
||||
# maybe prev-actions/rewards).
|
||||
# If no layers in the wrapped model, set it to the
|
||||
# observation space.
|
||||
if self.wrapped_keras_model.layers:
|
||||
assert self.wrapped_keras_model.layers[-1].outputs
|
||||
assert len(
|
||||
self.wrapped_keras_model.layers[-1].outputs[0].shape) == 2
|
||||
wrapped_num_outputs = int(
|
||||
self.wrapped_keras_model.layers[-1].outputs[0].shape[1])
|
||||
else:
|
||||
wrapped_num_outputs = int(np.product(self.obs_space.shape))
|
||||
|
||||
if isinstance(action_space, Discrete):
|
||||
self.action_dim = action_space.n
|
||||
elif isinstance(action_space, MultiDiscrete):
|
||||
self.action_dim = np.product(action_space.nvec)
|
||||
elif action_space.shape is not None:
|
||||
self.action_dim = int(np.product(action_space.shape))
|
||||
else:
|
||||
self.action_dim = int(len(action_space))
|
||||
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_n_prev_actions:
|
||||
wrapped_num_outputs += self.use_n_prev_actions * self.action_dim
|
||||
if self.use_n_prev_rewards:
|
||||
wrapped_num_outputs += self.use_n_prev_rewards
|
||||
|
||||
in_space = gym.spaces.Box(
|
||||
float("-inf"),
|
||||
float("inf"),
|
||||
shape=(wrapped_num_outputs, ),
|
||||
dtype=np.float32)
|
||||
|
||||
input_ = tf.keras.layers.Input(
|
||||
shape=(wrapped_num_outputs, ), name="inputs")
|
||||
memory_ins = [
|
||||
tf.keras.layers.Input(
|
||||
shape=(
|
||||
None,
|
||||
self.attention_dim,
|
||||
), name=f"memory_in_{i}")
|
||||
for i in range(attention_num_transformer_units)
|
||||
]
|
||||
# Construct GTrXL sub-module.
|
||||
self.gtrxl = Keras_GTrXLNet(
|
||||
in_space,
|
||||
action_space,
|
||||
name="gtrxl",
|
||||
max_seq_len=max_seq_len,
|
||||
num_transformer_units=attention_num_transformer_units,
|
||||
attention_dim=self.attention_dim,
|
||||
num_heads=attention_num_heads,
|
||||
head_dim=attention_head_dim,
|
||||
memory_inference=attention_memory_inference,
|
||||
memory_training=attention_memory_training,
|
||||
position_wise_mlp_dim=attention_position_wise_mlp_dim,
|
||||
init_gru_gate_bias=attention_init_gru_gate_bias,
|
||||
)
|
||||
keras_gtrxl_model_out, memory_outs = self.gtrxl(input_, memory_ins)
|
||||
|
||||
# Postprocess GTrXL output with another hidden layer and compute
|
||||
# values.
|
||||
logits = tf.keras.layers.Dense(
|
||||
num_outputs, activation=None)(keras_gtrxl_model_out)
|
||||
|
||||
value_outs = tf.keras.layers.Dense(
|
||||
1, activation=None)(keras_gtrxl_model_out)
|
||||
self.base_model = tf.keras.models.Model(
|
||||
[input_, memory_ins], [logits, memory_outs, value_outs])
|
||||
|
||||
self.view_requirements = self.gtrxl.view_requirements
|
||||
self.view_requirements["obs"].space = input_space
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if self.use_n_prev_actions:
|
||||
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(
|
||||
SampleBatch.ACTIONS,
|
||||
space=self.action_space,
|
||||
shift="-{}:-1".format(self.use_n_prev_actions))
|
||||
if self.use_n_prev_rewards:
|
||||
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(
|
||||
SampleBatch.REWARDS,
|
||||
shift="-{}:-1".format(self.use_n_prev_rewards))
|
||||
|
||||
def call(self, input_dict: SampleBatch) -> \
|
||||
(TensorType, List[TensorType], Dict[str, TensorType]):
|
||||
assert input_dict["seq_lens"] is not None
|
||||
# Push obs through "unwrapped" net's `forward()` first.
|
||||
wrapped_out, _, _ = self.wrapped_keras_model(input_dict)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
prev_a_r = []
|
||||
if self.use_n_prev_actions:
|
||||
if isinstance(self.action_space, Discrete):
|
||||
for i in range(self.use_n_prev_actions):
|
||||
prev_a_r.append(
|
||||
one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
|
||||
self.action_space))
|
||||
elif isinstance(self.action_space, MultiDiscrete):
|
||||
for i in range(
|
||||
self.use_n_prev_actions,
|
||||
step=self.action_space.shape[0]):
|
||||
prev_a_r.append(
|
||||
one_hot(
|
||||
tf.cast(
|
||||
input_dict[SampleBatch.PREV_ACTIONS]
|
||||
[:, i:i + self.action_space.shape[0]],
|
||||
tf.float32), self.action_space))
|
||||
else:
|
||||
prev_a_r.append(
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
|
||||
tf.float32),
|
||||
[-1, self.use_n_prev_actions * self.action_dim]))
|
||||
if self.use_n_prev_rewards:
|
||||
prev_a_r.append(
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
|
||||
[-1, self.use_n_prev_rewards]))
|
||||
|
||||
if prev_a_r:
|
||||
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
|
||||
|
||||
memory_ins = [
|
||||
s for k, s in input_dict.items() if k.startswith("state_in_")
|
||||
]
|
||||
model_out, memory_outs, value_outs = self.base_model([wrapped_out] +
|
||||
memory_ins)
|
||||
return model_out, memory_outs, {
|
||||
SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1])
|
||||
}
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.utils import get_activation_fn
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
|
||||
from ray.rllib.utils.typing import TensorType, List, ModelConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
# TODO: (sven) obsolete this class once we only support native keras models.
|
||||
class FullyConnectedNetwork(TFModelV2):
|
||||
"""Generic fully connected network implemented in ModelV2 API."""
|
||||
|
||||
|
@ -138,7 +139,7 @@ class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
|
|||
self,
|
||||
input_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
num_outputs: Optional[int] = None,
|
||||
*,
|
||||
name: str = "",
|
||||
fcnet_hiddens: Optional[Sequence[int]] = (),
|
||||
|
@ -209,10 +210,6 @@ class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
|
|||
name="fc_out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(last_layer)
|
||||
# Adjust num_outputs to be the number of nodes in the last layer.
|
||||
else:
|
||||
self.num_outputs = (
|
||||
[int(np.product(input_space.shape))] + hiddens[-1:])[-1]
|
||||
|
||||
# Concat the log std vars to the end of the state-dependent means.
|
||||
if free_log_std and logits_out is not None:
|
||||
|
@ -249,9 +246,8 @@ class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
|
|||
inputs, [(logits_out
|
||||
if logits_out is not None else last_layer), value_out])
|
||||
|
||||
def call(self, input_dict: Dict[str, TensorType]) -> \
|
||||
Tuple[TensorType, List[TensorType], TensorType]:
|
||||
model_out, value_out = self.base_model(input_dict["obs"])
|
||||
return model_out, [], {
|
||||
SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])
|
||||
}
|
||||
def call(self, input_dict: SampleBatch) -> \
|
||||
(TensorType, List[TensorType], Dict[str, TensorType]):
|
||||
model_out, value_out = self.base_model(input_dict[SampleBatch.OBS])
|
||||
extra_outs = {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])}
|
||||
return model_out, [], extra_outs
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Discrete, MultiDiscrete
|
||||
from typing import Dict, List
|
||||
from gym.spaces import Box, Discrete, MultiDiscrete
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
|
@ -235,3 +235,167 @@ class LSTMWrapper(RecurrentNetwork):
|
|||
@override(ModelV2)
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
class Keras_LSTMWrapper(tf.keras.Model if tf else object):
|
||||
"""A tf keras auto-LSTM wrapper used when `use_lstm`=True."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: Optional[int] = None,
|
||||
*,
|
||||
name: str,
|
||||
wrapped_cls: Type["tf.keras.Model"],
|
||||
max_seq_len: int = 20,
|
||||
lstm_cell_size: int = 256,
|
||||
lstm_use_prev_action: bool = False,
|
||||
lstm_use_prev_reward: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(name=name)
|
||||
self.wrapped_keras_model = wrapped_cls(
|
||||
input_space, action_space, None, name="wrapped_" + name, **kwargs)
|
||||
|
||||
self.action_space = action_space
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# Guess the number of outputs for the wrapped model by looking
|
||||
# at its first output's shape.
|
||||
# This will be the input size for the LSTM layer (plus
|
||||
# maybe prev-actions/rewards).
|
||||
# If no layers in the wrapped model, set it to the
|
||||
# observation space.
|
||||
if self.wrapped_keras_model.layers:
|
||||
assert self.wrapped_keras_model.layers[-1].outputs
|
||||
assert len(
|
||||
self.wrapped_keras_model.layers[-1].outputs[0].shape) == 2
|
||||
wrapped_num_outputs = int(
|
||||
self.wrapped_keras_model.layers[-1].outputs[0].shape[1])
|
||||
else:
|
||||
wrapped_num_outputs = int(np.product(self.obs_space.shape))
|
||||
|
||||
self.lstm_cell_size = lstm_cell_size
|
||||
self.lstm_use_prev_action = lstm_use_prev_action
|
||||
self.lstm_use_prev_reward = lstm_use_prev_reward
|
||||
|
||||
if isinstance(self.action_space, Discrete):
|
||||
self.action_dim = self.action_space.n
|
||||
elif isinstance(self.action_space, MultiDiscrete):
|
||||
self.action_dim = np.sum(self.action_space.nvec)
|
||||
elif self.action_space.shape is not None:
|
||||
self.action_dim = int(np.product(self.action_space.shape))
|
||||
else:
|
||||
self.action_dim = int(len(self.action_space))
|
||||
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.lstm_use_prev_action:
|
||||
wrapped_num_outputs += self.action_dim
|
||||
if self.lstm_use_prev_reward:
|
||||
wrapped_num_outputs += 1
|
||||
|
||||
# Define input layers.
|
||||
input_layer = tf.keras.layers.Input(
|
||||
shape=(None, wrapped_num_outputs), name="inputs")
|
||||
|
||||
state_in_h = tf.keras.layers.Input(
|
||||
shape=(self.lstm_cell_size, ), name="h")
|
||||
state_in_c = tf.keras.layers.Input(
|
||||
shape=(self.lstm_cell_size, ), name="c")
|
||||
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
|
||||
|
||||
# Preprocess observation with a hidden layer and send to LSTM cell
|
||||
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
|
||||
self.lstm_cell_size,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
name="lstm")(
|
||||
inputs=input_layer,
|
||||
mask=tf.sequence_mask(seq_in),
|
||||
initial_state=[state_in_h, state_in_c])
|
||||
|
||||
# Postprocess LSTM output with another hidden layer
|
||||
# if num_outputs not None.
|
||||
if num_outputs:
|
||||
logits = tf.keras.layers.Dense(
|
||||
num_outputs,
|
||||
activation=tf.keras.activations.linear,
|
||||
name="logits")(lstm_out)
|
||||
else:
|
||||
logits = lstm_out
|
||||
# Compute values.
|
||||
values = tf.keras.layers.Dense(
|
||||
1, activation=None, name="values")(lstm_out)
|
||||
|
||||
# Create the RNN model
|
||||
self._rnn_model = tf.keras.Model(
|
||||
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
||||
outputs=[logits, values, state_h, state_c])
|
||||
|
||||
# Use view-requirements of wrapped model and add own
|
||||
# requirements.
|
||||
self.view_requirements = \
|
||||
getattr(self.wrapped_keras_model, "view_requirements", {
|
||||
SampleBatch.OBS: ViewRequirement(space=input_space)
|
||||
})
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if self.lstm_use_prev_action:
|
||||
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
if self.lstm_use_prev_reward:
|
||||
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
|
||||
# Internal states view requirements.
|
||||
for i in range(2):
|
||||
space = Box(-1.0, 1.0, shape=(self.lstm_cell_size, ))
|
||||
self.view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
used_for_compute_actions=True,
|
||||
batch_repeat_value=max_seq_len,
|
||||
space=space)
|
||||
self.view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=space, used_for_training=True)
|
||||
|
||||
def call(self, input_dict: SampleBatch) -> \
|
||||
(TensorType, List[TensorType], Dict[str, TensorType]):
|
||||
assert input_dict.get("seq_lens") is not None
|
||||
# Push obs through underlying (wrapped) model first.
|
||||
wrapped_out, _, _ = self.wrapped_keras_model(input_dict)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
prev_a_r = []
|
||||
if self.lstm_use_prev_action:
|
||||
prev_a = input_dict[SampleBatch.PREV_ACTIONS]
|
||||
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
|
||||
prev_a = one_hot(prev_a, self.action_space)
|
||||
prev_a_r.append(
|
||||
tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]))
|
||||
if self.lstm_use_prev_reward:
|
||||
prev_a_r.append(
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
|
||||
[-1, 1]))
|
||||
|
||||
if prev_a_r:
|
||||
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
|
||||
|
||||
max_seq_len = tf.shape(wrapped_out)[0] // tf.shape(
|
||||
input_dict["seq_lens"])[0]
|
||||
wrapped_out_plus_time_dim = add_time_dimension(
|
||||
wrapped_out, max_seq_len=max_seq_len, framework="tf")
|
||||
model_out, value_out, h, c = self._rnn_model([
|
||||
wrapped_out_plus_time_dim, input_dict["seq_lens"],
|
||||
input_dict["state_in_0"], input_dict["state_in_1"]
|
||||
])
|
||||
model_out_no_time_dim = tf.reshape(
|
||||
model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0))
|
||||
return model_out_no_time_dim, [h, c], {
|
||||
SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, List
|
||||
import gym
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
|
@ -12,6 +12,7 @@ from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
# TODO: (sven) obsolete this class once we only support native keras models.
|
||||
class VisionNetwork(TFModelV2):
|
||||
"""Generic vision network implemented in ModelV2 API.
|
||||
|
||||
|
@ -19,10 +20,6 @@ class VisionNetwork(TFModelV2):
|
|||
via the config keys:
|
||||
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
|
||||
`post_fcnet_activation`: Activation function to use for this FC stack.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
|
@ -245,3 +242,238 @@ class VisionNetwork(TFModelV2):
|
|||
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
class Keras_VisionNetwork(tf.keras.Model if tf else object):
|
||||
"""Generic vision network implemented in tf keras.
|
||||
|
||||
An additional post-conv fully connected stack can be added and configured
|
||||
via the config keys:
|
||||
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
|
||||
`post_fcnet_activation`: Activation function to use for this FC stack.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: Optional[int] = None,
|
||||
*,
|
||||
name: str = "",
|
||||
conv_filters: Optional[Sequence[Sequence[int]]] = None,
|
||||
conv_activation: Optional[str] = None,
|
||||
post_fcnet_hiddens: Optional[Sequence[int]] = (),
|
||||
post_fcnet_activation: Optional[str] = None,
|
||||
no_final_linear: bool = False,
|
||||
vf_share_layers: bool = False,
|
||||
free_log_std: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(name=name)
|
||||
|
||||
if not conv_filters:
|
||||
conv_filters = get_filter_config(input_space.shape)
|
||||
assert len(conv_filters) > 0,\
|
||||
"Must provide at least 1 entry in `conv_filters`!"
|
||||
|
||||
conv_activation = get_activation_fn(conv_activation, framework="tf")
|
||||
post_fcnet_activation = get_activation_fn(
|
||||
post_fcnet_activation, framework="tf")
|
||||
|
||||
self.traj_view_framestacking = False
|
||||
|
||||
# Perform Atari framestacking via traj. view API.
|
||||
num_framestacks = kwargs.get("num_framestacks")
|
||||
if num_framestacks != "auto" and num_framestacks and \
|
||||
num_framestacks > 1:
|
||||
input_shape = input_space.shape + (num_framestacks, )
|
||||
self.data_format = "channels_first"
|
||||
self.traj_view_framestacking = True
|
||||
else:
|
||||
input_shape = input_space.shape
|
||||
self.data_format = "channels_last"
|
||||
|
||||
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
|
||||
last_layer = inputs
|
||||
# Whether the last layer is the output of a Flattened (rather than
|
||||
# a n x (1,1) Conv2D).
|
||||
self.last_layer_is_flattened = False
|
||||
|
||||
# Build the action layers
|
||||
for i, (out_size, kernel, stride) in enumerate(conv_filters[:-1], 1):
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=conv_activation,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv{}".format(i))(last_layer)
|
||||
|
||||
out_size, kernel, stride = conv_filters[-1]
|
||||
|
||||
# No final linear: Last layer has activation function and exits with
|
||||
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
|
||||
# on `post_fcnet_...` settings).
|
||||
if no_final_linear and num_outputs:
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size if post_fcnet_hiddens else num_outputs,
|
||||
kernel,
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=conv_activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
|
||||
if post_fcnet_hiddens else
|
||||
[])
|
||||
for i, out_size in enumerate(layer_sizes):
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i),
|
||||
activation=post_fcnet_activation,
|
||||
kernel_initializer=normc_initializer(1.0))(last_layer)
|
||||
|
||||
# Finish network normally (w/o overriding last layer size with
|
||||
# `num_outputs`), then add another linear one of size `num_outputs`.
|
||||
else:
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=conv_activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
name="conv{}".format(len(conv_filters)))(last_layer)
|
||||
|
||||
# num_outputs defined. Use that to create an exact
|
||||
# `num_output`-sized (1,1)-Conv2D.
|
||||
if num_outputs:
|
||||
if post_fcnet_hiddens:
|
||||
last_cnn = last_layer = tf.keras.layers.Conv2D(
|
||||
post_fcnet_hiddens[0], [1, 1],
|
||||
activation=post_fcnet_activation,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
|
||||
[num_outputs]):
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i + 1),
|
||||
activation=post_fcnet_activation
|
||||
if i < len(post_fcnet_hiddens) - 1 else None,
|
||||
kernel_initializer=normc_initializer(1.0))(
|
||||
last_layer)
|
||||
else:
|
||||
last_cnn = last_layer = tf.keras.layers.Conv2D(
|
||||
num_outputs, [1, 1],
|
||||
activation=None,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
|
||||
if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1:
|
||||
raise ValueError(
|
||||
"Given `conv_filters` ({}) do not result in a [B, 1, "
|
||||
"1, {} (`num_outputs`)] shape (but in {})! Please "
|
||||
"adjust your Conv2D stack such that the dims 1 and 2 "
|
||||
"are both 1.".format(
|
||||
self.model_config["conv_filters"], num_outputs,
|
||||
list(last_cnn.shape)))
|
||||
|
||||
# num_outputs not known -> Flatten.
|
||||
else:
|
||||
self.last_layer_is_flattened = True
|
||||
last_layer = tf.keras.layers.Flatten(
|
||||
data_format="channels_last")(last_layer)
|
||||
|
||||
# Add (optional) post-fc-stack after last Conv2D layer.
|
||||
for i, out_size in enumerate(post_fcnet_hiddens):
|
||||
last_layer = tf.keras.layers.Dense(
|
||||
out_size,
|
||||
name="post_fcnet_{}".format(i),
|
||||
activation=post_fcnet_activation,
|
||||
kernel_initializer=normc_initializer(1.0))(last_layer)
|
||||
logits_out = last_layer
|
||||
|
||||
# Build the value layers
|
||||
if vf_share_layers:
|
||||
if not self.last_layer_is_flattened:
|
||||
last_layer = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
value_out = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(last_layer)
|
||||
else:
|
||||
# build a parallel set of hidden layers for the value net
|
||||
last_layer = inputs
|
||||
for i, (out_size, kernel, stride) in enumerate(
|
||||
conv_filters[:-1], 1):
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=conv_activation,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_value_{}".format(i))(last_layer)
|
||||
out_size, kernel, stride = conv_filters[-1]
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
out_size,
|
||||
kernel,
|
||||
strides=stride
|
||||
if isinstance(stride, (list, tuple)) else (stride, stride),
|
||||
activation=conv_activation,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
name="conv_value_{}".format(len(conv_filters)))(last_layer)
|
||||
last_layer = tf.keras.layers.Conv2D(
|
||||
1, [1, 1],
|
||||
activation=None,
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_value_out")(last_layer)
|
||||
value_out = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
|
||||
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
|
||||
|
||||
# Optional: framestacking obs/new_obs for Atari.
|
||||
if self.traj_view_framestacking:
|
||||
from_ = num_framestacks - 1
|
||||
self.view_requirements[SampleBatch.OBS].shift = \
|
||||
"-{}:0".format(from_)
|
||||
self.view_requirements[SampleBatch.OBS].shift_from = -from_
|
||||
self.view_requirements[SampleBatch.OBS].shift_to = 0
|
||||
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
|
||||
data_col=SampleBatch.OBS,
|
||||
shift="-{}:1".format(from_ - 1),
|
||||
space=self.view_requirements[SampleBatch.OBS].space,
|
||||
used_for_compute_actions=False,
|
||||
)
|
||||
|
||||
def call(self, input_dict: SampleBatch) -> \
|
||||
(TensorType, List[TensorType], Dict[str, TensorType]):
|
||||
obs = input_dict["obs"]
|
||||
if self.data_format == "channels_first":
|
||||
obs = tf.transpose(obs, [0, 2, 3, 1])
|
||||
# Explicit cast to float32 needed in eager.
|
||||
model_out, self._value_out = self.base_model(tf.cast(obs, tf.float32))
|
||||
state = [v for k, v in input_dict.items() if k.startswith("state_in_")]
|
||||
extra_outs = {SampleBatch.VF_PREDS: tf.reshape(self._value_out, [-1])}
|
||||
# Our last layer is already flat.
|
||||
if self.last_layer_is_flattened:
|
||||
return model_out, state, extra_outs
|
||||
# Last layer is a n x [1,1] Conv2D -> Flatten.
|
||||
else:
|
||||
return tf.squeeze(model_out, axis=[1, 2]), state, extra_outs
|
||||
|
|
|
@ -456,7 +456,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||
batch_size=32)
|
||||
|
||||
return SampleBatch(input_dict, _seq_lens=self._seq_lens), dummy_batch
|
||||
return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
|
||||
|
||||
def _initialize_loss_from_dummy_batch(
|
||||
self, auto_remove_unneeded_view_reqs: bool = True,
|
||||
|
@ -503,8 +503,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
dict(self._input_dict, **self._loss_input_dict))
|
||||
|
||||
if self._state_inputs:
|
||||
train_batch.seq_lens = self._seq_lens
|
||||
self._loss_input_dict.update({"seq_lens": train_batch.seq_lens})
|
||||
train_batch["seq_lens"] = self._seq_lens
|
||||
self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]})
|
||||
|
||||
self._loss_input_dict.update({k: v for k, v in train_batch.items()})
|
||||
|
||||
|
@ -522,8 +522,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
|
||||
TFPolicy._initialize_loss(self, loss, [
|
||||
(k, v) for k, v in train_batch.items() if k in all_accessed_keys
|
||||
] + ([("seq_lens", train_batch.seq_lens)]
|
||||
if train_batch.seq_lens is not None else []))
|
||||
] + ([("seq_lens", train_batch["seq_lens"])]
|
||||
if "seq_lens" in train_batch else []))
|
||||
|
||||
if "is_training" in self._loss_input_dict:
|
||||
del self._loss_input_dict["is_training"]
|
||||
|
|
|
@ -27,8 +27,6 @@ logger = logging.getLogger(__name__)
|
|||
def _convert_to_tf(x, dtype=None):
|
||||
if isinstance(x, SampleBatch):
|
||||
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
|
||||
if x.seq_lens is not None:
|
||||
dict_["seq_lens"] = x.seq_lens
|
||||
return tf.nest.map_structure(_convert_to_tf, dict_)
|
||||
elif isinstance(x, Policy):
|
||||
return x
|
||||
|
@ -364,8 +362,6 @@ def build_eager_tf_policy(
|
|||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
else:
|
||||
postprocessed_batch["seq_lens"] = postprocessed_batch.seq_lens
|
||||
|
||||
self._is_training = True
|
||||
postprocessed_batch["is_training"] = True
|
||||
|
@ -526,6 +522,9 @@ def build_eager_tf_policy(
|
|||
raise e
|
||||
elif isinstance(self.model, tf.keras.Model):
|
||||
input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
|
||||
if state_batches and "state_in_0" not in input_dict:
|
||||
for i, s in enumerate(state_batches):
|
||||
input_dict[f"state_in_{i}"] = s
|
||||
self._lazy_tensor_dict(input_dict)
|
||||
dist_inputs, state_out, extra_fetches = \
|
||||
self.model(input_dict)
|
||||
|
|
|
@ -11,13 +11,14 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
unbatch
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TrainerConfigDict, Tuple, Union
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -648,13 +649,13 @@ class Policy(metaclass=ABCMeta):
|
|||
i += 1
|
||||
seq_len = sample_batch_size // B
|
||||
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
|
||||
postprocessed_batch.seq_lens = seq_lens
|
||||
postprocessed_batch["seq_lens"] = seq_lens
|
||||
# Switch on lazy to-tensor conversion on `postprocessed_batch`.
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
# Calling loss, so set `is_training` to True.
|
||||
train_batch.is_training = True
|
||||
if seq_lens is not None:
|
||||
train_batch.seq_lens = seq_lens
|
||||
train_batch["seq_lens"] = seq_lens
|
||||
train_batch.count = self._dummy_batch.count
|
||||
# Call the loss function, if it exists.
|
||||
if self._loss is not None:
|
||||
|
@ -761,6 +762,7 @@ class Policy(metaclass=ABCMeta):
|
|||
"""
|
||||
self._model_init_state_automatically_added = True
|
||||
model = getattr(self, "model", None)
|
||||
|
||||
obj = model or self
|
||||
if model and not hasattr(model, "view_requirements"):
|
||||
model.view_requirements = {
|
||||
|
@ -773,7 +775,17 @@ class Policy(metaclass=ABCMeta):
|
|||
obj.get_initial_state):
|
||||
init_state = obj.get_initial_state()
|
||||
else:
|
||||
obj.get_initial_state = lambda: []
|
||||
# Add this functionality automatically for new native model API.
|
||||
if tf and isinstance(model, tf.keras.Model) and \
|
||||
"state_in_0" not in view_reqs:
|
||||
obj.get_initial_state = lambda: [
|
||||
np.zeros_like(view_req.space.sample())
|
||||
for k, view_req in model.view_requirements.items()
|
||||
if k.startswith("state_in_")]
|
||||
else:
|
||||
obj.get_initial_state = lambda: []
|
||||
if "state_in_0" in view_reqs:
|
||||
self.is_recurrent = lambda: True
|
||||
for i, state in enumerate(init_state):
|
||||
space = Box(-1.0, 1.0, shape=state.shape) if \
|
||||
hasattr(state, "shape") else state
|
||||
|
|
|
@ -75,8 +75,8 @@ def pad_batch_to_sequences_of_same_size(
|
|||
if "state_in_0" in batch or "state_out_0" in batch:
|
||||
# Check, whether the state inputs have already been reduced to their
|
||||
# init values at the beginning of each max_seq_len chunk.
|
||||
if batch.seq_lens is not None and \
|
||||
len(batch["state_in_0"]) == len(batch.seq_lens):
|
||||
if batch.get("seq_lens") is not None and \
|
||||
len(batch["state_in_0"]) == len(batch["seq_lens"]):
|
||||
states_already_reduced_to_init = True
|
||||
|
||||
# RNN (or single timestep state-in): Set the max dynamically.
|
||||
|
@ -113,7 +113,7 @@ def pad_batch_to_sequences_of_same_size(
|
|||
episode_ids=batch.get(SampleBatch.EPS_ID),
|
||||
unroll_ids=batch.get(SampleBatch.UNROLL_ID),
|
||||
agent_indices=batch.get(SampleBatch.AGENT_INDEX),
|
||||
seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")),
|
||||
seq_lens=batch.get("seq_lens"),
|
||||
max_seq_len=max_seq_len,
|
||||
dynamic_max=dynamic_max,
|
||||
states_already_reduced_to_init=states_already_reduced_to_init,
|
||||
|
@ -318,12 +318,12 @@ def timeslice_along_seq_lens_with_overlap(
|
|||
zero_init_states=True) -> List["SampleBatch"]:
|
||||
"""Slices batch along `seq_lens` (each seq-len item produces one batch).
|
||||
|
||||
Asserts that seq_lens is given or sample_batch.seq_lens is not None.
|
||||
Asserts that seq_lens is given or sample_batch["seq_lens"] is not None.
|
||||
|
||||
Args:
|
||||
sample_batch (SampleBatch): The SampleBatch to timeslice.
|
||||
seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
|
||||
at. If None, use `sample_batch.seq_lens`.
|
||||
at. If None, use `sample_batch["seq_lens"]`.
|
||||
zero_pad_max_seq_len (int): If >0, already zero-pad the resulting
|
||||
slices up to this length. NOTE: This max-len will include the
|
||||
additional timesteps gained via setting pre_overlap or
|
||||
|
@ -354,10 +354,10 @@ def timeslice_along_seq_lens_with_overlap(
|
|||
# count (makes sure each slice has exactly length 10).
|
||||
"""
|
||||
if seq_lens is None:
|
||||
seq_lens = sample_batch.seq_lens
|
||||
seq_lens = sample_batch.get("seq_lens")
|
||||
assert seq_lens is not None and len(seq_lens) > 0, \
|
||||
"Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
|
||||
# Generate n slices based on self.seq_lens.
|
||||
# Generate n slices based on seq_lens.
|
||||
start = 0
|
||||
slices = []
|
||||
for seq_len in seq_lens:
|
||||
|
@ -391,10 +391,13 @@ def timeslice_along_seq_lens_with_overlap(
|
|||
shape=(zero_length, ) + v.shape[1:], dtype=v.dtype),
|
||||
v[data_begin:end]
|
||||
])
|
||||
for k, v in sample_batch.items()
|
||||
for k, v in sample_batch.items() if k != "seq_lens"
|
||||
}
|
||||
else:
|
||||
data = {k: v[begin:end] for k, v in sample_batch.items()}
|
||||
data = {
|
||||
k: v[begin:end]
|
||||
for k, v in sample_batch.items() if k != "seq_lens"
|
||||
}
|
||||
|
||||
if zero_init_states_:
|
||||
i = 0
|
||||
|
@ -416,7 +419,7 @@ def timeslice_along_seq_lens_with_overlap(
|
|||
i += 1
|
||||
key = "state_in_{}".format(i)
|
||||
|
||||
timeslices.append(SampleBatch(data, _seq_lens=[end - begin]))
|
||||
timeslices.append(SampleBatch(data, seq_lens=[end - begin]))
|
||||
|
||||
# Zero-pad each slice if necessary.
|
||||
if zero_pad_max_seq_len > 0:
|
||||
|
|
|
@ -64,19 +64,7 @@ class SampleBatch(dict):
|
|||
# Possible seq_lens (TxB or BxT) setup.
|
||||
self.time_major = kwargs.pop("_time_major", None)
|
||||
|
||||
self.seq_lens = kwargs.pop("_seq_lens", kwargs.pop("seq_lens", None))
|
||||
if self.seq_lens is None and len(args) > 0 and isinstance(
|
||||
args[0], dict):
|
||||
self.seq_lens = args[0].pop("_seq_lens", args[0].pop(
|
||||
"seq_lens", None))
|
||||
if isinstance(self.seq_lens, list):
|
||||
self.seq_lens = np.array(self.seq_lens, dtype=np.int32)
|
||||
|
||||
self.max_seq_len = kwargs.pop("_max_seq_len", None)
|
||||
if self.max_seq_len is None and self.seq_lens is not None and \
|
||||
not (tf and tf.is_tensor(self.seq_lens)) and \
|
||||
len(self.seq_lens) > 0:
|
||||
self.max_seq_len = max(self.seq_lens)
|
||||
self.zero_padded = kwargs.pop("_zero_padded", False)
|
||||
self.is_training = kwargs.pop("_is_training", None)
|
||||
|
||||
|
@ -88,14 +76,25 @@ class SampleBatch(dict):
|
|||
self.added_keys = set()
|
||||
self.deleted_keys = set()
|
||||
self.intercepted_values = {}
|
||||
|
||||
self.get_interceptor = None
|
||||
|
||||
# Clear out None seq-lens.
|
||||
if self.get("seq_lens") is None or self.get("seq_lens") == []:
|
||||
self.pop("seq_lens", None)
|
||||
# Numpyfy seq_lens if list.
|
||||
elif isinstance(self.get("seq_lens"), list):
|
||||
self["seq_lens"] = np.array(self["seq_lens"], dtype=np.int32)
|
||||
|
||||
if self.max_seq_len is None and self.get("seq_lens") is not None and \
|
||||
not (tf and tf.is_tensor(self["seq_lens"])) and \
|
||||
len(self["seq_lens"]) > 0:
|
||||
self.max_seq_len = max(self["seq_lens"])
|
||||
|
||||
if self.is_training is None:
|
||||
self.is_training = self.pop("is_training", False)
|
||||
|
||||
lengths = []
|
||||
copy_ = {k: v for k, v in self.items()}
|
||||
copy_ = {k: v for k, v in self.items() if k != "seq_lens"}
|
||||
for k, v in copy_.items():
|
||||
assert isinstance(k, str), self
|
||||
len_ = len(v) if isinstance(
|
||||
|
@ -105,10 +104,10 @@ class SampleBatch(dict):
|
|||
if isinstance(v, list):
|
||||
self[k] = np.array(v)
|
||||
|
||||
if self.seq_lens is not None and \
|
||||
not (tf and tf.is_tensor(self.seq_lens)) and \
|
||||
len(self.seq_lens) > 0:
|
||||
self.count = sum(self.seq_lens)
|
||||
if self.get("seq_lens") is not None and \
|
||||
not (tf and tf.is_tensor(self["seq_lens"])) and \
|
||||
len(self["seq_lens"]) > 0:
|
||||
self.count = sum(self["seq_lens"])
|
||||
else:
|
||||
self.count = lengths[0] if lengths else 0
|
||||
|
||||
|
@ -142,8 +141,8 @@ class SampleBatch(dict):
|
|||
if zero_padded:
|
||||
assert s.max_seq_len == max_seq_len
|
||||
concat_samples.append(s)
|
||||
if s.seq_lens is not None:
|
||||
seq_lens.extend(s.seq_lens)
|
||||
if s.get("seq_lens") is not None:
|
||||
seq_lens.extend(s["seq_lens"])
|
||||
|
||||
out = {}
|
||||
for k in concat_samples[0].keys():
|
||||
|
@ -152,7 +151,7 @@ class SampleBatch(dict):
|
|||
time_major=concat_samples[0].time_major)
|
||||
return SampleBatch(
|
||||
out,
|
||||
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
||||
seq_lens=seq_lens,
|
||||
_time_major=concat_samples[0].time_major,
|
||||
_zero_padded=zero_padded,
|
||||
_max_seq_len=max_seq_len,
|
||||
|
@ -202,7 +201,7 @@ class SampleBatch(dict):
|
|||
if isinstance(v, np.ndarray) else v
|
||||
for (k, v) in self.items()
|
||||
},
|
||||
_seq_lens=self.seq_lens,
|
||||
seq_lens=self.get("seq_lens"),
|
||||
)
|
||||
copy_.set_get_interceptor(self.get_interceptor)
|
||||
return copy_
|
||||
|
@ -304,7 +303,7 @@ class SampleBatch(dict):
|
|||
SampleBatch: A new SampleBatch, which has a slice of this batch's
|
||||
data.
|
||||
"""
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0:
|
||||
if start < 0:
|
||||
data = {
|
||||
k: np.concatenate([
|
||||
|
@ -312,15 +311,18 @@ class SampleBatch(dict):
|
|||
shape=(-start, ) + v.shape[1:], dtype=v.dtype),
|
||||
v[0:end]
|
||||
])
|
||||
for k, v in self.items()
|
||||
for k, v in self.items() if k != "seq_lens"
|
||||
}
|
||||
else:
|
||||
data = {k: v[start:end] for k, v in self.items()}
|
||||
data = {
|
||||
k: v[start:end]
|
||||
for k, v in self.items() if k != "seq_lens"
|
||||
}
|
||||
# Fix state_in_x data.
|
||||
count = 0
|
||||
state_start = None
|
||||
seq_lens = None
|
||||
for i, seq_len in enumerate(self.seq_lens):
|
||||
for i, seq_len in enumerate(self["seq_lens"]):
|
||||
count += seq_len
|
||||
if count >= end:
|
||||
state_idx = 0
|
||||
|
@ -331,7 +333,7 @@ class SampleBatch(dict):
|
|||
data[state_key] = self[state_key][state_start:i + 1]
|
||||
state_idx += 1
|
||||
state_key = "state_in_{}".format(state_idx)
|
||||
seq_lens = list(self.seq_lens[state_start:i]) + [
|
||||
seq_lens = list(self["seq_lens"][state_start:i]) + [
|
||||
seq_len - (count - end)
|
||||
]
|
||||
if start < 0:
|
||||
|
@ -343,14 +345,13 @@ class SampleBatch(dict):
|
|||
|
||||
return SampleBatch(
|
||||
data,
|
||||
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
||||
seq_lens=seq_lens,
|
||||
_time_major=self.time_major,
|
||||
)
|
||||
else:
|
||||
return SampleBatch(
|
||||
{k: v[start:end]
|
||||
for k, v in self.items()},
|
||||
_seq_lens=None,
|
||||
_is_training=self.is_training,
|
||||
_time_major=self.time_major)
|
||||
|
||||
|
@ -383,8 +384,9 @@ class SampleBatch(dict):
|
|||
data. If False, leave `state_in_x` keys as-is.
|
||||
"""
|
||||
for col in self.keys():
|
||||
# Skip state in columns.
|
||||
if exclude_states is True and col.startswith("state_in_"):
|
||||
# Skip "state_in_..." columns and "seq_lens".
|
||||
if (exclude_states is True and col.startswith("state_in_")) or \
|
||||
col == "seq_lens":
|
||||
continue
|
||||
|
||||
f = self[col]
|
||||
|
@ -395,7 +397,7 @@ class SampleBatch(dict):
|
|||
if f.shape[0] == max_seq_len:
|
||||
continue
|
||||
# Generate zero-filled primer of len=max_seq_len.
|
||||
length = len(self.seq_lens) * max_seq_len
|
||||
length = len(self["seq_lens"]) * max_seq_len
|
||||
if f.dtype == np.object or f.dtype.type is np.str_:
|
||||
f_pad = [None] * length
|
||||
else:
|
||||
|
@ -403,7 +405,7 @@ class SampleBatch(dict):
|
|||
f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype)
|
||||
# Fill primer with data.
|
||||
f_pad_base = f_base = 0
|
||||
for len_ in self.seq_lens:
|
||||
for len_ in self["seq_lens"]:
|
||||
f_pad[f_pad_base:f_pad_base + len_] = f[f_base:f_base + len_]
|
||||
f_pad_base += max_seq_len
|
||||
f_base += len_
|
||||
|
@ -441,7 +443,7 @@ class SampleBatch(dict):
|
|||
Returns:
|
||||
TensorType: The data under the given key.
|
||||
"""
|
||||
if not hasattr(self, key):
|
||||
if not hasattr(self, key) and key in self:
|
||||
self.accessed_keys.add(key)
|
||||
|
||||
# Backward compatibility for when "input-dicts" were used.
|
||||
|
@ -452,13 +454,6 @@ class SampleBatch(dict):
|
|||
new="SampleBatch.is_training",
|
||||
error=False)
|
||||
return self.is_training
|
||||
elif key == "seq_lens":
|
||||
if self.get_interceptor is not None and self.seq_lens is not None:
|
||||
if "seq_lens" not in self.intercepted_values:
|
||||
self.intercepted_values["seq_lens"] = self.get_interceptor(
|
||||
self.seq_lens)
|
||||
return self.intercepted_values["seq_lens"]
|
||||
return self.seq_lens
|
||||
|
||||
value = dict.__getitem__(self, key)
|
||||
if self.get_interceptor is not None:
|
||||
|
@ -475,12 +470,9 @@ class SampleBatch(dict):
|
|||
key (str): The column name to set a value for.
|
||||
item (TensorType): The data to insert.
|
||||
"""
|
||||
if key == "seq_lens":
|
||||
self.seq_lens = item
|
||||
return
|
||||
# Defend against creating SampleBatch via pickle (no property
|
||||
# `added_keys` and first item is already set).
|
||||
elif not hasattr(self, "added_keys"):
|
||||
if not hasattr(self, "added_keys"):
|
||||
dict.__setitem__(self, key, item)
|
||||
return
|
||||
|
||||
|
@ -548,15 +540,15 @@ class SampleBatch(dict):
|
|||
def _get_slice_indices(self, slice_size):
|
||||
i = 0
|
||||
slices = []
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
assert np.all(self.seq_lens < slice_size), \
|
||||
if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0:
|
||||
assert np.all(self["seq_lens"] < slice_size), \
|
||||
"ERROR: `slice_size` must be larger than the max. seq-len " \
|
||||
"in the batch!"
|
||||
start_pos = 0
|
||||
current_slize_size = 0
|
||||
idx = 0
|
||||
while idx < len(self.seq_lens):
|
||||
seq_len = self.seq_lens[idx]
|
||||
while idx < len(self["seq_lens"]):
|
||||
seq_len = self["seq_lens"][idx]
|
||||
current_slize_size += seq_len
|
||||
# Complete minibatch -> Append to slices.
|
||||
if current_slize_size >= slice_size:
|
||||
|
@ -645,7 +637,7 @@ class SampleBatch(dict):
|
|||
input_dict[view_col] = self[data_col][
|
||||
index:index + 1 if index != -1 else None]
|
||||
|
||||
return SampleBatch(input_dict, _seq_lens=np.array([1], dtype=np.int32))
|
||||
return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -34,14 +34,14 @@ class TestSampleBatch(unittest.TestCase):
|
|||
|
||||
# Add an item and check, whether it's in the "added" list.
|
||||
batch["d"] = np.array(1)
|
||||
assert batch.added_keys == {"d"}
|
||||
assert batch.added_keys == {"d"}, batch.added_keys
|
||||
# Access two keys and check, whether they are in the
|
||||
# "accessed" list.
|
||||
print(batch["a"], batch["b"])
|
||||
assert batch.accessed_keys == {"a", "b"}
|
||||
assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys
|
||||
# Delete a key and check, whether it's in the "deleted" list.
|
||||
del batch["c"]
|
||||
assert batch.deleted_keys == {"c"}
|
||||
assert batch.deleted_keys == {"c"}, batch.deleted_keys
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -905,6 +905,7 @@ class TFPolicy(Policy):
|
|||
Feed dict of data.
|
||||
"""
|
||||
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
if not isinstance(train_batch,
|
||||
SampleBatch) or not train_batch.zero_padded:
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
|
@ -915,10 +916,6 @@ class TFPolicy(Policy):
|
|||
feature_keys=list(self._loss_input_dict_no_rnn.keys()),
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
else:
|
||||
train_batch["seq_lens"] = train_batch.seq_lens
|
||||
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
|
||||
# Mark the batch as "is_training" so the Model can use this
|
||||
# information.
|
||||
|
|
|
@ -477,8 +477,6 @@ class TorchPolicy(Policy):
|
|||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
else:
|
||||
postprocessed_batch["seq_lens"] = postprocessed_batch.seq_lens
|
||||
|
||||
# Mark the batch as "is_training" so the Model can use this
|
||||
# information.
|
||||
|
@ -712,12 +710,13 @@ class TorchPolicy(Policy):
|
|||
if "state_in_0" not in self._dummy_batch:
|
||||
self._dummy_batch["state_in_0"] = \
|
||||
self._dummy_batch["seq_lens"] = np.array([1.0])
|
||||
seq_lens = self._dummy_batch["seq_lens"]
|
||||
|
||||
state_ins = []
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in self._dummy_batch:
|
||||
state_ins.append(self._dummy_batch["state_in_{}".format(i)])
|
||||
i += 1
|
||||
seq_lens = self._dummy_batch["seq_lens"]
|
||||
dummy_inputs = {
|
||||
k: self._dummy_batch[k]
|
||||
for k in self._dummy_batch.keys() if k != "is_training"
|
||||
|
|
|
@ -23,7 +23,7 @@ class TestAttentionNetLearning(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(num_cpus=5, ignore_reinit_error=True)
|
||||
ray.init(num_cpus=5)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
|
|
|
@ -63,7 +63,7 @@ def minibatches(samples, sgd_minibatch_size):
|
|||
raise NotImplementedError(
|
||||
"Minibatching not implemented for multi-agent in simple mode")
|
||||
|
||||
# Replace with `if samples.seq_lens` check.
|
||||
# Replace with `if samples["seq_lens"]` check.
|
||||
if "state_in_0" in samples or "state_out_0" in samples:
|
||||
if log_once("not_shuffling_rnn_data_in_simple_mode"):
|
||||
logger.warning("Not time-shuffling RNN data for SGD.")
|
||||
|
|
|
@ -288,6 +288,7 @@ def check_compute_single_action(trainer,
|
|||
pol = trainer.get_policy()
|
||||
except AttributeError:
|
||||
pol = trainer.policy
|
||||
model = pol.model
|
||||
|
||||
action_space = pol.action_space
|
||||
|
||||
|
@ -328,7 +329,14 @@ def check_compute_single_action(trainer,
|
|||
obs = np.clip(obs, -1.0, 1.0)
|
||||
state_in = None
|
||||
if include_state:
|
||||
state_in = pol.model.get_initial_state()
|
||||
state_in = model.get_initial_state()
|
||||
if not state_in:
|
||||
state_in = []
|
||||
i = 0
|
||||
while f"state_in_{i}" in model.view_requirements:
|
||||
state_in.append(model.view_requirements[
|
||||
f"state_in_{i}"].space.sample())
|
||||
i += 1
|
||||
action_in = action_space.sample() \
|
||||
if include_prev_action_reward else None
|
||||
reward_in = 1.0 if include_prev_action_reward else None
|
||||
|
|
Loading…
Add table
Reference in a new issue