[RLlib] Support native tf.keras.Models (part 2) - Default keras models for Vision/RNN/Attention. (#15273)

This commit is contained in:
Sven Mika 2021-04-30 19:26:30 +02:00 committed by GitHub
parent bdbf39f9d5
commit e973b726c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 924 additions and 143 deletions

View file

@ -531,6 +531,12 @@ py_test(
size = "large", size = "large",
srcs = ["agents/dqn/tests/test_dqn.py"] srcs = ["agents/dqn/tests/test_dqn.py"]
) )
py_test(
name = "test_r2d2",
tags = ["agents_dir"],
size = "large",
srcs = ["agents/dqn/tests/test_r2d2.py"]
)
py_test( py_test(
name = "test_simple_q", name = "test_simple_q",
tags = ["agents_dir"], tags = ["agents_dir"],

View file

@ -36,7 +36,7 @@ class TestR2D2(unittest.TestCase):
config["lr"] = 5e-4 config["lr"] = 5e-4
config["exploration_config"]["epsilon_timesteps"] = 100000 config["exploration_config"]["epsilon_timesteps"] = 100000
num_iterations = 2 num_iterations = 1
# Test building an R2D2 agent in all frameworks. # Test building an R2D2 agent in all frameworks.
for _ in framework_iterator(config): for _ in framework_iterator(config):

View file

@ -82,7 +82,7 @@ class TestPPO(unittest.TestCase):
# Settings in case we use an LSTM. # Settings in case we use an LSTM.
config["model"]["lstm_cell_size"] = 10 config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20 config["model"]["max_seq_len"] = 20
# Use default-native keras model whenever possible. # Use default-native keras models whenever possible.
config["model"]["_use_default_native_models"] = True config["model"]["_use_default_native_models"] = True
config["train_batch_size"] = 128 config["train_batch_size"] = 128
@ -93,7 +93,7 @@ class TestPPO(unittest.TestCase):
for _ in framework_iterator(config): for _ in framework_iterator(config):
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]: for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env)) print("Env={}".format(env))
for lstm in [False, True]: for lstm in [True, False]:
print("LSTM={}".format(lstm)) print("LSTM={}".format(lstm))
config["model"]["use_lstm"] = lstm config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action"] = lstm config["model"]["lstm_use_prev_action"] = lstm

View file

@ -366,7 +366,7 @@ class _PolicyCollector:
this policy. this policy.
""" """
# Create batch from our buffers. # Create batch from our buffers.
batch = SampleBatch(self.buffers, _seq_lens=self.seq_lens) batch = SampleBatch(self.buffers, seq_lens=self.seq_lens)
# Clear buffers for future samples. # Clear buffers for future samples.
self.buffers.clear() self.buffers.clear()
# Reset agent steps to 0 and seq-lens to empty list. # Reset agent steps to 0 and seq-lens to empty list.

View file

@ -121,7 +121,7 @@ def compute_gae_for_sample_batch(
# Create an input dict according to the Model's requirements. # Create an input dict according to the Model's requirements.
input_dict = sample_batch.get_single_step_input_dict( input_dict = sample_batch.get_single_step_input_dict(
policy.model.view_requirements, index="last") policy.model.view_requirements, index="last")
last_r = policy._value(**input_dict, seq_lens=input_dict.seq_lens) last_r = policy._value(**input_dict)
# Adds the policy logits, VF preds, and advantages to the batch, # Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not. # using GAE ("generalized advantage estimation") or not.

View file

@ -26,10 +26,10 @@ class MyCallbacks(DefaultCallbacks):
@override(DefaultCallbacks) @override(DefaultCallbacks)
def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs): def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
assert train_batch.count == 201 assert train_batch.count == 201
assert sum(train_batch.seq_lens) == 201 assert sum(train_batch["seq_lens"]) == 201
for k, v in train_batch.items(): for k, v in train_batch.items():
if k == "state_in_0": if k == "state_in_0":
assert len(v) == len(train_batch.seq_lens) assert len(v) == len(train_batch["seq_lens"])
else: else:
assert len(v) == 201 assert len(v) == 201
current = None current = None

View file

@ -17,7 +17,8 @@ parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO") parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv") parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--num-cpus", type=int, default=3) parser.add_argument("--num-cpus", type=int, default=3)
parser.add_argument("--framework", choices=["tf", "torch"], default="tf") parser.add_argument(
"--framework", choices=["tf", "tf2", "tfe", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true") parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200) parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-timesteps", type=int, default=500000) parser.add_argument("--stop-timesteps", type=int, default=500000)
@ -48,6 +49,9 @@ if __name__ == "__main__":
"num_sgd_iter": 10, "num_sgd_iter": 10,
"vf_loss_coeff": 1e-5, "vf_loss_coeff": 1e-5,
"model": { "model": {
# Attention net wrapping (for tf) can already use the native keras
# model versions. For torch, this will have no effect.
"_use_default_native_models": True,
"use_attention": True, "use_attention": True,
"max_seq_len": 10, "max_seq_len": 10,
"attention_num_transformer_units": 1, "attention_num_transformer_units": 1,

View file

@ -36,11 +36,11 @@ class RNNModel(tf.keras.models.Model if tf else object):
def call(self, sample_batch): def call(self, sample_batch):
dense_out = self.dense(sample_batch["obs"]) dense_out = self.dense(sample_batch["obs"])
B = tf.shape(sample_batch.seq_lens)[0] B = tf.shape(sample_batch["seq_lens"])[0]
lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]]) lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]])
lstm_out, h, c = self.lstm( lstm_out, h, c = self.lstm(
inputs=lstm_in, inputs=lstm_in,
mask=tf.sequence_mask(sample_batch.seq_lens), mask=tf.sequence_mask(sample_batch["seq_lens"]),
initial_state=[ initial_state=[
sample_batch["state_in_0"], sample_batch["state_in_1"] sample_batch["state_in_0"], sample_batch["state_in_1"]
], ],

View file

@ -82,7 +82,7 @@ class TorchParametricActionsModel(DQNTorchModel):
model_config, name + "_action_embed") model_config, name + "_action_embed")
def forward(self, input_dict, state, seq_lens): def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation. # Extract the available act ions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"] avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"] action_mask = input_dict["obs"]["action_mask"]

View file

