This PR fixes the currently broken lstm_use_prev_action_reward flag for default lstm models (model.use_lstm=True). (#8970)

This commit is contained in:
Sven Mika 2020-06-27 20:50:01 +02:00 committed by GitHub
parent d7549d6184
commit 5c6d5d4ab1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 118 additions and 119 deletions

View file

@ -492,8 +492,7 @@ py_test(
name = "test_ppo",
tags = ["agents_dir"],
size = "large",
srcs = ["agents/ppo/tests/test_ppo.py",
"agents/ppo/tests/test.py"] # TODO(sven): Move down once PR 6889 merged
srcs = ["agents/ppo/tests/test_ppo.py"]
)
# PPO: DDPPO

View file

@ -11,11 +11,11 @@ tf = try_import_tf()
class TestIMPALA(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(local_mode=True)
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
ray.shutdown()
def test_impala_compilation(self):
@ -40,11 +40,15 @@ class TestIMPALA(unittest.TestCase):
# Test w/ LSTM.
print("w/ LSTM")
local_cfg["model"]["use_lstm"] = True
local_cfg["model"]["lstm_use_prev_action_reward"] = True
local_cfg["num_aggregation_workers"] = 2
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
print(trainer.train())
check_compute_single_action(trainer, include_state=True)
check_compute_single_action(
trainer,
include_state=True,
include_prev_action_reward=True)
trainer.stop()

View file

@ -1,4 +1,5 @@
import logging
import numpy as np
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
@ -10,7 +11,8 @@ from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, sequence_mask
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
explained_variance, sequence_mask
torch, nn = try_import_torch()
@ -185,14 +187,15 @@ class ValueNetworkMixin:
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: self._convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: self._convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: self._convert_to_tensor(
[prev_reward]),
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob])),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action])),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward])),
"is_training": False,
}, [self._convert_to_tensor(s) for s in state],
self._convert_to_tensor([1]))
}, [convert_to_torch_tensor(np.asarray(s)) for s in state],
convert_to_torch_tensor(np.asarray([1])))
return self.model.value_function()[0]
else:

View file

@ -1,43 +0,0 @@
import unittest
import numpy as np
from numpy.testing import assert_allclose
from ray.rllib.agents.ppo.utils import flatten, concatenate
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
# TODO(sven): Move to utils/tests/.
class UtilsTest(unittest.TestCase):
def testFlatten(self):
d = {
"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
"a": np.array([[[5], [-5]], [[6], [-6]]])
}
flat = flatten(d.copy(), start=0, stop=2)
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
assert_allclose(d["a"][0][0], flat["a"][0])
assert_allclose(d["a"][0][1], flat["a"][1])
assert_allclose(d["a"][1][0], flat["a"][2])
assert_allclose(d["a"][1][1], flat["a"][3])
def testConcatenate(self):
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
d = concatenate([d1, d2])
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
D = concatenate([d])
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -47,17 +47,31 @@ class TestPPO(unittest.TestCase):
ray.shutdown()
def test_ppo_compilation(self):
"""Test whether a PPOTrainer can be built with both frameworks."""
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 1
config["num_sgd_iter"] = 2
# Settings in case we use an LSTM.
config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20
config["train_batch_size"] = 128
num_iterations = 2
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
for i in range(num_iterations):
trainer.train()
check_compute_single_action(
trainer, include_prev_action_reward=True)
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env))
for lstm in [True, False]:
print("LSTM={}".format(lstm))
config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action_reward"] = lstm
trainer = ppo.PPOTrainer(config=config, env=env)
for i in range(num_iterations):
trainer.train()
check_compute_single_action(
trainer,
include_prev_action_reward=True,
include_state=lstm)
trainer.stop()
def test_ppo_fake_multi_gpu_learning(self):
"""Test whether PPOTrainer can learn CartPole w/ faked multi-GPU."""
@ -86,6 +100,7 @@ class TestPPO(unittest.TestCase):
break
print(results)
assert learnt, "PPO multi-GPU (with fake-GPUs) did not learn CartPole!"
trainer.stop()
def test_ppo_exploration_setup(self):
"""Tests, whether PPO runs with different exploration setups."""
@ -125,6 +140,7 @@ class TestPPO(unittest.TestCase):
prev_action=np.array(2),
prev_reward=np.array(1.0)))
check(np.mean(actions), 1.5, atol=0.2)
trainer.stop()
def test_ppo_free_log_std(self):
"""Tests the free log std option works."""
@ -176,6 +192,7 @@ class TestPPO(unittest.TestCase):
# Check the variable is updated.
post_std = get_value()
assert post_std != 0.0, post_std
trainer.stop()
def test_ppo_loss_function(self):
"""Tests the PPO loss function math."""
@ -272,6 +289,7 @@ class TestPPO(unittest.TestCase):
check(
policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy.loss_obj.loss, overall_loss, decimals=4)
trainer.stop()
def _ppo_loss_helper(self,
policy,

View file

@ -1,32 +0,0 @@
import numpy as np
def flatten(weights, start=0, stop=2):
"""This methods reshapes all values in a dictionary.
The indices from start to stop will be flattened into a single index.
Args:
weights: A dictionary mapping keys to numpy arrays.
start: The starting index.
stop: The ending index.
"""
for key, val in weights.items():
new_shape = val.shape[0:start] + (-1, ) + val.shape[stop:]
weights[key] = val.reshape(new_shape)
return weights
def concatenate(weights_list):
keys = weights_list[0].keys()
result = {}
for key in keys:
result[key] = np.concatenate([l[key] for l in weights_list])
return result
def shuffle(trajectory):
permutation = np.random.permutation(trajectory["actions"].shape[0])
for key, val in trajectory.items():
trajectory[key] = val[permutation]
return trajectory

View file

@ -36,8 +36,8 @@ def build_sac_model(policy, obs_space, action_space, config):
logger.warning(
"When not using a state-preprocessor with SAC, `fcnet_hiddens`"
" will be set to an empty list! Any hidden layer sizes are "
"defined via `policy_model.hidden_layer_sizes` and "
"`Q_model.hidden_layer_sizes`.")
"defined via `policy_model.fcnet_hiddens` and "
"`Q_model.fcnet_hiddens`.")
config["model"]["fcnet_hiddens"] = []
# Force-ignore any additionally provided hidden layer sizes.

