[RLlib] Working/learning example: PPO + torch + LSTM. (#7797)

This commit is contained in:
Sven Mika 2020-04-01 07:00:28 +02:00 committed by GitHub
parent c23e56ce9a
commit 66df8b8c35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 578 additions and 213 deletions

View file

@ -1266,6 +1266,15 @@ py_test(
args = ["--iters=2", "--num-cpus=4"]
)
py_test(
name = "examples/custom_torch_rnn_model",
main = "examples/custom_torch_rnn_model.py",
tags = ["examples", "examples_C"],
size = "medium",
srcs = ["examples/custom_torch_rnn_model.py"],
args = ["--run=PPO", "--stop=90", "--num-cpus=4"]
)
py_test(
name = "examples/custom_torch_policy",
tags = ["examples", "examples_C"],

View file

@ -68,9 +68,10 @@ class PPOLoss:
use_gae (bool): If true, use the Generalized Advantage Estimator.
"""
if valid_mask is not None:
num_valid = torch.sum(valid_mask)
def reduce_mean_valid(t):
return torch.mean(t * valid_mask)
return torch.sum(t * valid_mask) / num_valid
else:
@ -190,14 +191,14 @@ class ValueNetworkMixin:
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: torch.Tensor([ob]).to(self.device),
SampleBatch.PREV_ACTIONS: torch.Tensor([prev_action]).to(
self.device),
SampleBatch.PREV_REWARDS: torch.Tensor([prev_reward]).to(
self.device),
SampleBatch.CUR_OBS: self._convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: self._convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: self._convert_to_tensor(
[prev_reward]),
"is_training": False,
}, [torch.Tensor([s]).to(self.device) for s in state],
torch.Tensor([1]).to(self.device))
}, [self._convert_to_tensor(s) for s in state],
self._convert_to_tensor([1]))
return self.model.value_function()[0]
else:

View file

@ -94,10 +94,10 @@ class TestPPO(unittest.TestCase):
"""Tests the PPO loss function math."""
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["eager"] = True
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config["vf_share_layers"] = True
# Fake CartPole episode of n time steps.
train_batch = {
@ -114,69 +114,56 @@ class TestPPO(unittest.TestCase):
ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32)
}
# tf.
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
for fw in ["tf", "torch"]:
print("framework={}".format(fw))
config["use_pytorch"] = fw == "torch"
config["eager"] = fw == "tf"
# Post-process (calculate simple (non-GAE) advantages) and attach to
# train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Calculate actual PPO loss (results are stored in policy.loss_obj) for
# tf.
ppo_surrogate_loss_tf(policy, policy.model, Categorical, train_batch)
# Post-process (calculate simple (non-GAE) advantages) and attach
# to train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
if fw == "tf":
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
else:
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
train_batch = policy._lazy_tensor_dict(train_batch)
vars = policy.model.trainable_variables()
expected_logits = fc(
fc(train_batch[SampleBatch.CUR_OBS], vars[0].numpy(),
vars[1].numpy()), vars[4].numpy(), vars[5].numpy())
expected_value_outs = fc(
fc(train_batch[SampleBatch.CUR_OBS], vars[2].numpy(),
vars[3].numpy()), vars[6].numpy(), vars[7].numpy())
# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])
kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model, Categorical, train_batch,
expected_logits, expected_value_outs
)
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)
# Calculate actual PPO loss (results are stored in policy.loss_obj)
# for tf.
if fw == "tf":
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
train_batch)
else:
ppo_surrogate_loss_torch(policy, policy.model,
TorchCategorical, train_batch)
# Torch.
config["use_pytorch"] = True
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
train_batch = policy._lazy_tensor_dict(train_batch)
vars = policy.model.variables() if fw == "tf" else \
list(policy.model.parameters())
expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS], vars[0],
vars[1])
expected_logits = fc(expected_shared_out, vars[2], vars[3])
expected_value_outs = fc(expected_shared_out, vars[4], vars[5])
# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])
# Calculate actual PPO loss (results are stored in policy.loss_obj)
# for tf.
ppo_surrogate_loss_torch(policy, policy.model, TorchCategorical,
train_batch)
kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model, TorchCategorical, train_batch,
policy.model.last_output(),
policy.model.value_function().detach().numpy()
)
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)
kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model,
Categorical if fw == "tf" else TorchCategorical,
train_batch,
expected_logits, expected_value_outs
)
check(policy.loss_obj.mean_kl, kl)
check(policy.loss_obj.mean_entropy, entropy)
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)
def _ppo_loss_helper(self, policy, model, dist_class, train_batch, logits,
vf_outs):

View file

@ -1,10 +1,10 @@
"""Example of using a custom RNN keras model."""
import argparse
import gym
from gym.spaces import Discrete
import numpy as np
import random
import argparse
import ray
from ray import tune
@ -89,13 +89,17 @@ class MyKerasRNN(RecurrentTFModelV2):
class RepeatInitialEnv(gym.Env):
"""Simple env in which the policy learns to repeat the initial observation
seen at timestep 0."""
"""Simple env where policy has to always repeat the initial observation.
def __init__(self):
Runs for 100 steps.
r=1 if action correct, -1 otherwise (max. R=100).
"""
def __init__(self, episode_len=100):
self.observation_space = Discrete(2)
self.action_space = Discrete(2)
self.token = None
self.episode_len = episode_len
self.num_steps = 0
def reset(self):
@ -109,7 +113,7 @@ class RepeatInitialEnv(gym.Env):
else:
reward = -1
self.num_steps += 1
done = self.num_steps > 100
done = self.num_steps >= self.episode_len
return 0, reward, done, {}
@ -148,22 +152,26 @@ if __name__ == "__main__":
ModelCatalog.register_custom_model("rnn", MyKerasRNN)
register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
config = {
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.9,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
},
}
tune.run(
args.run,
config=config,
stop={"episode_reward_mean": args.stop},
config={
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.9,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
},
})
)

View file

@ -0,0 +1,128 @@
import argparse
import ray
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv, \
RepeatAfterMeEnv
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
import ray.tune as tune
torch, nn = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="repeat_initial")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--fc-size", type=int, default=64)
parser.add_argument("--lstm-cell-size", type=int, default=256)
class RNNModel(RecurrentTorchModel):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
fc_size=64,
lstm_state_size=256):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.fc_size = fc_size
self.lstm_state_size = lstm_state_size
# Build the Module from fc + LSTM + 2xfc (action + value outs).
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
self.lstm = nn.LSTM(
self.fc_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
# Store the value output to save an extra forward pass.
self._cur_value = None
@override(ModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
]
return h
@override(ModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(RecurrentTorchModel)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
Returns:
NN Outputs (B x T x ...) as sequence.
The state batches as a List of two items (c- and h-states).
"""
x = nn.functional.relu(self.fc1(inputs))
lstm_out = self.lstm(
x, [torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
action_out = self.action_branch(lstm_out[0])
self._cur_value = torch.reshape(self.value_branch(lstm_out[0]), [-1])
return action_out, [
torch.squeeze(lstm_out[1][0], 0),
torch.squeeze(lstm_out[1][1], 0)
]
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ModelCatalog.register_custom_model("rnn", RNNModel)
tune.register_env(
"repeat_initial", lambda _: RepeatInitialEnv(episode_len=100))
tune.register_env(
"repeat_after_me", lambda _: RepeatAfterMeEnv({"repeat_delay": 1}))
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())
config = {
"env": args.env,
"use_pytorch": True,
"num_workers": 0,
"num_envs_per_worker": 20,
"gamma": 0.9,
"entropy_coeff": 0.0001,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
"lstm_use_prev_action_reward": "store_true",
"custom_options": {
"fc_size": args.fc_size,
"lstm_state_size": args.lstm_cell_size,
}
},
"lr": 3e-4,
"num_sgd_iter": 5,
"vf_loss_coeff": 0.0003,
}
tune.run(
args.run,
stop={
"episode_reward_mean": args.stop,
"timesteps_total": 100000
},
config=config,
)

View file

@ -44,7 +44,15 @@ class ModelV2:
"""Get the initial recurrent state values for the model.
Returns:
list of np.array objects, if any
List[np.ndarray]: List of np.array objects containing the initial
hidden state of an RNN, if applicable.
Examples:
>>> def get_initial_state(self):
>>> return [
>>> np.zeros(self.cell_size, np.float32),
>>> np.zeros(self.cell_size, np.float32),
>>> ]
"""
return []

View file

@ -0,0 +1,14 @@
# TODO(sven): Add once ModelV1 is deprecated and we no longer cause circular
# dependencies b/c of that.
# from ray.rllib.models.tf.tf_modelv2 import TFModelV2
# from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
# from ray.rllib.models.tf.recurrent_tf_modelv2 import \
# RecurrentTFModelV2
# from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
# __all__ = [
# "FullyConnectedNetwork",
# "RecurrentTFModelV2",
# "TFModelV2",
# "VisionNetwork",
# ]

View file

@ -12,44 +12,38 @@ class RecurrentTFModelV2(TFModelV2):
"""Helper class to simplify implementing RNN models with TFModelV2.
Instead of implementing forward(), you can implement forward_rnn() which
takes batches with the time dimension added already."""
takes batches with the time dimension added already.
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initialize a TFModelV2.
Here is an example implementation for a subclass
``MyRNNClass(RecurrentTFModelV2)``::
Here is an example implementation for a subclass
``MyRNNClass(RecurrentTFModelV2)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
cell_size = 256
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
cell_size = 256
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]))
state_in_h = tf.keras.layers.Input(shape=(256, ))
state_in_c = tf.keras.layers.Input(shape=(256, ))
seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]))
state_in_h = tf.keras.layers.Input(shape=(256, ))
state_in_c = tf.keras.layers.Input(shape=(256, ))
seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
# Send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True,
name="lstm")(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
output_layer = tf.keras.layers.Dense(...)(lstm_out)
# Send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True,
name="lstm")(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
output_layer = tf.keras.layers.Dense(...)(lstm_out)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[output_layer, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
"""
TFModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[output_layer, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
"""
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@ -57,7 +51,8 @@ class RecurrentTFModelV2(TFModelV2):
You should implement forward_rnn() in your subclass."""
output, new_state = self.forward_rnn(
add_time_dimension(input_dict["obs_flat"], seq_lens), state,
add_time_dimension(
input_dict["obs_flat"], seq_lens, framework="tf"), state,
seq_lens)
return tf.reshape(output, [-1, self.num_outputs]), new_state

View file

@ -0,0 +1,14 @@
# TODO(sven): Add once ModelV1 is deprecated and we no longer cause circular
# dependencies b/c of that.
# from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
# from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
# from ray.rllib.models.torch.recurrent_torch_model import \
# RecurrentTorchModel
# from ray.rllib.models.torch.visionnet import VisionNetwork
# __all__ = [
# "FullyConnectedNetwork",
# "RecurrentTorchModel",
# "TorchModelV2",
# "VisionNetwork",
# ]

View file

@ -0,0 +1,92 @@
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
@DeveloperAPI
class RecurrentTorchModel(TorchModelV2, nn.Module):
"""Helper class to simplify implementing RNN models with TFModelV2.
Instead of implementing forward(), you can implement forward_rnn() which
takes batches with the time dimension added already.
Here is an example implementation for a subclass
``MyRNNClass(nn.Module, RecurrentTorchModel)``::
def __init__(self, obs_space, num_outputs):
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
self._cur_value = None
@override(ModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [self.fc1.weight.new(
1, self.rnn_hidden_dim).zero_().squeeze(0)]
return h
@override(ModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(RecurrentTorchModel)
def forward_rnn(self, input_dict, state, seq_lens):
x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
h_in = state[0].reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
self._cur_value = self.value_branch(h).squeeze(1)
return q, [h]
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
if isinstance(seq_lens, np.ndarray):
seq_lens = torch.Tensor(seq_lens).int()
output, new_state = self.forward_rnn(
add_time_dimension(
input_dict["obs_flat"].float(), seq_lens, framework="torch"),
state, seq_lens)
return torch.reshape(output, [-1, self.num_outputs]), new_state
def forward_rnn(self, inputs, state, seq_lens):
"""Call the model with the given input tensors and state.
Args:
inputs (dict): Observation tensor with shape [B, T, obs_size].
state (list): List of state tensors, each with shape [B, size].
seq_lens (Tensor): 1D tensor holding input sequence lengths.
Note: len(seq_lens) == B.
Returns:
(outputs, new_state): The model output tensor of shape
[B, T, num_outputs] and the list of new state tensors each with
shape [B, size].
Examples:
def forward_rnn(self, inputs, state, seq_lens):
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
return model_out, [h, c]
"""
raise NotImplementedError("You must implement this for an RNN model")

View file

@ -54,7 +54,8 @@ class TorchCategorical(TorchDistributionWrapper):
@override(ActionDistribution)
def deterministic_sample(self):
return self.dist.probs.argmax(dim=1)
self.last_sample = self.dist.probs.argmax(dim=1)
return self.last_sample
@staticmethod
@override(ActionDistribution)
@ -68,7 +69,8 @@ class TorchMultiCategorical(TorchDistributionWrapper):
@override(TorchDistributionWrapper)
def __init__(self, inputs, model, input_lens):
super().__init__(inputs, model)
inputs_split = self.inputs.split(input_lens, dim=1)
# If input_lens is np.ndarray or list, force-make it a tuple.
inputs_split = self.inputs.split(tuple(input_lens), dim=1)
self.cats = [
torch.distributions.categorical.Categorical(logits=input_)
for input_ in inputs_split
@ -77,14 +79,14 @@ class TorchMultiCategorical(TorchDistributionWrapper):
@override(TorchDistributionWrapper)
def sample(self):
arr = [cat.sample() for cat in self.cats]
ret = torch.stack(arr, dim=1)
return ret
self.last_sample = torch.stack(arr, dim=1)
return self.last_sample
@override(ActionDistribution)
def deterministic_sample(self):
arr = [torch.argmax(cat.probs, -1) for cat in self.cats]
ret = torch.stack(arr, dim=1)
return ret
self.last_sample = torch.stack(arr, dim=1)
return self.last_sample
@override(TorchDistributionWrapper)
def logp(self, actions):
@ -134,7 +136,8 @@ class TorchDiagGaussian(TorchDistributionWrapper):
@override(ActionDistribution)
def deterministic_sample(self):
return self.dist.mean
self.last_sample = self.dist.mean
return self.last_sample
@override(TorchDistributionWrapper)
def logp(self, actions):

View file

@ -11,16 +11,105 @@ meaningfully affect the loss function. This happens to be true for all the
current algorithms: https://github.com/ray-project/ray/issues/2992
"""
import logging
import numpy as np
from ray.util import log_once
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf = try_import_tf()
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
@DeveloperAPI
def add_time_dimension(padded_inputs, seq_lens):
def pad_batch_to_sequences_of_same_size(batch,
max_seq_len,
shuffle=False,
batch_divisibility_req=1,
feature_keys=None):
"""Applies padding to `batch` so it's choppable into same-size sequences.
Shuffles `batch` (if desired), makes sure divisibility requirement is met,
then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
adding a time dimension (yet).
Padding depends on episodes found in batch and `max_seq_len`.
Args:
batch (SampleBatch): The SampleBatch object. All values in here have
the shape [B, ...].
max_seq_len (int): The max. sequence length to use for chopping.
shuffle (bool): Whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
applying minibatch SGD after getting the outputs.
batch_divisibility_req (int): The int by which the batch dimension
must be dividable.
feature_keys (Optional[List[str]]): An optional list of keys to apply
sequence-chopping to. If None, use all keys in batch that are not
"state_in/out_"-type keys.
"""
if batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
# not multiagent
and max(batch[SampleBatch.AGENT_INDEX]) == 0)
else:
meets_divisibility_reqs = True
# RNN-case.
if "state_in_0" in batch:
dynamic_max = True
# Multi-agent case.
elif not meets_divisibility_reqs:
max_seq_len = batch_divisibility_req
dynamic_max = False
# Simple case: not RNN nor do we need to pad.
else:
if shuffle:
batch.shuffle()
return
# RNN or multi-agent case.
state_keys = []
feature_keys_ = feature_keys or []
for k in batch.keys():
if "state_in_" in k:
state_keys.append(k)
elif not feature_keys and "state_out_" not in k and k != "infos":
feature_keys_.append(k)
feature_sequences, initial_states, seq_lens = \
chop_into_sequences(
batch[SampleBatch.EPS_ID],
batch[SampleBatch.UNROLL_ID],
batch[SampleBatch.AGENT_INDEX],
[batch[k] for k in feature_keys_],
[batch[k] for k in state_keys],
max_seq_len,
dynamic_max=dynamic_max,
shuffle=shuffle)
for i, k in enumerate(feature_keys_):
batch[k] = feature_sequences[i]
for i, k in enumerate(state_keys):
batch[k] = initial_states[i]
batch["seq_lens"] = seq_lens
if log_once("rnn_ma_feed_dict"):
logger.info("Padded input for RNN:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
"seq_lens": seq_lens,
"max_seq_len": max_seq_len,
})))
@DeveloperAPI
def add_time_dimension(padded_inputs, seq_lens, framework="tf"):
"""Adds a time dimension to padded inputs.
Arguments:
@ -37,14 +126,24 @@ def add_time_dimension(padded_inputs, seq_lens):
# Sequence lengths have to be specified for LSTM batch inputs. The
# input batch must be padded to the max seq length given here. That is,
# batch_size == len(seq_lens) * max(seq_lens)
padded_batch_size = tf.shape(padded_inputs)[0]
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
if framework == "tf":
padded_batch_size = tf.shape(padded_inputs)[0]
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
new_shape = ([new_batch_size, max_seq_len] +
padded_inputs.get_shape().as_list()[1:])
return tf.reshape(padded_inputs, new_shape)
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
new_shape = ([new_batch_size, max_seq_len] +
padded_inputs.get_shape().as_list()[1:])
return tf.reshape(padded_inputs, new_shape)
else:
assert framework == "torch", "`framework` must be either tf or torch!"
padded_batch_size = padded_inputs.shape[0]
max_seq_len = padded_batch_size // seq_lens.shape[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
return torch.reshape(padded_inputs, new_shape)
@DeveloperAPI

View file

@ -8,7 +8,7 @@ import ray.experimental.tf_utils
from ray.util.debug import log_once
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, \
ACTION_PROB, ACTION_LOGP
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, DeveloperAPI
@ -602,69 +602,27 @@ class TFPolicy(Policy):
return fetches
def _get_loss_inputs_dict(self, batch, shuffle):
"""Return a feed dict from a batch.
Arguments:
batch (SampleBatch): batch of data to derive inputs from
shuffle (bool): whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
applying minibatch SGD after getting the outputs.
Returns:
feed dict of data
"""
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
batch,
shuffle=shuffle,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self._batch_divisibility_req,
feature_keys=[k for k, v in self._loss_inputs])
# Build the feed dict from the batch.
feed_dict = {}
if self._batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) %
self._batch_divisibility_req == 0
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
else:
meets_divisibility_reqs = True
for k, ph in self._loss_inputs:
feed_dict[ph] = batch[k]
# Simple case: not RNN nor do we need to pad
if not self._state_inputs and meets_divisibility_reqs:
if shuffle:
batch.shuffle()
for k, ph in self._loss_inputs:
feed_dict[ph] = batch[k]
return feed_dict
if self._state_inputs:
max_seq_len = self._max_seq_len
dynamic_max = True
else:
max_seq_len = self._batch_divisibility_req
dynamic_max = False
# RNN or multi-agent case
feature_keys = [k for k, v in self._loss_inputs]
state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs))
]
feature_sequences, initial_states, seq_lens = chop_into_sequences(
batch[SampleBatch.EPS_ID],
batch[SampleBatch.UNROLL_ID],
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
[batch[k] for k in state_keys],
max_seq_len,
dynamic_max=dynamic_max,
shuffle=shuffle)
for k, v in zip(feature_keys, feature_sequences):
feed_dict[self._loss_input_dict[k]] = v
for k, v in zip(state_keys, initial_states):
feed_dict[self._loss_input_dict[k]] = v
feed_dict[self._seq_lens] = seq_lens
for k in state_keys:
feed_dict[self._loss_input_dict[k]] = batch[k]
if state_keys:
feed_dict[self._seq_lens] = batch["seq_lens"]
if log_once("rnn_feed_dict"):
logger.info("Padded input for RNN:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
"seq_lens": seq_lens,
"max_seq_len": max_seq_len,
})))
return feed_dict