@ -39,8 +39,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
# If True, try to use a native (tf.keras.Model or torch.Module) default # If True, try to use a native (tf.keras.Model or torch.Module) default
# model instead of our built-in ModelV2 defaults. # model instead of our built-in ModelV2 defaults.
# If False (default), use "classic" ModelV2 default models. # If False (default), use "classic" ModelV2 default models.
# Note that this currently only works for framework != torch AND fully # Note that this currently only works for:
# connected default networks. # 1) framework != torch AND
# 2) fully connected and CNN default networks as well as
# auto-wrapped LSTM- and attention nets.
"_use_default_native_models": False, "_use_default_native_models": False,
# === Built-in options === # === Built-in options ===
@ -418,14 +420,25 @@ class ModelCatalog:
if model_config.get("use_lstm") or \ if model_config.get("use_lstm") or \
model_config.get("use_attention"): model_config.get("use_attention"):
from ray.rllib.models.tf.attention_net import \ from ray.rllib.models.tf.attention_net import \
AttentionWrapper AttentionWrapper, Keras_AttentionWrapper
from ray.rllib.models.tf.recurrent_net import LSTMWrapper from ray.rllib.models.tf.recurrent_net import \
LSTMWrapper, Keras_LSTMWrapper
wrapped_cls = model_cls wrapped_cls = model_cls
# Wrapped (custom) model is itself a keras Model ->
# wrap with keras LSTM/GTrXL (attention) wrappers.
if issubclass(wrapped_cls, tf.keras.Model):
model_cls = Keras_LSTMWrapper if \
model_config.get("use_lstm") else \
Keras_AttentionWrapper
model_config["wrapped_cls"] = wrapped_cls
# Wrapped (custom) model is ModelV2 ->
# wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
else:
forward = wrapped_cls.forward forward = wrapped_cls.forward
model_cls = ModelCatalog._wrap_if_needed( model_cls = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper wrapped_cls, LSTMWrapper if
if model_config.get("use_lstm") else AttentionWrapper) model_config.get("use_lstm") else AttentionWrapper)
model_cls._wrapped_forward = forward model_cls._wrapped_forward = forward
# Obsolete: Track and warn if vars were created but not # Obsolete: Track and warn if vars were created but not
@ -561,31 +574,41 @@ class ModelCatalog:
model_config.get("use_attention"): model_config.get("use_attention"):
from ray.rllib.models.tf.attention_net import \ from ray.rllib.models.tf.attention_net import \
AttentionWrapper AttentionWrapper, Keras_AttentionWrapper
from ray.rllib.models.tf.recurrent_net import LSTMWrapper from ray.rllib.models.tf.recurrent_net import LSTMWrapper, \
Keras_LSTMWrapper
wrapped_cls = v2_class wrapped_cls = v2_class
forward = wrapped_cls.forward
if model_config.get("use_lstm"): if model_config.get("use_lstm"):
if issubclass(wrapped_cls, tf.keras.Model):
v2_class = Keras_LSTMWrapper
model_config["wrapped_cls"] = wrapped_cls
else:
v2_class = ModelCatalog._wrap_if_needed( v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper) wrapped_cls, LSTMWrapper)
v2_class._wrapped_forward = wrapped_cls.forward
else:
if issubclass(wrapped_cls, tf.keras.Model):
v2_class = Keras_AttentionWrapper
model_config["wrapped_cls"] = wrapped_cls
else: else:
v2_class = ModelCatalog._wrap_if_needed( v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, AttentionWrapper) wrapped_cls, AttentionWrapper)
v2_class._wrapped_forward = wrapped_cls.forward
v2_class._wrapped_forward = forward
# Wrap in the requested interface. # Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface) wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
if issubclass(wrapper, tf.keras.Model): if issubclass(wrapper, tf.keras.Model):
return wrapper( model = wrapper(
input_space=obs_space, input_space=obs_space,
action_space=action_space, action_space=action_space,
num_outputs=num_outputs, num_outputs=num_outputs,
name=name, name=name,
**dict(model_kwargs, **model_config), **dict(model_kwargs, **model_config),
) )
return model
return wrapper(obs_space, action_space, num_outputs, model_config, return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs) name, **model_kwargs)
@ -759,13 +782,15 @@ class ModelCatalog:
VisionNet = None VisionNet = None
ComplexNet = None ComplexNet = None
Keras_FCNet = None Keras_FCNet = None
Keras_VisionNet = None
if framework in ["tf2", "tf", "tfe"]: if framework in ["tf2", "tf", "tfe"]:
from ray.rllib.models.tf.fcnet import \ from ray.rllib.models.tf.fcnet import \
FullyConnectedNetwork as FCNet, \ FullyConnectedNetwork as FCNet, \
Keras_FullyConnectedNetwork as Keras_FCNet Keras_FullyConnectedNetwork as Keras_FCNet
from ray.rllib.models.tf.visionnet import \ from ray.rllib.models.tf.visionnet import \
VisionNetwork as VisionNet VisionNetwork as VisionNet, \
Keras_VisionNetwork as Keras_VisionNet
from ray.rllib.models.tf.complex_input_net import \ from ray.rllib.models.tf.complex_input_net import \
ComplexInputNetwork as ComplexNet ComplexInputNetwork as ComplexNet
elif framework == "torch": elif framework == "torch":
@ -802,10 +827,8 @@ class ModelCatalog:
len(input_space.shape) == 1 or ( len(input_space.shape) == 1 or (
len(input_space.shape) == 2 and ( len(input_space.shape) == 2 and (
num_framestacks == "auto" or num_framestacks <= 1)): num_framestacks == "auto" or num_framestacks <= 1)):
# Keras native requested AND no auto-rnn-wrapping AND . # Keras native requested AND no auto-rnn-wrapping.
if model_config.get("_use_default_native_models") and \ if model_config.get("_use_default_native_models") and Keras_FCNet:
Keras_FCNet and not model_config.get("use_lstm") and \
not model_config.get("use_attention"):
return Keras_FCNet return Keras_FCNet
# Classic ModelV2 FCNet. # Classic ModelV2 FCNet.
else: else:
@ -815,6 +838,8 @@ class ModelCatalog:
raise NotImplementedError("No non-FC default net for JAX yet!") raise NotImplementedError("No non-FC default net for JAX yet!")
# Last resort: Conv2D stack for single image spaces. # Last resort: Conv2D stack for single image spaces.
if model_config.get("_use_default_native_models") and Keras_VisionNet:
return Keras_VisionNet
return VisionNet return VisionNet
@staticmethod @staticmethod

View file

@ -210,7 +210,7 @@ class ModelV2:
restored = input_dict.copy(shallow=True) restored = input_dict.copy(shallow=True)
# Backward compatibility. # Backward compatibility.
if seq_lens is None: if seq_lens is None:
seq_lens = input_dict.seq_lens seq_lens = input_dict.get("seq_lens")
if not state: if not state:
state = [] state = []
i = 0 i = 0

View file