View file

@ -138,6 +138,8 @@ class MultiAgentEpisode:
else:
policy = self._policies[self.policy_for(agent_id)]
flat = flatten_to_single_ndarray(policy.action_space.sample())
if hasattr(policy.action_space, "dtype"):
return np.zeros_like(flat, dtype=policy.action_space.dtype)
return np.zeros_like(flat)
@DeveloperAPI

View file

@ -22,7 +22,7 @@ parser.add_argument(
parser.add_argument("--eager", action="store_true")
if __name__ == "__main__":
ray.init()
ray.init(local_mode=True)
args = parser.parse_args()
if args.framework == "torch":
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)

View file

@ -70,7 +70,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, stop=stop, config=config)
results = tune.run(args.run, stop=stop, config=config, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -335,7 +335,7 @@ _cache = {}
def _unpack_obs(obs, space, tensorlib=tf):
"""Unpack a flattened Dict or Tuple observation array/tensor.
Arguments:
Args:
obs: The flattened observation tensor, with last dimension equal to
the flat size and any number of batch dimensions. For example, for
Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case

View file

@ -64,7 +64,7 @@ class RepeatedValues:
>>> print(max(len(x) for x in items) <= N)
True
>>> print(items)
... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
... [<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
... ...
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
... ...

View file

@ -3,6 +3,7 @@ import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
@ -109,6 +110,12 @@ class LSTMWrapper(RecurrentNetwork):
model_config, name)
self.cell_size = model_config["lstm_cell_size"]
self.use_prev_action_reward = model_config[
"lstm_use_prev_action_reward"]
self.action_dim = int(np.product(action_space.shape))
# Add prev-action/reward nodes to input to LSTM.
if self.use_prev_action_reward:
self.num_outputs += 1 + self.action_dim
# Define input layers.
input_layer = tf.keras.layers.Input(
@ -151,6 +158,21 @@ class LSTMWrapper(RecurrentNetwork):
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
# Concat. prev-action/reward if required.
if self.model_config["lstm_use_prev_action_reward"]:
if self.model_config["lstm_use_prev_action_reward"]:
wrapped_out = tf.concat(
[
wrapped_out,
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
tf.float32), [-1, self.action_dim]),
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_REWARDS],
tf.float32), [-1, 1]),
],
axis=1)
# Then through our LSTM.
input_dict["obs_flat"] = wrapped_out
return super().forward(input_dict, state, seq_lens)

View file

@ -4,6 +4,7 @@ from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
@ -101,6 +102,12 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
super().__init__(obs_space, action_space, None, model_config, name)
self.cell_size = model_config["lstm_cell_size"]
self.use_prev_action_reward = model_config[
"lstm_use_prev_action_reward"]
self.action_dim = int(np.product(action_space.shape))
# Add prev-action/reward nodes to input to LSTM.
if self.use_prev_action_reward:
self.num_outputs += 1 + self.action_dim
self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True)
self.num_outputs = num_outputs
@ -123,6 +130,18 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
# Concat. prev-action/reward if required.
if self.model_config["lstm_use_prev_action_reward"]:
wrapped_out = torch.cat(
[
wrapped_out,
torch.reshape(input_dict[SampleBatch.PREV_ACTIONS].float(),
[-1, self.action_dim]),
torch.reshape(input_dict[SampleBatch.PREV_REWARDS],
[-1, 1]),
],
dim=1)
# Then through our LSTM.
input_dict["obs_flat"] = wrapped_out
return super().forward(input_dict, state, seq_lens)

View file

@ -153,8 +153,8 @@ class Policy(metaclass=ABCMeta):
episodes = [episode]
if state is not None:
state_batch = [
s.unsqueeze(0)
if torch and isinstance(s, torch.Tensor) else [s]
s.unsqueeze(0) if torch and isinstance(s, torch.Tensor) else
np.expand_dims(s, 0)
for s in state
]

View file

@ -114,15 +114,17 @@ class TorchPolicy(Policy):
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_batch,
SampleBatch.CUR_OBS: np.asarray(obs_batch),
"is_training": False,
})
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
input_dict[SampleBatch.PREV_ACTIONS] = \
np.asarray(prev_action_batch)
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
input_dict[SampleBatch.PREV_REWARDS] = \
np.asarray(prev_reward_batch)
state_batches = [
self._convert_to_tensor(s) for s in (state_batches or [])
convert_to_torch_tensor(s) for s in (state_batches or [])
]
if self.action_sampler_fn:
@ -411,17 +413,9 @@ class TorchPolicy(Policy):
def _lazy_tensor_dict(self, postprocessed_batch):
train_batch = UsageTrackingDict(postprocessed_batch)
train_batch.set_get_interceptor(self._convert_to_tensor)
train_batch.set_get_interceptor(convert_to_torch_tensor)
return train_batch
def _convert_to_tensor(self, arr):
if torch.is_tensor(arr):
return arr.to(self.device)
tensor = torch.from_numpy(np.asarray(arr))
if tensor.dtype == torch.double:
tensor = tensor.float()
return tensor.to(self.device)
@override(Policy)
def export_model(self, export_dir):
"""TODO(sven): implement for torch.

View file

@ -251,6 +251,8 @@ def check_compute_single_action(trainer,
Args:
trainer (Trainer): The Trainer object to test.
include_state (bool): Whether to include the initial state of the
Policy's Model in the `compute_action` call.
include_prev_action_reward (bool): Whether to include the prev-action
and -reward in the `compute_action` call.
@ -266,17 +268,16 @@ def check_compute_single_action(trainer,
action_space = pol.action_space
for what in [pol, trainer]:
print("what={}".format(what))
method_to_test = trainer.compute_action if what is trainer else \
pol.compute_single_action
for explore in [True, False]:
print("explore={}".format(explore))
for full_fetch in ([False, True] if what is trainer else [False]):
print("full-fetch={}".format(full_fetch))
call_kwargs = {}
if what is trainer:
call_kwargs["full_fetch"] = full_fetch
else:
call_kwargs["clip_actions"] = True
obs = np.clip(obs_space.sample(), -1.0, 1.0)
state_in = None

View file

@ -1,5 +1,6 @@
import numpy as np
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.framework import try_import_torch
@ -123,8 +124,14 @@ def convert_to_torch_tensor(stats, device=None):
"""
def mapping(item):
# Already torch tensor -> make sure it's on right device.
if torch.is_tensor(item):
return item if device is None else item.to(device)
# Special handling of "Repeated" values.
elif isinstance(item, RepeatedValues):
return RepeatedValues(
tree.map_structure(mapping, item.values),
item.lengths, item.max_len)
tensor = torch.from_numpy(np.asarray(item))
# Floatify all float64 tensors.
if tensor.dtype == torch.double:

View file

@ -17,6 +17,11 @@ class UsageTrackingDict(dict):
def set_get_interceptor(self, fn):
self.get_interceptor = fn
def copy(self):
copy = UsageTrackingDict(**dict.copy(self))
copy.set_get_interceptor(self.get_interceptor)
return copy
def __getitem__(self, key):
self.accessed_keys.add(key)
value = dict.__getitem__(self, key)