View file

@ -4,6 +4,7 @@ import time
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \
ACTION_LOGP
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
@ -26,8 +27,15 @@ class TorchPolicy(Policy):
dist_class (type): Torch action distribution class.
"""
def __init__(self, observation_space, action_space, config, model, loss,
action_distribution_class):
def __init__(self,
observation_space,
action_space,
config,
model,
loss,
action_distribution_class,
max_seq_len=20,
get_batch_divisibility_req=None):
"""Build a policy from policy and loss torch modules.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
@ -44,6 +52,9 @@ class TorchPolicy(Policy):
train_batch) and returns a single scalar loss.
action_distribution_class (ActionDistribution): Class for action
distribution.
max_seq_len (int): Max sequence length for LSTM training.
get_batch_divisibility_req (Optional[callable]): Optional callable
that returns the divisibility requirement for sample batches.
"""
self.framework = "torch"
super().__init__(observation_space, action_space, config)
@ -58,6 +69,11 @@ class TorchPolicy(Policy):
# If set, means we are using distributed allreduce during learning.
self.distributed_world_size = None
self.max_seq_len = max_seq_len
self.batch_divisibility_req = \
get_batch_divisibility_req(self) if get_batch_divisibility_req \
else 1
@override(Policy)
def compute_actions(self,
obs_batch,
@ -72,6 +88,7 @@ class TorchPolicy(Policy):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
with torch.no_grad():
input_dict = self._lazy_tensor_dict({
@ -86,8 +103,7 @@ class TorchPolicy(Policy):
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(timestep=timestep)
model_out = self.model(input_dict, state_batches,
self._convert_to_tensor([1]))
model_out = self.model(input_dict, state_batches, seq_lens)
logits, state = model_out
action_dist = None
actions, logp = \
@ -115,6 +131,7 @@ class TorchPolicy(Policy):
prev_action_batch=None,
prev_reward_batch=None):
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_batch,
SampleBatch.ACTIONS: actions
@ -124,15 +141,21 @@ class TorchPolicy(Policy):
if prev_reward_batch:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
parameters, _ = self.model(input_dict, state_batches, [1])
parameters, _ = self.model(input_dict, state_batches, seq_lens)
action_dist = self.dist_class(parameters, self.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
return log_likelihoods
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
train_batch = self._lazy_tensor_dict(postprocessed_batch)
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req)
train_batch = self._lazy_tensor_dict(postprocessed_batch)
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
self._optimizer.zero_grad()
loss_out.backward()

View file

@ -22,7 +22,8 @@ def build_torch_policy(name,
before_init=None,
after_init=None,
make_model_and_action_dist=None,
mixins=None):
mixins=None,
get_batch_divisibility_req=None):
"""Helper function for creating a torch policy at runtime.
Arguments:
@ -52,6 +53,8 @@ def build_torch_policy(name,
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the TorchPolicy class
get_batch_divisibility_req (Optional[callable]): Optional callable that
returns the divisibility requirement for sample batches.
Returns:
a TorchPolicy instance that uses the specified args
@ -80,14 +83,24 @@ def build_torch_policy(name,
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], framework="torch")
self.model = ModelCatalog.get_model_v2(
obs_space,
action_space,
logit_dim,
self.config["model"],
framework="torch")
obs_space=obs_space,
action_space=action_space,
num_outputs=logit_dim,
model_config=self.config["model"],
framework="torch",
**self.config["model"].get("custom_options", {}))
TorchPolicy.__init__(self, obs_space, action_space, config,
self.model, loss_fn, self.dist_class)
TorchPolicy.__init__(
self,
obs_space,
action_space,
config,
model=self.model,
loss=loss_fn,
action_distribution_class=self.dist_class,
max_seq_len=config["model"]["max_seq_len"],
get_batch_divisibility_req=get_batch_divisibility_req,
)
if after_init:
after_init(self, obs_space, action_space, config)