@ -11,7 +11,7 @@
from gym.spaces import Box, Discrete, MultiDiscrete from gym.spaces import Box, Discrete, MultiDiscrete
import numpy as np import numpy as np
import gym import gym
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Type, Union
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
@ -491,3 +491,347 @@ class AttentionWrapper(TFModelV2):
def value_function(self) -> TensorType: def value_function(self) -> TensorType:
assert self._features is not None, "Must call forward() first!" assert self._features is not None, "Must call forward() first!"
return tf.reshape(self._value_branch(self._features), [-1]) return tf.reshape(self._value_branch(self._features), [-1])
class Keras_GTrXLNet(tf.keras.Model if tf else object):
"""A GTrXL net Model described in [2].
This is still in an experimental phase.
Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
For an example script, see: `ray/rllib/examples/attention_net.py`.
To use this network as a replacement for an RNN, configure your Trainer
as follows:
Examples:
>> config["model"]["custom_model"] = GTrXLNet
>> config["model"]["max_seq_len"] = 10
>> config["model"]["custom_model_config"] = {
>> num_transformer_units=1,
>> attention_dim=32,
>> num_heads=2,
>> memory_inference=100,
>> memory_training=50,
>> etc..
>> }
"""
def __init__(self,
input_space: gym.spaces.Space,
action_space: gym.spaces.Space,
*,
name: str,
max_seq_len: int = 20,
num_transformer_units: int = 1,
attention_dim: int = 64,
num_heads: int = 2,
memory_inference: int = 50,
memory_training: int = 50,
head_dim: int = 32,
position_wise_mlp_dim: int = 32,
init_gru_gate_bias: float = 2.0):
"""Initializes a GTrXLNet instance.
Args:
num_transformer_units (int): The number of Transformer repeats to
use (denoted L in [2]).
attention_dim (int): The input and output dimensions of one
Transformer unit.
num_heads (int): The number of attention heads to use in parallel.
Denoted as `H` in [3].
memory_inference (int): The number of timesteps to concat (time
axis) and feed into the next transformer unit as inference
input. The first transformer unit will receive this number of
past observations (plus the current one), instead.
memory_training (int): The number of timesteps to concat (time
axis) and feed into the next transformer unit as training
input (plus the actual input sequence of len=max_seq_len).
The first transformer unit will receive this number of
past observations (plus the input sequence), instead.
head_dim (int): The dimension of a single(!) attention head within
a multi-head attention unit. Denoted as `d` in [3].
position_wise_mlp_dim (int): The dimension of the hidden layer
within the position-wise MLP (after the multi-head attention
block within one Transformer unit). This is the size of the
first of the two layers within the PositionwiseFeedforward. The
second layer always has size=`attention_dim`.
init_gru_gate_bias (float): Initial bias values for the GRU gates
(two GRUs per Transformer unit, one after the MHA, one after
the position-wise MLP).
"""
super().__init__(name=name)
self.num_transformer_units = num_transformer_units
self.attention_dim = attention_dim
self.num_heads = num_heads
self.memory_inference = memory_inference
self.memory_training = memory_training
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.obs_dim = input_space.shape[0]
# Raw observation input (plus (None) time axis).
input_layer = tf.keras.layers.Input(
shape=(
None,
self.obs_dim,
), name="inputs")
memory_ins = [
tf.keras.layers.Input(
shape=(
None,
self.attention_dim,
),
dtype=tf.float32,
name="memory_in_{}".format(i))
for i in range(self.num_transformer_units)
]
# Map observation dim to input/output transformer (attention) dim.
E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
# Output, collected and concat'd to build the internal, tau-len
# Memory units used for additional contextual information.
memory_outs = [E_out]
# 2) Create L Transformer blocks according to [2].
for i in range(self.num_transformer_units):
# RelativeMultiHeadAttention part.
MHA_out = SkipConnection(
RelativeMultiHeadAttention(
out_dim=self.attention_dim,
num_heads=num_heads,
head_dim=head_dim,
input_layernorm=True,
output_activation=tf.nn.relu),
fan_in_layer=GRUGate(init_gru_gate_bias),
name="mha_{}".format(i + 1))(
E_out, memory=memory_ins[i])
# Position-wise MLP part.
E_out = SkipConnection(
tf.keras.Sequential(
(tf.keras.layers.LayerNormalization(axis=-1),
PositionwiseFeedforward(
out_dim=self.attention_dim,
hidden_dim=position_wise_mlp_dim,
output_activation=tf.nn.relu))),
fan_in_layer=GRUGate(init_gru_gate_bias),
name="pos_wise_mlp_{}".format(i + 1))(MHA_out)
# Output of position-wise MLP == E(l-1), which is concat'd
# to the current Mem block (M(l-1)) to yield E~(l-1), which is then
# used by the next transformer block.
memory_outs.append(E_out)
self._logits = None
self._value_out = None
self.trxl_model = tf.keras.Model(
inputs=[input_layer] + memory_ins,
outputs=[E_out] + memory_outs[:-1])
self.view_requirements = {
SampleBatch.OBS: ViewRequirement(space=input_space),
}
# Setup trajectory views (`memory-inference` x past memory outs).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attention_dim, ))
self.view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift="-{}:-1".format(self.memory_inference),
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space)
self.view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=space,
used_for_training=False)
def call(self, inputs, memory_ins) -> (TensorType, List[TensorType]):
# Add the time dim to observations.
B = tf.shape(memory_ins[0])[0]
shape = tf.shape(inputs)
T = shape[0] // B
inputs = tf.reshape(inputs, tf.concat([[-1, T], shape[1:]], axis=0))
all_out = self.trxl_model([inputs] + memory_ins)
out = tf.reshape(all_out[0], [-1, self.attention_dim])
memory_outs = all_out[1:]
return out, [
tf.reshape(m, [-1, self.attention_dim]) for m in memory_outs
]
class Keras_AttentionWrapper(tf.keras.Model if tf else object):
"""A tf keras auto-GTrXL wrapper used when `use_attention`=True."""
def __init__(
self,
input_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: Optional[int] = None,
*,
name: str,
wrapped_cls: Type["tf.keras.Model"],
max_seq_len: int = 20,
attention_num_transformer_units: int = 1,
attention_dim: int = 64,
attention_num_heads: int = 1,
attention_head_dim: int = 32,
attention_memory_inference: int = 50,
attention_memory_training: int = 50,
attention_position_wise_mlp_dim: int = 32,
attention_init_gru_gate_bias: int = 2.0,
attention_use_n_prev_actions: int = 0,
attention_use_n_prev_rewards: int = 0,
**kwargs,
):
super().__init__(name=name)
self.wrapped_keras_model = wrapped_cls(
input_space, action_space, None, name="wrapped_" + name, **kwargs)
self.action_space = action_space
self.max_seq_len = max_seq_len
self.use_n_prev_actions = attention_use_n_prev_actions
self.use_n_prev_rewards = attention_use_n_prev_rewards
self.attention_dim = attention_dim
# Guess the number of outputs for the wrapped model by looking
# at its first output's shape.
# This will be the input size for the LSTM layer (plus
# maybe prev-actions/rewards).
# If no layers in the wrapped model, set it to the
# observation space.
if self.wrapped_keras_model.layers:
assert self.wrapped_keras_model.layers[-1].outputs
assert len(
self.wrapped_keras_model.layers[-1].outputs[0].shape) == 2
wrapped_num_outputs = int(
self.wrapped_keras_model.layers[-1].outputs[0].shape[1])
else:
wrapped_num_outputs = int(np.product(self.obs_space.shape))
if isinstance(action_space, Discrete):
self.action_dim = action_space.n
elif isinstance(action_space, MultiDiscrete):
self.action_dim = np.product(action_space.nvec)
elif action_space.shape is not None:
self.action_dim = int(np.product(action_space.shape))
else:
self.action_dim = int(len(action_space))
# Add prev-action/reward nodes to input to LSTM.
if self.use_n_prev_actions:
wrapped_num_outputs += self.use_n_prev_actions * self.action_dim
if self.use_n_prev_rewards:
wrapped_num_outputs += self.use_n_prev_rewards
in_space = gym.spaces.Box(
float("-inf"),
float("inf"),
shape=(wrapped_num_outputs, ),
dtype=np.float32)
input_ = tf.keras.layers.Input(
shape=(wrapped_num_outputs, ), name="inputs")
memory_ins = [
tf.keras.layers.Input(
shape=(
None,
self.attention_dim,
), name=f"memory_in_{i}")
for i in range(attention_num_transformer_units)
]
# Construct GTrXL sub-module.
self.gtrxl = Keras_GTrXLNet(
in_space,
action_space,
name="gtrxl",
max_seq_len=max_seq_len,
num_transformer_units=attention_num_transformer_units,
attention_dim=self.attention_dim,
num_heads=attention_num_heads,
head_dim=attention_head_dim,
memory_inference=attention_memory_inference,
memory_training=attention_memory_training,
position_wise_mlp_dim=attention_position_wise_mlp_dim,
init_gru_gate_bias=attention_init_gru_gate_bias,
)
keras_gtrxl_model_out, memory_outs = self.gtrxl(input_, memory_ins)
# Postprocess GTrXL output with another hidden layer and compute
# values.
logits = tf.keras.layers.Dense(
num_outputs, activation=None)(keras_gtrxl_model_out)
value_outs = tf.keras.layers.Dense(
1, activation=None)(keras_gtrxl_model_out)
self.base_model = tf.keras.models.Model(
[input_, memory_ins], [logits, memory_outs, value_outs])
self.view_requirements = self.gtrxl.view_requirements
self.view_requirements["obs"].space = input_space
# Add prev-a/r to this model's view, if required.
if self.use_n_prev_actions:
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(
SampleBatch.ACTIONS,
space=self.action_space,
shift="-{}:-1".format(self.use_n_prev_actions))
if self.use_n_prev_rewards:
self.view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(
SampleBatch.REWARDS,
shift="-{}:-1".format(self.use_n_prev_rewards))
def call(self, input_dict: SampleBatch) -> \
(TensorType, List[TensorType], Dict[str, TensorType]):
assert input_dict["seq_lens"] is not None
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _, _ = self.wrapped_keras_model(input_dict)
# Concat. prev-action/reward if required.
prev_a_r = []
if self.use_n_prev_actions:
if isinstance(self.action_space, Discrete):
for i in range(self.use_n_prev_actions):
prev_a_r.append(
one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i],
self.action_space))
elif isinstance(self.action_space, MultiDiscrete):
for i in range(
self.use_n_prev_actions,
step=self.action_space.shape[0]):
prev_a_r.append(
one_hot(
tf.cast(
input_dict[SampleBatch.PREV_ACTIONS]
[:, i:i + self.action_space.shape[0]],
tf.float32), self.action_space))
else:
prev_a_r.append(
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
tf.float32),
[-1, self.use_n_prev_actions * self.action_dim]))
if self.use_n_prev_rewards:
prev_a_r.append(
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
[-1, self.use_n_prev_rewards]))
if prev_a_r:
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
memory_ins = [
s for k, s in input_dict.items() if k.startswith("state_in_")
]
model_out, memory_outs, value_outs = self.base_model([wrapped_out] +
memory_ins)
return model_out, memory_outs, {
SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1])
}

View file

