mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Working/learning example: PPO + torch + LSTM. (#7797)
This commit is contained in:
parent
c23e56ce9a
commit
66df8b8c35
17 changed files with 578 additions and 213 deletions
|
@ -1266,6 +1266,15 @@ py_test(
|
|||
args = ["--iters=2", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_torch_rnn_model",
|
||||
main = "examples/custom_torch_rnn_model.py",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "medium",
|
||||
srcs = ["examples/custom_torch_rnn_model.py"],
|
||||
args = ["--run=PPO", "--stop=90", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_torch_policy",
|
||||
tags = ["examples", "examples_C"],
|
||||
|
|
|
@ -68,9 +68,10 @@ class PPOLoss:
|
|||
use_gae (bool): If true, use the Generalized Advantage Estimator.
|
||||
"""
|
||||
if valid_mask is not None:
|
||||
num_valid = torch.sum(valid_mask)
|
||||
|
||||
def reduce_mean_valid(t):
|
||||
return torch.mean(t * valid_mask)
|
||||
return torch.sum(t * valid_mask) / num_valid
|
||||
|
||||
else:
|
||||
|
||||
|
@ -190,14 +191,14 @@ class ValueNetworkMixin:
|
|||
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: torch.Tensor([ob]).to(self.device),
|
||||
SampleBatch.PREV_ACTIONS: torch.Tensor([prev_action]).to(
|
||||
self.device),
|
||||
SampleBatch.PREV_REWARDS: torch.Tensor([prev_reward]).to(
|
||||
self.device),
|
||||
SampleBatch.CUR_OBS: self._convert_to_tensor([ob]),
|
||||
SampleBatch.PREV_ACTIONS: self._convert_to_tensor(
|
||||
[prev_action]),
|
||||
SampleBatch.PREV_REWARDS: self._convert_to_tensor(
|
||||
[prev_reward]),
|
||||
"is_training": False,
|
||||
}, [torch.Tensor([s]).to(self.device) for s in state],
|
||||
torch.Tensor([1]).to(self.device))
|
||||
}, [self._convert_to_tensor(s) for s in state],
|
||||
self._convert_to_tensor([1]))
|
||||
return self.model.value_function()[0]
|
||||
|
||||
else:
|
||||
|
|
|
@ -94,10 +94,10 @@ class TestPPO(unittest.TestCase):
|
|||
"""Tests the PPO loss function math."""
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["eager"] = True
|
||||
config["gamma"] = 0.99
|
||||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["model"]["fcnet_activation"] = "linear"
|
||||
config["vf_share_layers"] = True
|
||||
|
||||
# Fake CartPole episode of n time steps.
|
||||
train_batch = {
|
||||
|
@ -114,46 +114,21 @@ class TestPPO(unittest.TestCase):
|
|||
ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32)
|
||||
}
|
||||
|
||||
# tf.
|
||||
for fw in ["tf", "torch"]:
|
||||
print("framework={}".format(fw))
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
config["eager"] = fw == "tf"
|
||||
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
|
||||
# Post-process (calculate simple (non-GAE) advantages) and attach to
|
||||
# train_batch dict.
|
||||
# Post-process (calculate simple (non-GAE) advantages) and attach
|
||||
# to train_batch dict.
|
||||
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
||||
# [0.50005, -0.505, 0.5]
|
||||
if fw == "tf":
|
||||
train_batch = postprocess_ppo_gae_tf(policy, train_batch)
|
||||
# Check Advantage values.
|
||||
check(train_batch[Postprocessing.VALUE_TARGETS],
|
||||
[0.50005, -0.505, 0.5])
|
||||
|
||||
# Calculate actual PPO loss (results are stored in policy.loss_obj) for
|
||||
# tf.
|
||||
ppo_surrogate_loss_tf(policy, policy.model, Categorical, train_batch)
|
||||
|
||||
vars = policy.model.trainable_variables()
|
||||
expected_logits = fc(
|
||||
fc(train_batch[SampleBatch.CUR_OBS], vars[0].numpy(),
|
||||
vars[1].numpy()), vars[4].numpy(), vars[5].numpy())
|
||||
expected_value_outs = fc(
|
||||
fc(train_batch[SampleBatch.CUR_OBS], vars[2].numpy(),
|
||||
vars[3].numpy()), vars[6].numpy(), vars[7].numpy())
|
||||
|
||||
kl, entropy, pg_loss, vf_loss, overall_loss = \
|
||||
self._ppo_loss_helper(
|
||||
policy, policy.model, Categorical, train_batch,
|
||||
expected_logits, expected_value_outs
|
||||
)
|
||||
check(policy.loss_obj.mean_kl, kl)
|
||||
check(policy.loss_obj.mean_entropy, entropy)
|
||||
check(policy.loss_obj.mean_policy_loss, np.mean(-pg_loss))
|
||||
check(policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
|
||||
check(policy.loss_obj.loss, overall_loss, decimals=4)
|
||||
|
||||
# Torch.
|
||||
config["use_pytorch"] = True
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
else:
|
||||
train_batch = postprocess_ppo_gae_torch(policy, train_batch)
|
||||
train_batch = policy._lazy_tensor_dict(train_batch)
|
||||
|
||||
|
@ -163,14 +138,26 @@ class TestPPO(unittest.TestCase):
|
|||
|
||||
# Calculate actual PPO loss (results are stored in policy.loss_obj)
|
||||
# for tf.
|
||||
ppo_surrogate_loss_torch(policy, policy.model, TorchCategorical,
|
||||
if fw == "tf":
|
||||
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
|
||||
train_batch)
|
||||
else:
|
||||
ppo_surrogate_loss_torch(policy, policy.model,
|
||||
TorchCategorical, train_batch)
|
||||
|
||||
vars = policy.model.variables() if fw == "tf" else \
|
||||
list(policy.model.parameters())
|
||||
expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS], vars[0],
|
||||
vars[1])
|
||||
expected_logits = fc(expected_shared_out, vars[2], vars[3])
|
||||
expected_value_outs = fc(expected_shared_out, vars[4], vars[5])
|
||||
|
||||
kl, entropy, pg_loss, vf_loss, overall_loss = \
|
||||
self._ppo_loss_helper(
|
||||
policy, policy.model, TorchCategorical, train_batch,
|
||||
policy.model.last_output(),
|
||||
policy.model.value_function().detach().numpy()
|
||||
policy, policy.model,
|
||||
Categorical if fw == "tf" else TorchCategorical,
|
||||
train_batch,
|
||||
expected_logits, expected_value_outs
|
||||
)
|
||||
check(policy.loss_obj.mean_kl, kl)
|
||||
check(policy.loss_obj.mean_entropy, entropy)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
"""Example of using a custom RNN keras model."""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
from gym.spaces import Discrete
|
||||
import numpy as np
|
||||
import random
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
@ -89,13 +89,17 @@ class MyKerasRNN(RecurrentTFModelV2):
|
|||
|
||||
|
||||
class RepeatInitialEnv(gym.Env):
|
||||
"""Simple env in which the policy learns to repeat the initial observation
|
||||
seen at timestep 0."""
|
||||
"""Simple env where policy has to always repeat the initial observation.
|
||||
|
||||
def __init__(self):
|
||||
Runs for 100 steps.
|
||||
r=1 if action correct, -1 otherwise (max. R=100).
|
||||
"""
|
||||
|
||||
def __init__(self, episode_len=100):
|
||||
self.observation_space = Discrete(2)
|
||||
self.action_space = Discrete(2)
|
||||
self.token = None
|
||||
self.episode_len = episode_len
|
||||
self.num_steps = 0
|
||||
|
||||
def reset(self):
|
||||
|
@ -109,7 +113,7 @@ class RepeatInitialEnv(gym.Env):
|
|||
else:
|
||||
reward = -1
|
||||
self.num_steps += 1
|
||||
done = self.num_steps > 100
|
||||
done = self.num_steps >= self.episode_len
|
||||
return 0, reward, done, {}
|
||||
|
||||
|
||||
|
@ -148,10 +152,8 @@ if __name__ == "__main__":
|
|||
ModelCatalog.register_custom_model("rnn", MyKerasRNN)
|
||||
register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
|
||||
register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={"episode_reward_mean": args.stop},
|
||||
config={
|
||||
|
||||
config = {
|
||||
"env": args.env,
|
||||
"env_config": {
|
||||
"repeat_delay": 2,
|
||||
|
@ -166,4 +168,10 @@ if __name__ == "__main__":
|
|||
"custom_model": "rnn",
|
||||
"max_seq_len": 20,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
tune.run(
|
||||
args.run,
|
||||
config=config,
|
||||
stop={"episode_reward_mean": args.stop},
|
||||
)
|
||||
|
|
128
rllib/examples/custom_torch_rnn_model.py
Normal file
128
rllib/examples/custom_torch_rnn_model.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import argparse
|
||||
|
||||
import ray
|
||||
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
|
||||
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv, \
|
||||
RepeatAfterMeEnv
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.models import ModelCatalog
|
||||
import ray.tune as tune
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument("--env", type=str, default="repeat_initial")
|
||||
parser.add_argument("--stop", type=int, default=90)
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument("--fc-size", type=int, default=64)
|
||||
parser.add_argument("--lstm-cell-size", type=int, default=256)
|
||||
|
||||
|
||||
class RNNModel(RecurrentTorchModel):
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
fc_size=64,
|
||||
lstm_state_size=256):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
self.obs_size = get_preprocessor(obs_space)(obs_space).size
|
||||
self.fc_size = fc_size
|
||||
self.lstm_state_size = lstm_state_size
|
||||
|
||||
# Build the Module from fc + LSTM + 2xfc (action + value outs).
|
||||
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
|
||||
self.lstm = nn.LSTM(
|
||||
self.fc_size, self.lstm_state_size, batch_first=True)
|
||||
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
|
||||
self.value_branch = nn.Linear(self.lstm_state_size, 1)
|
||||
# Store the value output to save an extra forward pass.
|
||||
self._cur_value = None
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
# make hidden states on same device as model
|
||||
h = [
|
||||
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
|
||||
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
|
||||
]
|
||||
return h
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
|
||||
@override(RecurrentTorchModel)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
|
||||
|
||||
Returns the resulting outputs as a sequence (B x T x ...).
|
||||
Values are stored in self._cur_value in simple (B) shape (where B
|
||||
contains both the B and T dims!).
|
||||
|
||||
Returns:
|
||||
NN Outputs (B x T x ...) as sequence.
|
||||
The state batches as a List of two items (c- and h-states).
|
||||
"""
|
||||
x = nn.functional.relu(self.fc1(inputs))
|
||||
lstm_out = self.lstm(
|
||||
x, [torch.unsqueeze(state[0], 0),
|
||||
torch.unsqueeze(state[1], 0)])
|
||||
action_out = self.action_branch(lstm_out[0])
|
||||
self._cur_value = torch.reshape(self.value_branch(lstm_out[0]), [-1])
|
||||
return action_out, [
|
||||
torch.squeeze(lstm_out[1][0], 0),
|
||||
torch.squeeze(lstm_out[1][1], 0)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
ModelCatalog.register_custom_model("rnn", RNNModel)
|
||||
tune.register_env(
|
||||
"repeat_initial", lambda _: RepeatInitialEnv(episode_len=100))
|
||||
tune.register_env(
|
||||
"repeat_after_me", lambda _: RepeatAfterMeEnv({"repeat_delay": 1}))
|
||||
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())
|
||||
|
||||
config = {
|
||||
"env": args.env,
|
||||
"use_pytorch": True,
|
||||
"num_workers": 0,
|
||||
"num_envs_per_worker": 20,
|
||||
"gamma": 0.9,
|
||||
"entropy_coeff": 0.0001,
|
||||
"model": {
|
||||
"custom_model": "rnn",
|
||||
"max_seq_len": 20,
|
||||
"lstm_use_prev_action_reward": "store_true",
|
||||
"custom_options": {
|
||||
"fc_size": args.fc_size,
|
||||
"lstm_state_size": args.lstm_cell_size,
|
||||
}
|
||||
},
|
||||
"lr": 3e-4,
|
||||
"num_sgd_iter": 5,
|
||||
"vf_loss_coeff": 0.0003,
|
||||
}
|
||||
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={
|
||||
"episode_reward_mean": args.stop,
|
||||
"timesteps_total": 100000
|
||||
},
|
||||
config=config,
|
||||
)
|
|
@ -44,7 +44,15 @@ class ModelV2:
|
|||
"""Get the initial recurrent state values for the model.
|
||||
|
||||
Returns:
|
||||
list of np.array objects, if any
|
||||
List[np.ndarray]: List of np.array objects containing the initial
|
||||
hidden state of an RNN, if applicable.
|
||||
|
||||
Examples:
|
||||
>>> def get_initial_state(self):
|
||||
>>> return [
|
||||
>>> np.zeros(self.cell_size, np.float32),
|
||||
>>> np.zeros(self.cell_size, np.float32),
|
||||
>>> ]
|
||||
"""
|
||||
return []
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# TODO(sven): Add once ModelV1 is deprecated and we no longer cause circular
|
||||
# dependencies b/c of that.
|
||||
# from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
# from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
|
||||
# from ray.rllib.models.tf.recurrent_tf_modelv2 import \
|
||||
# RecurrentTFModelV2
|
||||
# from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
|
||||
|
||||
# __all__ = [
|
||||
# "FullyConnectedNetwork",
|
||||
# "RecurrentTFModelV2",
|
||||
# "TFModelV2",
|
||||
# "VisionNetwork",
|
||||
# ]
|
|
@ -12,11 +12,7 @@ class RecurrentTFModelV2(TFModelV2):
|
|||
"""Helper class to simplify implementing RNN models with TFModelV2.
|
||||
|
||||
Instead of implementing forward(), you can implement forward_rnn() which
|
||||
takes batches with the time dimension added already."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
"""Initialize a TFModelV2.
|
||||
takes batches with the time dimension added already.
|
||||
|
||||
Here is an example implementation for a subclass
|
||||
``MyRNNClass(RecurrentTFModelV2)``::
|
||||
|
@ -48,8 +44,6 @@ class RecurrentTFModelV2(TFModelV2):
|
|||
self.register_variables(self.rnn_model.variables)
|
||||
self.rnn_model.summary()
|
||||
"""
|
||||
TFModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
@ -57,7 +51,8 @@ class RecurrentTFModelV2(TFModelV2):
|
|||
|
||||
You should implement forward_rnn() in your subclass."""
|
||||
output, new_state = self.forward_rnn(
|
||||
add_time_dimension(input_dict["obs_flat"], seq_lens), state,
|
||||
add_time_dimension(
|
||||
input_dict["obs_flat"], seq_lens, framework="tf"), state,
|
||||
seq_lens)
|
||||
return tf.reshape(output, [-1, self.num_outputs]), new_state
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# TODO(sven): Add once ModelV1 is deprecated and we no longer cause circular
|
||||
# dependencies b/c of that.
|
||||
# from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
# from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
|
||||
# from ray.rllib.models.torch.recurrent_torch_model import \
|
||||
# RecurrentTorchModel
|
||||
# from ray.rllib.models.torch.visionnet import VisionNetwork
|
||||
|
||||
# __all__ = [
|
||||
# "FullyConnectedNetwork",
|
||||
# "RecurrentTorchModel",
|
||||
# "TorchModelV2",
|
||||
# "VisionNetwork",
|
||||
# ]
|
92
rllib/models/torch/recurrent_torch_model.py
Normal file
92
rllib/models/torch/recurrent_torch_model.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import numpy as np
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class RecurrentTorchModel(TorchModelV2, nn.Module):
|
||||
"""Helper class to simplify implementing RNN models with TFModelV2.
|
||||
|
||||
Instead of implementing forward(), you can implement forward_rnn() which
|
||||
takes batches with the time dimension added already.
|
||||
|
||||
Here is an example implementation for a subclass
|
||||
``MyRNNClass(nn.Module, RecurrentTorchModel)``::
|
||||
|
||||
def __init__(self, obs_space, num_outputs):
|
||||
self.obs_size = _get_size(obs_space)
|
||||
self.rnn_hidden_dim = model_config["lstm_cell_size"]
|
||||
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
|
||||
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
|
||||
|
||||
self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
|
||||
self._cur_value = None
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
# make hidden states on same device as model
|
||||
h = [self.fc1.weight.new(
|
||||
1, self.rnn_hidden_dim).zero_().squeeze(0)]
|
||||
return h
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
|
||||
@override(RecurrentTorchModel)
|
||||
def forward_rnn(self, input_dict, state, seq_lens):
|
||||
x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
|
||||
h_in = state[0].reshape(-1, self.rnn_hidden_dim)
|
||||
h = self.rnn(x, h_in)
|
||||
q = self.fc2(h)
|
||||
self._cur_value = self.value_branch(h).squeeze(1)
|
||||
return q, [h]
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""Adds time dimension to batch before sending inputs to forward_rnn().
|
||||
|
||||
You should implement forward_rnn() in your subclass."""
|
||||
if isinstance(seq_lens, np.ndarray):
|
||||
seq_lens = torch.Tensor(seq_lens).int()
|
||||
output, new_state = self.forward_rnn(
|
||||
add_time_dimension(
|
||||
input_dict["obs_flat"].float(), seq_lens, framework="torch"),
|
||||
state, seq_lens)
|
||||
return torch.reshape(output, [-1, self.num_outputs]), new_state
|
||||
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
Args:
|
||||
inputs (dict): Observation tensor with shape [B, T, obs_size].
|
||||
state (list): List of state tensors, each with shape [B, size].
|
||||
seq_lens (Tensor): 1D tensor holding input sequence lengths.
|
||||
Note: len(seq_lens) == B.
|
||||
|
||||
Returns:
|
||||
(outputs, new_state): The model output tensor of shape
|
||||
[B, T, num_outputs] and the list of new state tensors each with
|
||||
shape [B, size].
|
||||
|
||||
Examples:
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
|
||||
return model_out, [h, c]
|
||||
"""
|
||||
raise NotImplementedError("You must implement this for an RNN model")
|
|
@ -54,7 +54,8 @@ class TorchCategorical(TorchDistributionWrapper):
|
|||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
return self.dist.probs.argmax(dim=1)
|
||||
self.last_sample = self.dist.probs.argmax(dim=1)
|
||||
return self.last_sample
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
|
@ -68,7 +69,8 @@ class TorchMultiCategorical(TorchDistributionWrapper):
|
|||
@override(TorchDistributionWrapper)
|
||||
def __init__(self, inputs, model, input_lens):
|
||||
super().__init__(inputs, model)
|
||||
inputs_split = self.inputs.split(input_lens, dim=1)
|
||||
# If input_lens is np.ndarray or list, force-make it a tuple.
|
||||
inputs_split = self.inputs.split(tuple(input_lens), dim=1)
|
||||
self.cats = [
|
||||
torch.distributions.categorical.Categorical(logits=input_)
|
||||
for input_ in inputs_split
|
||||
|
@ -77,14 +79,14 @@ class TorchMultiCategorical(TorchDistributionWrapper):
|
|||
@override(TorchDistributionWrapper)
|
||||
def sample(self):
|
||||
arr = [cat.sample() for cat in self.cats]
|
||||
ret = torch.stack(arr, dim=1)
|
||||
return ret
|
||||
self.last_sample = torch.stack(arr, dim=1)
|
||||
return self.last_sample
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
arr = [torch.argmax(cat.probs, -1) for cat in self.cats]
|
||||
ret = torch.stack(arr, dim=1)
|
||||
return ret
|
||||
self.last_sample = torch.stack(arr, dim=1)
|
||||
return self.last_sample
|
||||
|
||||
@override(TorchDistributionWrapper)
|
||||
def logp(self, actions):
|
||||
|
@ -134,7 +136,8 @@ class TorchDiagGaussian(TorchDistributionWrapper):
|
|||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
return self.dist.mean
|
||||
self.last_sample = self.dist.mean
|
||||
return self.last_sample
|
||||
|
||||
@override(TorchDistributionWrapper)
|
||||
def logp(self, actions):
|
||||
|
|
|
@ -11,16 +11,105 @@ meaningfully affect the loss function. This happens to be true for all the
|
|||
current algorithms: https://github.com/ray-project/ray/issues/2992
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.util import log_once
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def add_time_dimension(padded_inputs, seq_lens):
|
||||
def pad_batch_to_sequences_of_same_size(batch,
|
||||
max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=1,
|
||||
feature_keys=None):
|
||||
"""Applies padding to `batch` so it's choppable into same-size sequences.
|
||||
|
||||
Shuffles `batch` (if desired), makes sure divisibility requirement is met,
|
||||
then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
|
||||
adding a time dimension (yet).
|
||||
Padding depends on episodes found in batch and `max_seq_len`.
|
||||
|
||||
Args:
|
||||
batch (SampleBatch): The SampleBatch object. All values in here have
|
||||
the shape [B, ...].
|
||||
max_seq_len (int): The max. sequence length to use for chopping.
|
||||
shuffle (bool): Whether to shuffle batch sequences. Shuffle may
|
||||
be done in-place. This only makes sense if you're further
|
||||
applying minibatch SGD after getting the outputs.
|
||||
batch_divisibility_req (int): The int by which the batch dimension
|
||||
must be dividable.
|
||||
feature_keys (Optional[List[str]]): An optional list of keys to apply
|
||||
sequence-chopping to. If None, use all keys in batch that are not
|
||||
"state_in/out_"-type keys.
|
||||
"""
|
||||
if batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
|
||||
# not multiagent
|
||||
and max(batch[SampleBatch.AGENT_INDEX]) == 0)
|
||||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
# RNN-case.
|
||||
if "state_in_0" in batch:
|
||||
dynamic_max = True
|
||||
# Multi-agent case.
|
||||
elif not meets_divisibility_reqs:
|
||||
max_seq_len = batch_divisibility_req
|
||||
dynamic_max = False
|
||||
# Simple case: not RNN nor do we need to pad.
|
||||
else:
|
||||
if shuffle:
|
||||
batch.shuffle()
|
||||
return
|
||||
|
||||
# RNN or multi-agent case.
|
||||
state_keys = []
|
||||
feature_keys_ = feature_keys or []
|
||||
for k in batch.keys():
|
||||
if "state_in_" in k:
|
||||
state_keys.append(k)
|
||||
elif not feature_keys and "state_out_" not in k and k != "infos":
|
||||
feature_keys_.append(k)
|
||||
|
||||
feature_sequences, initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.UNROLL_ID],
|
||||
batch[SampleBatch.AGENT_INDEX],
|
||||
[batch[k] for k in feature_keys_],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max,
|
||||
shuffle=shuffle)
|
||||
for i, k in enumerate(feature_keys_):
|
||||
batch[k] = feature_sequences[i]
|
||||
for i, k in enumerate(state_keys):
|
||||
batch[k] = initial_states[i]
|
||||
batch["seq_lens"] = seq_lens
|
||||
|
||||
if log_once("rnn_ma_feed_dict"):
|
||||
logger.info("Padded input for RNN:\n\n{}\n".format(
|
||||
summarize({
|
||||
"features": feature_sequences,
|
||||
"initial_states": initial_states,
|
||||
"seq_lens": seq_lens,
|
||||
"max_seq_len": max_seq_len,
|
||||
})))
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def add_time_dimension(padded_inputs, seq_lens, framework="tf"):
|
||||
"""Adds a time dimension to padded inputs.
|
||||
|
||||
Arguments:
|
||||
|
@ -37,6 +126,7 @@ def add_time_dimension(padded_inputs, seq_lens):
|
|||
# Sequence lengths have to be specified for LSTM batch inputs. The
|
||||
# input batch must be padded to the max seq length given here. That is,
|
||||
# batch_size == len(seq_lens) * max(seq_lens)
|
||||
if framework == "tf":
|
||||
padded_batch_size = tf.shape(padded_inputs)[0]
|
||||
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
|
||||
|
||||
|
@ -45,6 +135,15 @@ def add_time_dimension(padded_inputs, seq_lens):
|
|||
new_shape = ([new_batch_size, max_seq_len] +
|
||||
padded_inputs.get_shape().as_list()[1:])
|
||||
return tf.reshape(padded_inputs, new_shape)
|
||||
else:
|
||||
assert framework == "torch", "`framework` must be either tf or torch!"
|
||||
padded_batch_size = padded_inputs.shape[0]
|
||||
max_seq_len = padded_batch_size // seq_lens.shape[0]
|
||||
|
||||
# Dynamically reshape the padded batch to introduce a time dimension.
|
||||
new_batch_size = padded_batch_size // max_seq_len
|
||||
new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
|
||||
return torch.reshape(padded_inputs, new_shape)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -8,7 +8,7 @@ import ray.experimental.tf_utils
|
|||
from ray.util.debug import log_once
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, \
|
||||
ACTION_PROB, ACTION_LOGP
|
||||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
@ -602,69 +602,27 @@ class TFPolicy(Policy):
|
|||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch, shuffle):
|
||||
"""Return a feed dict from a batch.
|
||||
|
||||
Arguments:
|
||||
batch (SampleBatch): batch of data to derive inputs from
|
||||
shuffle (bool): whether to shuffle batch sequences. Shuffle may
|
||||
be done in-place. This only makes sense if you're further
|
||||
applying minibatch SGD after getting the outputs.
|
||||
|
||||
Returns:
|
||||
feed dict of data
|
||||
"""
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
batch,
|
||||
shuffle=shuffle,
|
||||
max_seq_len=self._max_seq_len,
|
||||
batch_divisibility_req=self._batch_divisibility_req,
|
||||
feature_keys=[k for k, v in self._loss_inputs])
|
||||
|
||||
# Build the feed dict from the batch.
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch[SampleBatch.CUR_OBS]) %
|
||||
self._batch_divisibility_req == 0
|
||||
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
|
||||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
# Simple case: not RNN nor do we need to pad
|
||||
if not self._state_inputs and meets_divisibility_reqs:
|
||||
if shuffle:
|
||||
batch.shuffle()
|
||||
for k, ph in self._loss_inputs:
|
||||
feed_dict[ph] = batch[k]
|
||||
return feed_dict
|
||||
|
||||
if self._state_inputs:
|
||||
max_seq_len = self._max_seq_len
|
||||
dynamic_max = True
|
||||
else:
|
||||
max_seq_len = self._batch_divisibility_req
|
||||
dynamic_max = False
|
||||
|
||||
# RNN or multi-agent case
|
||||
feature_keys = [k for k, v in self._loss_inputs]
|
||||
state_keys = [
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.UNROLL_ID],
|
||||
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max,
|
||||
shuffle=shuffle)
|
||||
for k, v in zip(feature_keys, feature_sequences):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
for k, v in zip(state_keys, initial_states):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
feed_dict[self._seq_lens] = seq_lens
|
||||
for k in state_keys:
|
||||
feed_dict[self._loss_input_dict[k]] = batch[k]
|
||||
if state_keys:
|
||||
feed_dict[self._seq_lens] = batch["seq_lens"]
|
||||
|
||||
if log_once("rnn_feed_dict"):
|
||||
logger.info("Padded input for RNN:\n\n{}\n".format(
|
||||
summarize({
|
||||
"features": feature_sequences,
|
||||
"initial_states": initial_states,
|
||||
"seq_lens": seq_lens,
|
||||
"max_seq_len": max_seq_len,
|
||||
})))
|
||||
return feed_dict
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import time
|
|||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \
|
||||
ACTION_LOGP
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
|
@ -26,8 +27,15 @@ class TorchPolicy(Policy):
|
|||
dist_class (type): Torch action distribution class.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, config, model, loss,
|
||||
action_distribution_class):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
model,
|
||||
loss,
|
||||
action_distribution_class,
|
||||
max_seq_len=20,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Build a policy from policy and loss torch modules.
|
||||
|
||||
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
||||
|
@ -44,6 +52,9 @@ class TorchPolicy(Policy):
|
|||
train_batch) and returns a single scalar loss.
|
||||
action_distribution_class (ActionDistribution): Class for action
|
||||
distribution.
|
||||
max_seq_len (int): Max sequence length for LSTM training.
|
||||
get_batch_divisibility_req (Optional[callable]): Optional callable
|
||||
that returns the divisibility requirement for sample batches.
|
||||
"""
|
||||
self.framework = "torch"
|
||||
super().__init__(observation_space, action_space, config)
|
||||
|
@ -58,6 +69,11 @@ class TorchPolicy(Policy):
|
|||
# If set, means we are using distributed allreduce during learning.
|
||||
self.distributed_world_size = None
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.batch_divisibility_req = \
|
||||
get_batch_divisibility_req(self) if get_batch_divisibility_req \
|
||||
else 1
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
|
@ -72,6 +88,7 @@ class TorchPolicy(Policy):
|
|||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
||||
|
||||
with torch.no_grad():
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
|
@ -86,8 +103,7 @@ class TorchPolicy(Policy):
|
|||
# Call the exploration before_compute_actions hook.
|
||||
self.exploration.before_compute_actions(timestep=timestep)
|
||||
|
||||
model_out = self.model(input_dict, state_batches,
|
||||
self._convert_to_tensor([1]))
|
||||
model_out = self.model(input_dict, state_batches, seq_lens)
|
||||
logits, state = model_out
|
||||
action_dist = None
|
||||
actions, logp = \
|
||||
|
@ -115,6 +131,7 @@ class TorchPolicy(Policy):
|
|||
prev_action_batch=None,
|
||||
prev_reward_batch=None):
|
||||
with torch.no_grad():
|
||||
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
SampleBatch.CUR_OBS: obs_batch,
|
||||
SampleBatch.ACTIONS: actions
|
||||
|
@ -124,15 +141,21 @@ class TorchPolicy(Policy):
|
|||
if prev_reward_batch:
|
||||
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
||||
|
||||
parameters, _ = self.model(input_dict, state_batches, [1])
|
||||
parameters, _ = self.model(input_dict, state_batches, seq_lens)
|
||||
action_dist = self.dist_class(parameters, self.model)
|
||||
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
||||
return log_likelihoods
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
max_seq_len=self.max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=self.batch_divisibility_req)
|
||||
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
|
|
@ -22,7 +22,8 @@ def build_torch_policy(name,
|
|||
before_init=None,
|
||||
after_init=None,
|
||||
make_model_and_action_dist=None,
|
||||
mixins=None):
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Helper function for creating a torch policy at runtime.
|
||||
|
||||
Arguments:
|
||||
|
@ -52,6 +53,8 @@ def build_torch_policy(name,
|
|||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the TorchPolicy class
|
||||
get_batch_divisibility_req (Optional[callable]): Optional callable that
|
||||
returns the divisibility requirement for sample batches.
|
||||
|
||||
Returns:
|
||||
a TorchPolicy instance that uses the specified args
|
||||
|
@ -80,14 +83,24 @@ def build_torch_policy(name,
|
|||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], framework="torch")
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=logit_dim,
|
||||
model_config=self.config["model"],
|
||||
framework="torch",
|
||||
**self.config["model"].get("custom_options", {}))
|
||||
|
||||
TorchPolicy.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
framework="torch")
|
||||
|
||||
TorchPolicy.__init__(self, obs_space, action_space, config,
|
||||
self.model, loss_fn, self.dist_class)
|
||||
config,
|
||||
model=self.model,
|
||||
loss=loss_fn,
|
||||
action_distribution_class=self.dist_class,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
SMALL_NUMBER = 1e-6
|
||||
|
@ -123,9 +124,19 @@ def fc(x, weights, biases=None):
|
|||
Returns:
|
||||
The dense layer's output.
|
||||
"""
|
||||
|
||||
# Torch stores matrices in transpose (faster for backprop).
|
||||
if torch and isinstance(weights, torch.Tensor):
|
||||
weights = np.transpose(weights.numpy())
|
||||
if torch: # and isinstance(weights, torch.Tensor):
|
||||
x = x.detach().numpy() if isinstance(x, torch.Tensor) else x
|
||||
weights = np.transpose(weights.detach().numpy()) if \
|
||||
isinstance(weights, torch.Tensor) else weights
|
||||
biases = biases.detach().numpy() if \
|
||||
isinstance(biases, torch.Tensor) else biases
|
||||
if tf:
|
||||
x = x.numpy() if isinstance(x, tf.Variable) else x
|
||||
weights = weights.numpy() if isinstance(weights, tf.Variable) else \
|
||||
weights
|
||||
biases = biases.numpy() if isinstance(biases, tf.Variable) else biases
|
||||
return np.matmul(x, weights) + (0.0 if biases is None else biases)
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
import random
|
||||
|
||||
from ray.util import log_once
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
@ -63,6 +64,7 @@ def minibatches(samples, sgd_minibatch_size):
|
|||
"Minibatching not implemented for multi-agent in simple mode")
|
||||
|
||||
if "state_in_0" in samples.data:
|
||||
if log_once("not_shuffling_rnn_data_in_simple_mode"):
|
||||
logger.warning("Not shuffling RNN data for SGD in simple mode")
|
||||
else:
|
||||
samples.shuffle()
|
||||
|
|
Loading…
Add table
Reference in a new issue