View file

@ -1,7 +1,8 @@
import numpy as np
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf = try_import_tf()
torch, _ = try_import_torch()
SMALL_NUMBER = 1e-6
@ -123,9 +124,19 @@ def fc(x, weights, biases=None):
Returns:
The dense layer's output.
"""
# Torch stores matrices in transpose (faster for backprop).
if torch and isinstance(weights, torch.Tensor):
weights = np.transpose(weights.numpy())
if torch: # and isinstance(weights, torch.Tensor):
x = x.detach().numpy() if isinstance(x, torch.Tensor) else x
weights = np.transpose(weights.detach().numpy()) if \
isinstance(weights, torch.Tensor) else weights
biases = biases.detach().numpy() if \
isinstance(biases, torch.Tensor) else biases
if tf:
x = x.numpy() if isinstance(x, tf.Variable) else x
weights = weights.numpy() if isinstance(weights, tf.Variable) else \
weights
biases = biases.numpy() if isinstance(biases, tf.Variable) else biases
return np.matmul(x, weights) + (0.0 if biases is None else biases)

View file

@ -5,6 +5,7 @@ import logging
from collections import defaultdict
import random
from ray.util import log_once
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
@ -63,7 +64,8 @@ def minibatches(samples, sgd_minibatch_size):
"Minibatching not implemented for multi-agent in simple mode")
if "state_in_0" in samples.data:
logger.warning("Not shuffling RNN data for SGD in simple mode")
if log_once("not_shuffling_rnn_data_in_simple_mode"):
logger.warning("Not shuffling RNN data for SGD in simple mode")
else:
samples.shuffle()