@ -1,17 +1,18 @@
import numpy as np import numpy as np
import gym import gym
from typing import Optional, Sequence, Tuple from typing import Dict, Optional, Sequence
from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.utils import get_activation_fn from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict from ray.rllib.utils.typing import TensorType, List, ModelConfigDict
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
# TODO: (sven) obsolete this class once we only support native keras models.
class FullyConnectedNetwork(TFModelV2): class FullyConnectedNetwork(TFModelV2):
"""Generic fully connected network implemented in ModelV2 API.""" """Generic fully connected network implemented in ModelV2 API."""
@ -138,7 +139,7 @@ class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
self, self,
input_space: gym.spaces.Space, input_space: gym.spaces.Space,
action_space: gym.spaces.Space, action_space: gym.spaces.Space,
num_outputs: int, num_outputs: Optional[int] = None,
*, *,
name: str = "", name: str = "",
fcnet_hiddens: Optional[Sequence[int]] = (), fcnet_hiddens: Optional[Sequence[int]] = (),
@ -209,10 +210,6 @@ class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
name="fc_out", name="fc_out",
activation=None, activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer) kernel_initializer=normc_initializer(0.01))(last_layer)
# Adjust num_outputs to be the number of nodes in the last layer.
else:
self.num_outputs = (
[int(np.product(input_space.shape))] + hiddens[-1:])[-1]
# Concat the log std vars to the end of the state-dependent means. # Concat the log std vars to the end of the state-dependent means.
if free_log_std and logits_out is not None: if free_log_std and logits_out is not None:
@ -249,9 +246,8 @@ class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
inputs, [(logits_out inputs, [(logits_out
if logits_out is not None else last_layer), value_out]) if logits_out is not None else last_layer), value_out])
def call(self, input_dict: Dict[str, TensorType]) -> \ def call(self, input_dict: SampleBatch) -> \
Tuple[TensorType, List[TensorType], TensorType]: (TensorType, List[TensorType], Dict[str, TensorType]):
model_out, value_out = self.base_model(input_dict["obs"]) model_out, value_out = self.base_model(input_dict[SampleBatch.OBS])
return model_out, [], { extra_outs = {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])}
SampleBatch.VF_PREDS: tf.reshape(value_out, [-1]) return model_out, [], extra_outs
}

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import gym import gym
from gym.spaces import Discrete, MultiDiscrete from gym.spaces import Box, Discrete, MultiDiscrete
from typing import Dict, List from typing import Dict, List, Optional, Type
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2
@ -235,3 +235,167 @@ class LSTMWrapper(RecurrentNetwork):
@override(ModelV2) @override(ModelV2)
def value_function(self) -> TensorType: def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1]) return tf.reshape(self._value_out, [-1])
class Keras_LSTMWrapper(tf.keras.Model if tf else object):
"""A tf keras auto-LSTM wrapper used when `use_lstm`=True."""
def __init__(
self,
input_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: Optional[int] = None,
*,
name: str,
wrapped_cls: Type["tf.keras.Model"],
max_seq_len: int = 20,
lstm_cell_size: int = 256,
lstm_use_prev_action: bool = False,
lstm_use_prev_reward: bool = False,
**kwargs,
):
super().__init__(name=name)
self.wrapped_keras_model = wrapped_cls(
input_space, action_space, None, name="wrapped_" + name, **kwargs)
self.action_space = action_space
self.max_seq_len = max_seq_len
# Guess the number of outputs for the wrapped model by looking
# at its first output's shape.
# This will be the input size for the LSTM layer (plus
# maybe prev-actions/rewards).
# If no layers in the wrapped model, set it to the
# observation space.
if self.wrapped_keras_model.layers:
assert self.wrapped_keras_model.layers[-1].outputs
assert len(
self.wrapped_keras_model.layers[-1].outputs[0].shape) == 2
wrapped_num_outputs = int(
self.wrapped_keras_model.layers[-1].outputs[0].shape[1])
else:
wrapped_num_outputs = int(np.product(self.obs_space.shape))
self.lstm_cell_size = lstm_cell_size
self.lstm_use_prev_action = lstm_use_prev_action
self.lstm_use_prev_reward = lstm_use_prev_reward
if isinstance(self.action_space, Discrete):
self.action_dim = self.action_space.n
elif isinstance(self.action_space, MultiDiscrete):
self.action_dim = np.sum(self.action_space.nvec)
elif self.action_space.shape is not None:
self.action_dim = int(np.product(self.action_space.shape))
else:
self.action_dim = int(len(self.action_space))
# Add prev-action/reward nodes to input to LSTM.
if self.lstm_use_prev_action:
wrapped_num_outputs += self.action_dim
if self.lstm_use_prev_reward:
wrapped_num_outputs += 1
# Define input layers.
input_layer = tf.keras.layers.Input(
shape=(None, wrapped_num_outputs), name="inputs")
state_in_h = tf.keras.layers.Input(
shape=(self.lstm_cell_size, ), name="h")
state_in_c = tf.keras.layers.Input(
shape=(self.lstm_cell_size, ), name="c")
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
# Preprocess observation with a hidden layer and send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
self.lstm_cell_size,
return_sequences=True,
return_state=True,
name="lstm")(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
# Postprocess LSTM output with another hidden layer
# if num_outputs not None.
if num_outputs:
logits = tf.keras.layers.Dense(
num_outputs,
activation=tf.keras.activations.linear,
name="logits")(lstm_out)
else:
logits = lstm_out
# Compute values.
values = tf.keras.layers.Dense(
1, activation=None, name="values")(lstm_out)
# Create the RNN model
self._rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c])
# Use view-requirements of wrapped model and add own
# requirements.
self.view_requirements = \
getattr(self.wrapped_keras_model, "view_requirements", {
SampleBatch.OBS: ViewRequirement(space=input_space)
})
# Add prev-a/r to this model's view, if required.
if self.lstm_use_prev_action:
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
if self.lstm_use_prev_reward:
self.view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
# Internal states view requirements.
for i in range(2):
space = Box(-1.0, 1.0, shape=(self.lstm_cell_size, ))
self.view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
used_for_compute_actions=True,
batch_repeat_value=max_seq_len,
space=space)
self.view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=space, used_for_training=True)
def call(self, input_dict: SampleBatch) -> \
(TensorType, List[TensorType], Dict[str, TensorType]):
assert input_dict.get("seq_lens") is not None
# Push obs through underlying (wrapped) model first.
wrapped_out, _, _ = self.wrapped_keras_model(input_dict)
# Concat. prev-action/reward if required.
prev_a_r = []
if self.lstm_use_prev_action:
prev_a = input_dict[SampleBatch.PREV_ACTIONS]
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
prev_a = one_hot(prev_a, self.action_space)
prev_a_r.append(
tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]))
if self.lstm_use_prev_reward:
prev_a_r.append(
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
[-1, 1]))
if prev_a_r:
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
max_seq_len = tf.shape(wrapped_out)[0] // tf.shape(
input_dict["seq_lens"])[0]
wrapped_out_plus_time_dim = add_time_dimension(
wrapped_out, max_seq_len=max_seq_len, framework="tf")
model_out, value_out, h, c = self._rnn_model([
wrapped_out_plus_time_dim, input_dict["seq_lens"],
input_dict["state_in_0"], input_dict["state_in_1"]
])
model_out_no_time_dim = tf.reshape(
model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0))
return model_out_no_time_dim, [h, c], {
SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])
}

View file

@ -1,5 +1,5 @@
from typing import Dict, List
import gym import gym
from typing import Dict, List, Optional, Sequence
from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.misc import normc_initializer
@ -12,6 +12,7 @@ from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
# TODO: (sven) obsolete this class once we only support native keras models.
class VisionNetwork(TFModelV2): class VisionNetwork(TFModelV2):
"""Generic vision network implemented in ModelV2 API. """Generic vision network implemented in ModelV2 API.
@ -19,10 +20,6 @@ class VisionNetwork(TFModelV2):
via the config keys: via the config keys:
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack. `post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
`post_fcnet_activation`: Activation function to use for this FC stack. `post_fcnet_activation`: Activation function to use for this FC stack.
Examples:
""" """
def __init__(self, obs_space: gym.spaces.Space, def __init__(self, obs_space: gym.spaces.Space,
@ -245,3 +242,238 @@ class VisionNetwork(TFModelV2):
def value_function(self) -> TensorType: def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1]) return tf.reshape(self._value_out, [-1])
class Keras_VisionNetwork(tf.keras.Model if tf else object):
"""Generic vision network implemented in tf keras.
An additional post-conv fully connected stack can be added and configured
via the config keys:
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
`post_fcnet_activation`: Activation function to use for this FC stack.
"""
def __init__(
self,
input_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: Optional[int] = None,
*,
name: str = "",
conv_filters: Optional[Sequence[Sequence[int]]] = None,
conv_activation: Optional[str] = None,
post_fcnet_hiddens: Optional[Sequence[int]] = (),
post_fcnet_activation: Optional[str] = None,
no_final_linear: bool = False,
vf_share_layers: bool = False,
free_log_std: bool = False,
**kwargs,
):
super().__init__(name=name)
if not conv_filters:
conv_filters = get_filter_config(input_space.shape)
assert len(conv_filters) > 0,\
"Must provide at least 1 entry in `conv_filters`!"
conv_activation = get_activation_fn(conv_activation, framework="tf")
post_fcnet_activation = get_activation_fn(
post_fcnet_activation, framework="tf")
self.traj_view_framestacking = False
# Perform Atari framestacking via traj. view API.
num_framestacks = kwargs.get("num_framestacks")
if num_framestacks != "auto" and num_framestacks and \
num_framestacks > 1:
input_shape = input_space.shape + (num_framestacks, )
self.data_format = "channels_first"
self.traj_view_framestacking = True
else:
input_shape = input_space.shape
self.data_format = "channels_last"
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
last_layer = inputs
# Whether the last layer is the output of a Flattened (rather than
# a n x (1,1) Conv2D).
self.last_layer_is_flattened = False
# Build the action layers
for i, (out_size, kernel, stride) in enumerate(conv_filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=conv_activation,
padding="same",
data_format="channels_last",
name="conv{}".format(i))(last_layer)
out_size, kernel, stride = conv_filters[-1]
# No final linear: Last layer has activation function and exits with
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
# on `post_fcnet_...` settings).
if no_final_linear and num_outputs:
last_layer = tf.keras.layers.Conv2D(
out_size if post_fcnet_hiddens else num_outputs,
kernel,
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=conv_activation,
padding="valid",
data_format="channels_last",
name="conv_out")(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
if post_fcnet_hiddens else
[])
for i, out_size in enumerate(layer_sizes):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=conv_activation,
padding="valid",
data_format="channels_last",
name="conv{}".format(len(conv_filters)))(last_layer)
# num_outputs defined. Use that to create an exact
# `num_output`-sized (1,1)-Conv2D.
if num_outputs:
if post_fcnet_hiddens:
last_cnn = last_layer = tf.keras.layers.Conv2D(
post_fcnet_hiddens[0], [1, 1],
activation=post_fcnet_activation,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens[1:] +
[num_outputs]):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i + 1),
activation=post_fcnet_activation
if i < len(post_fcnet_hiddens) - 1 else None,
kernel_initializer=normc_initializer(1.0))(
last_layer)
else:
last_cnn = last_layer = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, 1, "
"1, {} (`num_outputs`)] shape (but in {})! Please "
"adjust your Conv2D stack such that the dims 1 and 2 "
"are both 1.".format(
self.model_config["conv_filters"], num_outputs,
list(last_cnn.shape)))
# num_outputs not known -> Flatten.
else:
self.last_layer_is_flattened = True
last_layer = tf.keras.layers.Flatten(
data_format="channels_last")(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
logits_out = last_layer
# Build the value layers
if vf_share_layers:
if not self.last_layer_is_flattened:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
for i, (out_size, kernel, stride) in enumerate(
conv_filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=conv_activation,
padding="same",
data_format="channels_last",
name="conv_value_{}".format(i))(last_layer)
out_size, kernel, stride = conv_filters[-1]
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple)) else (stride, stride),
activation=conv_activation,
padding="valid",
data_format="channels_last",
name="conv_value_{}".format(len(conv_filters)))(last_layer)
last_layer = tf.keras.layers.Conv2D(
1, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_value_out")(last_layer)
value_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
# Optional: framestacking obs/new_obs for Atari.
if self.traj_view_framestacking:
from_ = num_framestacks - 1
self.view_requirements[SampleBatch.OBS].shift = \
"-{}:0".format(from_)
self.view_requirements[SampleBatch.OBS].shift_from = -from_
self.view_requirements[SampleBatch.OBS].shift_to = 0
self.view_requirements[SampleBatch.NEXT_OBS] = ViewRequirement(
data_col=SampleBatch.OBS,
shift="-{}:1".format(from_ - 1),
space=self.view_requirements[SampleBatch.OBS].space,
used_for_compute_actions=False,
)
def call(self, input_dict: SampleBatch) -> \
(TensorType, List[TensorType], Dict[str, TensorType]):
obs = input_dict["obs"]
if self.data_format == "channels_first":
obs = tf.transpose(obs, [0, 2, 3, 1])
# Explicit cast to float32 needed in eager.
model_out, self._value_out = self.base_model(tf.cast(obs, tf.float32))
state = [v for k, v in input_dict.items() if k.startswith("state_in_")]
extra_outs = {SampleBatch.VF_PREDS: tf.reshape(self._value_out, [-1])}
# Our last layer is already flat.
if self.last_layer_is_flattened:
return model_out, state, extra_outs
# Last layer is a n x [1,1] Conv2D -> Flatten.
else:
return tf.squeeze(model_out, axis=[1, 2]), state, extra_outs

View file

@ -456,7 +456,7 @@ class DynamicTFPolicy(TFPolicy):
dummy_batch = self._get_dummy_batch_from_view_requirements( dummy_batch = self._get_dummy_batch_from_view_requirements(
batch_size=32) batch_size=32)
return SampleBatch(input_dict, _seq_lens=self._seq_lens), dummy_batch return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
def _initialize_loss_from_dummy_batch( def _initialize_loss_from_dummy_batch(
self, auto_remove_unneeded_view_reqs: bool = True, self, auto_remove_unneeded_view_reqs: bool = True,
@ -503,8 +503,8 @@ class DynamicTFPolicy(TFPolicy):
dict(self._input_dict, **self._loss_input_dict)) dict(self._input_dict, **self._loss_input_dict))
if self._state_inputs: if self._state_inputs:
train_batch.seq_lens = self._seq_lens train_batch["seq_lens"] = self._seq_lens
self._loss_input_dict.update({"seq_lens": train_batch.seq_lens}) self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]})
self._loss_input_dict.update({k: v for k, v in train_batch.items()}) self._loss_input_dict.update({k: v for k, v in train_batch.items()})
@ -522,8 +522,8 @@ class DynamicTFPolicy(TFPolicy):
TFPolicy._initialize_loss(self, loss, [ TFPolicy._initialize_loss(self, loss, [
(k, v) for k, v in train_batch.items() if k in all_accessed_keys (k, v) for k, v in train_batch.items() if k in all_accessed_keys
] + ([("seq_lens", train_batch.seq_lens)] ] + ([("seq_lens", train_batch["seq_lens"])]
if train_batch.seq_lens is not None else [])) if "seq_lens" in train_batch else []))
if "is_training" in self._loss_input_dict: if "is_training" in self._loss_input_dict:
del self._loss_input_dict["is_training"] del self._loss_input_dict["is_training"]

View file

@ -27,8 +27,6 @@ logger = logging.getLogger(__name__)
def _convert_to_tf(x, dtype=None): def _convert_to_tf(x, dtype=None):
if isinstance(x, SampleBatch): if isinstance(x, SampleBatch):
dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS} dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
if x.seq_lens is not None:
dict_["seq_lens"] = x.seq_lens
return tf.nest.map_structure(_convert_to_tf, dict_) return tf.nest.map_structure(_convert_to_tf, dict_)
elif isinstance(x, Policy): elif isinstance(x, Policy):
return x return x
@ -364,8 +362,6 @@ def build_eager_tf_policy(
batch_divisibility_req=self.batch_divisibility_req, batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements, view_requirements=self.view_requirements,
) )
else:
postprocessed_batch["seq_lens"] = postprocessed_batch.seq_lens
self._is_training = True self._is_training = True
postprocessed_batch["is_training"] = True postprocessed_batch["is_training"] = True
@ -526,6 +522,9 @@ def build_eager_tf_policy(
raise e raise e
elif isinstance(self.model, tf.keras.Model): elif isinstance(self.model, tf.keras.Model):
input_dict = SampleBatch(input_dict, seq_lens=seq_lens) input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
if state_batches and "state_in_0" not in input_dict:
for i, s in enumerate(state_batches):
input_dict[f"state_in_{i}"] = s
self._lazy_tensor_dict(input_dict) self._lazy_tensor_dict(input_dict)
dist_inputs, state_out, extra_fetches = \ dist_inputs, state_out, extra_fetches = \
self.model(input_dict) self.model(input_dict)

View file

@ -11,13 +11,14 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
unbatch unbatch
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TrainerConfigDict, Tuple, Union TensorType, TrainerConfigDict, Tuple, Union
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch() torch, _ = try_import_torch()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -648,13 +649,13 @@ class Policy(metaclass=ABCMeta):
i += 1 i += 1
seq_len = sample_batch_size // B seq_len = sample_batch_size // B
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32) seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
postprocessed_batch.seq_lens = seq_lens postprocessed_batch["seq_lens"] = seq_lens
# Switch on lazy to-tensor conversion on `postprocessed_batch`. # Switch on lazy to-tensor conversion on `postprocessed_batch`.
train_batch = self._lazy_tensor_dict(postprocessed_batch) train_batch = self._lazy_tensor_dict(postprocessed_batch)
# Calling loss, so set `is_training` to True. # Calling loss, so set `is_training` to True.
train_batch.is_training = True train_batch.is_training = True
if seq_lens is not None: if seq_lens is not None:
train_batch.seq_lens = seq_lens train_batch["seq_lens"] = seq_lens
train_batch.count = self._dummy_batch.count train_batch.count = self._dummy_batch.count
# Call the loss function, if it exists. # Call the loss function, if it exists.
if self._loss is not None: if self._loss is not None:
@ -761,6 +762,7 @@ class Policy(metaclass=ABCMeta):
""" """
self._model_init_state_automatically_added = True self._model_init_state_automatically_added = True
model = getattr(self, "model", None) model = getattr(self, "model", None)
obj = model or self obj = model or self
if model and not hasattr(model, "view_requirements"): if model and not hasattr(model, "view_requirements"):
model.view_requirements = { model.view_requirements = {
@ -772,8 +774,18 @@ class Policy(metaclass=ABCMeta):
if hasattr(obj, "get_initial_state") and callable( if hasattr(obj, "get_initial_state") and callable(
obj.get_initial_state): obj.get_initial_state):
init_state = obj.get_initial_state() init_state = obj.get_initial_state()
else:
# Add this functionality automatically for new native model API.
if tf and isinstance(model, tf.keras.Model) and \
"state_in_0" not in view_reqs:
obj.get_initial_state = lambda: [
np.zeros_like(view_req.space.sample())
for k, view_req in model.view_requirements.items()
if k.startswith("state_in_")]
else: else:
obj.get_initial_state = lambda: [] obj.get_initial_state = lambda: []
if "state_in_0" in view_reqs:
self.is_recurrent = lambda: True
for i, state in enumerate(init_state): for i, state in enumerate(init_state):
space = Box(-1.0, 1.0, shape=state.shape) if \ space = Box(-1.0, 1.0, shape=state.shape) if \
hasattr(state, "shape") else state hasattr(state, "shape") else state

View file

@ -75,8 +75,8 @@ def pad_batch_to_sequences_of_same_size(
if "state_in_0" in batch or "state_out_0" in batch: if "state_in_0" in batch or "state_out_0" in batch:
# Check, whether the state inputs have already been reduced to their # Check, whether the state inputs have already been reduced to their
# init values at the beginning of each max_seq_len chunk. # init values at the beginning of each max_seq_len chunk.
if batch.seq_lens is not None and \ if batch.get("seq_lens") is not None and \
len(batch["state_in_0"]) == len(batch.seq_lens): len(batch["state_in_0"]) == len(batch["seq_lens"]):
states_already_reduced_to_init = True states_already_reduced_to_init = True
# RNN (or single timestep state-in): Set the max dynamically. # RNN (or single timestep state-in): Set the max dynamically.
@ -113,7 +113,7 @@ def pad_batch_to_sequences_of_same_size(
episode_ids=batch.get(SampleBatch.EPS_ID), episode_ids=batch.get(SampleBatch.EPS_ID),
unroll_ids=batch.get(SampleBatch.UNROLL_ID), unroll_ids=batch.get(SampleBatch.UNROLL_ID),
agent_indices=batch.get(SampleBatch.AGENT_INDEX), agent_indices=batch.get(SampleBatch.AGENT_INDEX),
seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")), seq_lens=batch.get("seq_lens"),
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
dynamic_max=dynamic_max, dynamic_max=dynamic_max,
states_already_reduced_to_init=states_already_reduced_to_init, states_already_reduced_to_init=states_already_reduced_to_init,
@ -318,12 +318,12 @@ def timeslice_along_seq_lens_with_overlap(
zero_init_states=True) -> List["SampleBatch"]: zero_init_states=True) -> List["SampleBatch"]:
"""Slices batch along `seq_lens` (each seq-len item produces one batch). """Slices batch along `seq_lens` (each seq-len item produces one batch).
Asserts that seq_lens is given or sample_batch.seq_lens is not None. Asserts that seq_lens is given or sample_batch["seq_lens"] is not None.
Args: Args:
sample_batch (SampleBatch): The SampleBatch to timeslice. sample_batch (SampleBatch): The SampleBatch to timeslice.
seq_lens (Optional[List[int]]): An optional list of seq_lens to slice seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
at. If None, use `sample_batch.seq_lens`. at. If None, use `sample_batch["seq_lens"]`.
zero_pad_max_seq_len (int): If >0, already zero-pad the resulting zero_pad_max_seq_len (int): If >0, already zero-pad the resulting
slices up to this length. NOTE: This max-len will include the slices up to this length. NOTE: This max-len will include the
additional timesteps gained via setting pre_overlap or additional timesteps gained via setting pre_overlap or
@ -354,10 +354,10 @@ def timeslice_along_seq_lens_with_overlap(
# count (makes sure each slice has exactly length 10). # count (makes sure each slice has exactly length 10).
""" """
if seq_lens is None: if seq_lens is None:
seq_lens = sample_batch.seq_lens seq_lens = sample_batch.get("seq_lens")
assert seq_lens is not None and len(seq_lens) > 0, \ assert seq_lens is not None and len(seq_lens) > 0, \
"Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!" "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
# Generate n slices based on self.seq_lens. # Generate n slices based on seq_lens.
start = 0 start = 0
slices = [] slices = []
for seq_len in seq_lens: for seq_len in seq_lens:
@ -391,10 +391,13 @@ def timeslice_along_seq_lens_with_overlap(
shape=(zero_length, ) + v.shape[1:], dtype=v.dtype), shape=(zero_length, ) + v.shape[1:], dtype=v.dtype),
v[data_begin:end] v[data_begin:end]
]) ])
for k, v in sample_batch.items() for k, v in sample_batch.items() if k != "seq_lens"
} }
else: else:
data = {k: v[begin:end] for k, v in sample_batch.items()} data = {
k: v[begin:end]
for k, v in sample_batch.items() if k != "seq_lens"
}
if zero_init_states_: if zero_init_states_:
i = 0 i = 0
@ -416,7 +419,7 @@ def timeslice_along_seq_lens_with_overlap(
i += 1 i += 1
key = "state_in_{}".format(i) key = "state_in_{}".format(i)
timeslices.append(SampleBatch(data, _seq_lens=[end - begin])) timeslices.append(SampleBatch(data, seq_lens=[end - begin]))
# Zero-pad each slice if necessary. # Zero-pad each slice if necessary.
if zero_pad_max_seq_len > 0: if zero_pad_max_seq_len > 0:

View file

@ -64,19 +64,7 @@ class SampleBatch(dict):
# Possible seq_lens (TxB or BxT) setup. # Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None) self.time_major = kwargs.pop("_time_major", None)
self.seq_lens = kwargs.pop("_seq_lens", kwargs.pop("seq_lens", None))
if self.seq_lens is None and len(args) > 0 and isinstance(
args[0], dict):
self.seq_lens = args[0].pop("_seq_lens", args[0].pop(
"seq_lens", None))
if isinstance(self.seq_lens, list):
self.seq_lens = np.array(self.seq_lens, dtype=np.int32)
self.max_seq_len = kwargs.pop("_max_seq_len", None) self.max_seq_len = kwargs.pop("_max_seq_len", None)
if self.max_seq_len is None and self.seq_lens is not None and \
not (tf and tf.is_tensor(self.seq_lens)) and \
len(self.seq_lens) > 0:
self.max_seq_len = max(self.seq_lens)
self.zero_padded = kwargs.pop("_zero_padded", False) self.zero_padded = kwargs.pop("_zero_padded", False)
self.is_training = kwargs.pop("_is_training", None) self.is_training = kwargs.pop("_is_training", None)
@ -88,14 +76,25 @@ class SampleBatch(dict):
self.added_keys = set() self.added_keys = set()
self.deleted_keys = set() self.deleted_keys = set()
self.intercepted_values = {} self.intercepted_values = {}
self.get_interceptor = None self.get_interceptor = None
# Clear out None seq-lens.
if self.get("seq_lens") is None or self.get("seq_lens") == []:
self.pop("seq_lens", None)
# Numpyfy seq_lens if list.
elif isinstance(self.get("seq_lens"), list):
self["seq_lens"] = np.array(self["seq_lens"], dtype=np.int32)
if self.max_seq_len is None and self.get("seq_lens") is not None and \
not (tf and tf.is_tensor(self["seq_lens"])) and \
len(self["seq_lens"]) > 0:
self.max_seq_len = max(self["seq_lens"])
if self.is_training is None: if self.is_training is None:
self.is_training = self.pop("is_training", False) self.is_training = self.pop("is_training", False)
lengths = [] lengths = []
copy_ = {k: v for k, v in self.items()} copy_ = {k: v for k, v in self.items() if k != "seq_lens"}
for k, v in copy_.items(): for k, v in copy_.items():
assert isinstance(k, str), self assert isinstance(k, str), self
len_ = len(v) if isinstance( len_ = len(v) if isinstance(
@ -105,10 +104,10 @@ class SampleBatch(dict):
if isinstance(v, list): if isinstance(v, list):
self[k] = np.array(v) self[k] = np.array(v)
if self.seq_lens is not None and \ if self.get("seq_lens") is not None and \
not (tf and tf.is_tensor(self.seq_lens)) and \ not (tf and tf.is_tensor(self["seq_lens"])) and \
len(self.seq_lens) > 0: len(self["seq_lens"]) > 0:
self.count = sum(self.seq_lens) self.count = sum(self["seq_lens"])
else: else:
self.count = lengths[0] if lengths else 0 self.count = lengths[0] if lengths else 0
@ -142,8 +141,8 @@ class SampleBatch(dict):
if zero_padded: if zero_padded:
assert s.max_seq_len == max_seq_len assert s.max_seq_len == max_seq_len
concat_samples.append(s) concat_samples.append(s)
if s.seq_lens is not None: if s.get("seq_lens") is not None:
seq_lens.extend(s.seq_lens) seq_lens.extend(s["seq_lens"])
out = {} out = {}
for k in concat_samples[0].keys(): for k in concat_samples[0].keys():
@ -152,7 +151,7 @@ class SampleBatch(dict):
time_major=concat_samples[0].time_major) time_major=concat_samples[0].time_major)
return SampleBatch( return SampleBatch(
out, out,
_seq_lens=np.array(seq_lens, dtype=np.int32), seq_lens=seq_lens,
_time_major=concat_samples[0].time_major, _time_major=concat_samples[0].time_major,
_zero_padded=zero_padded, _zero_padded=zero_padded,
_max_seq_len=max_seq_len, _max_seq_len=max_seq_len,
@ -202,7 +201,7 @@ class SampleBatch(dict):
if isinstance(v, np.ndarray) else v if isinstance(v, np.ndarray) else v
for (k, v) in self.items() for (k, v) in self.items()
}, },
_seq_lens=self.seq_lens, seq_lens=self.get("seq_lens"),
) )
copy_.set_get_interceptor(self.get_interceptor) copy_.set_get_interceptor(self.get_interceptor)
return copy_ return copy_
@ -304,7 +303,7 @@ class SampleBatch(dict):
SampleBatch: A new SampleBatch, which has a slice of this batch's SampleBatch: A new SampleBatch, which has a slice of this batch's
data. data.
""" """
if self.seq_lens is not None and len(self.seq_lens) > 0: if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0:
if start < 0: if start < 0:
data = { data = {
k: np.concatenate([ k: np.concatenate([
@ -312,15 +311,18 @@ class SampleBatch(dict):
shape=(-start, ) + v.shape[1:], dtype=v.dtype), shape=(-start, ) + v.shape[1:], dtype=v.dtype),
v[0:end] v[0:end]
]) ])
for k, v in self.items() for k, v in self.items() if k != "seq_lens"
} }
else: else:
data = {k: v[start:end] for k, v in self.items()} data = {
k: v[start:end]
for k, v in self.items() if k != "seq_lens"
}
# Fix state_in_x data. # Fix state_in_x data.
count = 0 count = 0
state_start = None state_start = None
seq_lens = None seq_lens = None
for i, seq_len in enumerate(self.seq_lens): for i, seq_len in enumerate(self["seq_lens"]):
count += seq_len count += seq_len
if count >= end: if count >= end:
state_idx = 0 state_idx = 0
@ -331,7 +333,7 @@ class SampleBatch(dict):
data[state_key] = self[state_key][state_start:i + 1] data[state_key] = self[state_key][state_start:i + 1]
state_idx += 1 state_idx += 1
state_key = "state_in_{}".format(state_idx) state_key = "state_in_{}".format(state_idx)
seq_lens = list(self.seq_lens[state_start:i]) + [ seq_lens = list(self["seq_lens"][state_start:i]) + [
seq_len - (count - end) seq_len - (count - end)
] ]
if start < 0: if start < 0:
@ -343,14 +345,13 @@ class SampleBatch(dict):
return SampleBatch( return SampleBatch(
data, data,
_seq_lens=np.array(seq_lens, dtype=np.int32), seq_lens=seq_lens,
_time_major=self.time_major, _time_major=self.time_major,
) )
else: else:
return SampleBatch( return SampleBatch(
{k: v[start:end] {k: v[start:end]
for k, v in self.items()}, for k, v in self.items()},
_seq_lens=None,
_is_training=self.is_training, _is_training=self.is_training,
_time_major=self.time_major) _time_major=self.time_major)
@ -383,8 +384,9 @@ class SampleBatch(dict):
data. If False, leave `state_in_x` keys as-is. data. If False, leave `state_in_x` keys as-is.
""" """
for col in self.keys(): for col in self.keys():
# Skip state in columns. # Skip "state_in_..." columns and "seq_lens".
if exclude_states is True and col.startswith("state_in_"): if (exclude_states is True and col.startswith("state_in_")) or \
col == "seq_lens":
continue continue
f = self[col] f = self[col]
@ -395,7 +397,7 @@ class SampleBatch(dict):
if f.shape[0] == max_seq_len: if f.shape[0] == max_seq_len:
continue continue
# Generate zero-filled primer of len=max_seq_len. # Generate zero-filled primer of len=max_seq_len.
length = len(self.seq_lens) * max_seq_len length = len(self["seq_lens"]) * max_seq_len
if f.dtype == np.object or f.dtype.type is np.str_: if f.dtype == np.object or f.dtype.type is np.str_:
f_pad = [None] * length f_pad = [None] * length
else: else:
@ -403,7 +405,7 @@ class SampleBatch(dict):
f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype) f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype)
# Fill primer with data. # Fill primer with data.
f_pad_base = f_base = 0 f_pad_base = f_base = 0
for len_ in self.seq_lens: for len_ in self["seq_lens"]:
f_pad[f_pad_base:f_pad_base + len_] = f[f_base:f_base + len_] f_pad[f_pad_base:f_pad_base + len_] = f[f_base:f_base + len_]
f_pad_base += max_seq_len f_pad_base += max_seq_len
f_base += len_ f_base += len_
@ -441,7 +443,7 @@ class SampleBatch(dict):
Returns: Returns:
TensorType: The data under the given key. TensorType: The data under the given key.
""" """
if not hasattr(self, key): if not hasattr(self, key) and key in self:
self.accessed_keys.add(key) self.accessed_keys.add(key)
# Backward compatibility for when "input-dicts" were used. # Backward compatibility for when "input-dicts" were used.
@ -452,13 +454,6 @@ class SampleBatch(dict):
new="SampleBatch.is_training", new="SampleBatch.is_training",
error=False) error=False)
return self.is_training return self.is_training
elif key == "seq_lens":
if self.get_interceptor is not None and self.seq_lens is not None:
if "seq_lens" not in self.intercepted_values:
self.intercepted_values["seq_lens"] = self.get_interceptor(
self.seq_lens)
return self.intercepted_values["seq_lens"]
return self.seq_lens
value = dict.__getitem__(self, key) value = dict.__getitem__(self, key)
if self.get_interceptor is not None: if self.get_interceptor is not None:
@ -475,12 +470,9 @@ class SampleBatch(dict):
key (str): The column name to set a value for. key (str): The column name to set a value for.
item (TensorType): The data to insert. item (TensorType): The data to insert.
""" """
if key == "seq_lens":
self.seq_lens = item
return
# Defend against creating SampleBatch via pickle (no property # Defend against creating SampleBatch via pickle (no property
# `added_keys` and first item is already set). # `added_keys` and first item is already set).
elif not hasattr(self, "added_keys"): if not hasattr(self, "added_keys"):
dict.__setitem__(self, key, item) dict.__setitem__(self, key, item)
return return
@ -548,15 +540,15 @@ class SampleBatch(dict):
def _get_slice_indices(self, slice_size): def _get_slice_indices(self, slice_size):
i = 0 i = 0
slices = [] slices = []
if self.seq_lens is not None and len(self.seq_lens) > 0: if self.get("seq_lens") is not None and len(self["seq_lens"]) > 0:
assert np.all(self.seq_lens < slice_size), \ assert np.all(self["seq_lens"] < slice_size), \
"ERROR: `slice_size` must be larger than the max. seq-len " \ "ERROR: `slice_size` must be larger than the max. seq-len " \
"in the batch!" "in the batch!"
start_pos = 0 start_pos = 0
current_slize_size = 0 current_slize_size = 0
idx = 0 idx = 0
while idx < len(self.seq_lens): while idx < len(self["seq_lens"]):
seq_len = self.seq_lens[idx] seq_len = self["seq_lens"][idx]
current_slize_size += seq_len current_slize_size += seq_len
# Complete minibatch -> Append to slices. # Complete minibatch -> Append to slices.
if current_slize_size >= slice_size: if current_slize_size >= slice_size:
@ -645,7 +637,7 @@ class SampleBatch(dict):
input_dict[view_col] = self[data_col][ input_dict[view_col] = self[data_col][
index:index + 1 if index != -1 else None] index:index + 1 if index != -1 else None]
return SampleBatch(input_dict, _seq_lens=np.array([1], dtype=np.int32)) return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
@PublicAPI @PublicAPI

View file

@ -34,14 +34,14 @@ class TestSampleBatch(unittest.TestCase):
# Add an item and check, whether it's in the "added" list. # Add an item and check, whether it's in the "added" list.
batch["d"] = np.array(1) batch["d"] = np.array(1)
assert batch.added_keys == {"d"} assert batch.added_keys == {"d"}, batch.added_keys
# Access two keys and check, whether they are in the # Access two keys and check, whether they are in the
# "accessed" list. # "accessed" list.
print(batch["a"], batch["b"]) print(batch["a"], batch["b"])
assert batch.accessed_keys == {"a", "b"} assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys
# Delete a key and check, whether it's in the "deleted" list. # Delete a key and check, whether it's in the "deleted" list.
del batch["c"] del batch["c"]
assert batch.deleted_keys == {"c"} assert batch.deleted_keys == {"c"}, batch.deleted_keys
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -905,6 +905,7 @@ class TFPolicy(Policy):
Feed dict of data. Feed dict of data.
""" """
# Get batch ready for RNNs, if applicable.
if not isinstance(train_batch, if not isinstance(train_batch,
SampleBatch) or not train_batch.zero_padded: SampleBatch) or not train_batch.zero_padded:
pad_batch_to_sequences_of_same_size( pad_batch_to_sequences_of_same_size(
@ -915,10 +916,6 @@ class TFPolicy(Policy):
feature_keys=list(self._loss_input_dict_no_rnn.keys()), feature_keys=list(self._loss_input_dict_no_rnn.keys()),
view_requirements=self.view_requirements, view_requirements=self.view_requirements,
) )
else:
train_batch["seq_lens"] = train_batch.seq_lens
# Get batch ready for RNNs, if applicable.
# Mark the batch as "is_training" so the Model can use this # Mark the batch as "is_training" so the Model can use this
# information. # information.

View file

@ -477,8 +477,6 @@ class TorchPolicy(Policy):
batch_divisibility_req=self.batch_divisibility_req, batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements, view_requirements=self.view_requirements,
) )
else:
postprocessed_batch["seq_lens"] = postprocessed_batch.seq_lens
# Mark the batch as "is_training" so the Model can use this # Mark the batch as "is_training" so the Model can use this
# information. # information.
@ -712,12 +710,13 @@ class TorchPolicy(Policy):
if "state_in_0" not in self._dummy_batch: if "state_in_0" not in self._dummy_batch:
self._dummy_batch["state_in_0"] = \ self._dummy_batch["state_in_0"] = \
self._dummy_batch["seq_lens"] = np.array([1.0]) self._dummy_batch["seq_lens"] = np.array([1.0])
seq_lens = self._dummy_batch["seq_lens"]
state_ins = [] state_ins = []
i = 0 i = 0
while "state_in_{}".format(i) in self._dummy_batch: while "state_in_{}".format(i) in self._dummy_batch:
state_ins.append(self._dummy_batch["state_in_{}".format(i)]) state_ins.append(self._dummy_batch["state_in_{}".format(i)])
i += 1 i += 1
seq_lens = self._dummy_batch["seq_lens"]
dummy_inputs = { dummy_inputs = {
k: self._dummy_batch[k] k: self._dummy_batch[k]
for k in self._dummy_batch.keys() if k != "is_training" for k in self._dummy_batch.keys() if k != "is_training"

View file

@ -23,7 +23,7 @@ class TestAttentionNetLearning(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
ray.init(num_cpus=5, ignore_reinit_error=True) ray.init(num_cpus=5)
@classmethod @classmethod
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:

View file

@ -63,7 +63,7 @@ def minibatches(samples, sgd_minibatch_size):
raise NotImplementedError( raise NotImplementedError(
"Minibatching not implemented for multi-agent in simple mode") "Minibatching not implemented for multi-agent in simple mode")
# Replace with `if samples.seq_lens` check. # Replace with `if samples["seq_lens"]` check.
if "state_in_0" in samples or "state_out_0" in samples: if "state_in_0" in samples or "state_out_0" in samples:
if log_once("not_shuffling_rnn_data_in_simple_mode"): if log_once("not_shuffling_rnn_data_in_simple_mode"):
logger.warning("Not time-shuffling RNN data for SGD.") logger.warning("Not time-shuffling RNN data for SGD.")

View file

@ -288,6 +288,7 @@ def check_compute_single_action(trainer,
pol = trainer.get_policy() pol = trainer.get_policy()
except AttributeError: except AttributeError:
pol = trainer.policy pol = trainer.policy
model = pol.model
action_space = pol.action_space action_space = pol.action_space
@ -328,7 +329,14 @@ def check_compute_single_action(trainer,
obs = np.clip(obs, -1.0, 1.0) obs = np.clip(obs, -1.0, 1.0)
state_in = None state_in = None
if include_state: if include_state:
state_in = pol.model.get_initial_state() state_in = model.get_initial_state()
if not state_in:
state_in = []
i = 0
while f"state_in_{i}" in model.view_requirements:
state_in.append(model.view_requirements[
f"state_in_{i}"].space.sample())
i += 1
action_in = action_space.sample() \ action_in = action_space.sample() \
if include_prev_action_reward else None if include_prev_action_reward else None
reward_in = 1.0 if include_prev_action_reward else None reward_in = 1.0 if include_prev_action_reward else None