mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
This PR fixes the currently broken lstm_use_prev_action_reward flag for default lstm models (model.use_lstm=True). (#8970)
This commit is contained in:
parent
d7549d6184
commit
5c6d5d4ab1
19 changed files with 118 additions and 119 deletions
|
@ -492,8 +492,7 @@ py_test(
|
|||
name = "test_ppo",
|
||||
tags = ["agents_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/ppo/tests/test_ppo.py",
|
||||
"agents/ppo/tests/test.py"] # TODO(sven): Move down once PR 6889 merged
|
||||
srcs = ["agents/ppo/tests/test_ppo.py"]
|
||||
)
|
||||
|
||||
# PPO: DDPPO
|
||||
|
|
|
@ -11,11 +11,11 @@ tf = try_import_tf()
|
|||
|
||||
class TestIMPALA(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(local_mode=True)
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_impala_compilation(self):
|
||||
|
@ -40,11 +40,15 @@ class TestIMPALA(unittest.TestCase):
|
|||
# Test w/ LSTM.
|
||||
print("w/ LSTM")
|
||||
local_cfg["model"]["use_lstm"] = True
|
||||
local_cfg["model"]["lstm_use_prev_action_reward"] = True
|
||||
local_cfg["num_aggregation_workers"] = 2
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(trainer, include_state=True)
|
||||
check_compute_single_action(
|
||||
trainer,
|
||||
include_state=True,
|
||||
include_prev_action_reward=True)
|
||||
trainer.stop()
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
|
@ -10,7 +11,8 @@ from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
|
|||
LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import explained_variance, sequence_mask
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
|
||||
explained_variance, sequence_mask
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
@ -185,14 +187,15 @@ class ValueNetworkMixin:
|
|||
|
||||
def value(ob, prev_action, prev_reward, *state):
|
||||
model_out, _ = self.model({
|
||||
SampleBatch.CUR_OBS: self._convert_to_tensor([ob]),
|
||||
SampleBatch.PREV_ACTIONS: self._convert_to_tensor(
|
||||
[prev_action]),
|
||||
SampleBatch.PREV_REWARDS: self._convert_to_tensor(
|
||||
[prev_reward]),
|
||||
SampleBatch.CUR_OBS: convert_to_torch_tensor(
|
||||
np.asarray([ob])),
|
||||
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
|
||||
np.asarray([prev_action])),
|
||||
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
|
||||
np.asarray([prev_reward])),
|
||||
"is_training": False,
|
||||
}, [self._convert_to_tensor(s) for s in state],
|
||||
self._convert_to_tensor([1]))
|
||||
}, [convert_to_torch_tensor(np.asarray(s)) for s in state],
|
||||
convert_to_torch_tensor(np.asarray([1])))
|
||||
return self.model.value_function()[0]
|
||||
|
||||
else:
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from ray.rllib.agents.ppo.utils import flatten, concatenate
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
# TODO(sven): Move to utils/tests/.
|
||||
class UtilsTest(unittest.TestCase):
|
||||
def testFlatten(self):
|
||||
d = {
|
||||
"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
|
||||
"a": np.array([[[5], [-5]], [[6], [-6]]])
|
||||
}
|
||||
flat = flatten(d.copy(), start=0, stop=2)
|
||||
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
|
||||
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
|
||||
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
|
||||
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
|
||||
assert_allclose(d["a"][0][0], flat["a"][0])
|
||||
assert_allclose(d["a"][0][1], flat["a"][1])
|
||||
assert_allclose(d["a"][1][0], flat["a"][2])
|
||||
assert_allclose(d["a"][1][1], flat["a"][3])
|
||||
|
||||
def testConcatenate(self):
|
||||
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
|
||||
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
|
||||
d = concatenate([d1, d2])
|
||||
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
|
||||
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
|
||||
|
||||
D = concatenate([d])
|
||||
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
|
||||
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -47,17 +47,31 @@ class TestPPO(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_ppo_compilation(self):
|
||||
"""Test whether a PPOTrainer can be built with both frameworks."""
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
config["num_workers"] = 1
|
||||
config["num_sgd_iter"] = 2
|
||||
# Settings in case we use an LSTM.
|
||||
config["model"]["lstm_cell_size"] = 10
|
||||
config["model"]["max_seq_len"] = 20
|
||||
config["train_batch_size"] = 128
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
check_compute_single_action(
|
||||
trainer, include_prev_action_reward=True)
|
||||
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
|
||||
print("Env={}".format(env))
|
||||
for lstm in [True, False]:
|
||||
print("LSTM={}".format(lstm))
|
||||
config["model"]["use_lstm"] = lstm
|
||||
config["model"]["lstm_use_prev_action_reward"] = lstm
|
||||
trainer = ppo.PPOTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
check_compute_single_action(
|
||||
trainer,
|
||||
include_prev_action_reward=True,
|
||||
include_state=lstm)
|
||||
trainer.stop()
|
||||
|
||||
def test_ppo_fake_multi_gpu_learning(self):
|
||||
"""Test whether PPOTrainer can learn CartPole w/ faked multi-GPU."""
|
||||
|
@ -86,6 +100,7 @@ class TestPPO(unittest.TestCase):
|
|||
break
|
||||
print(results)
|
||||
assert learnt, "PPO multi-GPU (with fake-GPUs) did not learn CartPole!"
|
||||
trainer.stop()
|
||||
|
||||
def test_ppo_exploration_setup(self):
|
||||
"""Tests, whether PPO runs with different exploration setups."""
|
||||
|
@ -125,6 +140,7 @@ class TestPPO(unittest.TestCase):
|
|||
prev_action=np.array(2),
|
||||
prev_reward=np.array(1.0)))
|
||||
check(np.mean(actions), 1.5, atol=0.2)
|
||||
trainer.stop()
|
||||
|
||||
def test_ppo_free_log_std(self):
|
||||
"""Tests the free log std option works."""
|
||||
|
@ -176,6 +192,7 @@ class TestPPO(unittest.TestCase):
|
|||
# Check the variable is updated.
|
||||
post_std = get_value()
|
||||
assert post_std != 0.0, post_std
|
||||
trainer.stop()
|
||||
|
||||
def test_ppo_loss_function(self):
|
||||
"""Tests the PPO loss function math."""
|
||||
|
@ -272,6 +289,7 @@ class TestPPO(unittest.TestCase):
|
|||
check(
|
||||
policy.loss_obj.mean_vf_loss, np.mean(vf_loss), decimals=4)
|
||||
check(policy.loss_obj.loss, overall_loss, decimals=4)
|
||||
trainer.stop()
|
||||
|
||||
def _ppo_loss_helper(self,
|
||||
policy,
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def flatten(weights, start=0, stop=2):
|
||||
"""This methods reshapes all values in a dictionary.
|
||||
|
||||
The indices from start to stop will be flattened into a single index.
|
||||
|
||||
Args:
|
||||
weights: A dictionary mapping keys to numpy arrays.
|
||||
start: The starting index.
|
||||
stop: The ending index.
|
||||
"""
|
||||
for key, val in weights.items():
|
||||
new_shape = val.shape[0:start] + (-1, ) + val.shape[stop:]
|
||||
weights[key] = val.reshape(new_shape)
|
||||
return weights
|
||||
|
||||
|
||||
def concatenate(weights_list):
|
||||
keys = weights_list[0].keys()
|
||||
result = {}
|
||||
for key in keys:
|
||||
result[key] = np.concatenate([l[key] for l in weights_list])
|
||||
return result
|
||||
|
||||
|
||||
def shuffle(trajectory):
|
||||
permutation = np.random.permutation(trajectory["actions"].shape[0])
|
||||
for key, val in trajectory.items():
|
||||
trajectory[key] = val[permutation]
|
||||
return trajectory
|
|
@ -36,8 +36,8 @@ def build_sac_model(policy, obs_space, action_space, config):
|
|||
logger.warning(
|
||||
"When not using a state-preprocessor with SAC, `fcnet_hiddens`"
|
||||
" will be set to an empty list! Any hidden layer sizes are "
|
||||
"defined via `policy_model.hidden_layer_sizes` and "
|
||||
"`Q_model.hidden_layer_sizes`.")
|
||||
"defined via `policy_model.fcnet_hiddens` and "
|
||||
"`Q_model.fcnet_hiddens`.")
|
||||
config["model"]["fcnet_hiddens"] = []
|
||||
|
||||
# Force-ignore any additionally provided hidden layer sizes.
|
||||
|
|
|
@ -138,6 +138,8 @@ class MultiAgentEpisode:
|
|||
else:
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
flat = flatten_to_single_ndarray(policy.action_space.sample())
|
||||
if hasattr(policy.action_space, "dtype"):
|
||||
return np.zeros_like(flat, dtype=policy.action_space.dtype)
|
||||
return np.zeros_like(flat)
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -22,7 +22,7 @@ parser.add_argument(
|
|||
parser.add_argument("--eager", action="store_true")
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
ray.init(local_mode=True)
|
||||
args = parser.parse_args()
|
||||
if args.framework == "torch":
|
||||
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
|
||||
|
|
|
@ -70,7 +70,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, stop=stop, config=config)
|
||||
results = tune.run(args.run, stop=stop, config=config, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -335,7 +335,7 @@ _cache = {}
|
|||
def _unpack_obs(obs, space, tensorlib=tf):
|
||||
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
obs: The flattened observation tensor, with last dimension equal to
|
||||
the flat size and any number of batch dimensions. For example, for
|
||||
Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case
|
||||
|
|
|
@ -64,7 +64,7 @@ class RepeatedValues:
|
|||
>>> print(max(len(x) for x in items) <= N)
|
||||
True
|
||||
>>> print(items)
|
||||
... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
|
||||
... [<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
|
||||
... ...
|
||||
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
|
||||
... ...
|
||||
|
|
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
|
@ -109,6 +110,12 @@ class LSTMWrapper(RecurrentNetwork):
|
|||
model_config, name)
|
||||
|
||||
self.cell_size = model_config["lstm_cell_size"]
|
||||
self.use_prev_action_reward = model_config[
|
||||
"lstm_use_prev_action_reward"]
|
||||
self.action_dim = int(np.product(action_space.shape))
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_prev_action_reward:
|
||||
self.num_outputs += 1 + self.action_dim
|
||||
|
||||
# Define input layers.
|
||||
input_layer = tf.keras.layers.Input(
|
||||
|
@ -151,6 +158,21 @@ class LSTMWrapper(RecurrentNetwork):
|
|||
# Push obs through "unwrapped" net's `forward()` first.
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
if self.model_config["lstm_use_prev_action_reward"]:
|
||||
if self.model_config["lstm_use_prev_action_reward"]:
|
||||
wrapped_out = tf.concat(
|
||||
[
|
||||
wrapped_out,
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
|
||||
tf.float32), [-1, self.action_dim]),
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_REWARDS],
|
||||
tf.float32), [-1, 1]),
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# Then through our LSTM.
|
||||
input_dict["obs_flat"] = wrapped_out
|
||||
return super().forward(input_dict, state, seq_lens)
|
||||
|
|
|
@ -4,6 +4,7 @@ from ray.rllib.models.modelv2 import ModelV2
|
|||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
|
@ -101,6 +102,12 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
super().__init__(obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.cell_size = model_config["lstm_cell_size"]
|
||||
self.use_prev_action_reward = model_config[
|
||||
"lstm_use_prev_action_reward"]
|
||||
self.action_dim = int(np.product(action_space.shape))
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_prev_action_reward:
|
||||
self.num_outputs += 1 + self.action_dim
|
||||
self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True)
|
||||
|
||||
self.num_outputs = num_outputs
|
||||
|
@ -123,6 +130,18 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
# Push obs through "unwrapped" net's `forward()` first.
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
if self.model_config["lstm_use_prev_action_reward"]:
|
||||
wrapped_out = torch.cat(
|
||||
[
|
||||
wrapped_out,
|
||||
torch.reshape(input_dict[SampleBatch.PREV_ACTIONS].float(),
|
||||
[-1, self.action_dim]),
|
||||
torch.reshape(input_dict[SampleBatch.PREV_REWARDS],
|
||||
[-1, 1]),
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# Then through our LSTM.
|
||||
input_dict["obs_flat"] = wrapped_out
|
||||
return super().forward(input_dict, state, seq_lens)
|
||||
|
|
|
@ -153,8 +153,8 @@ class Policy(metaclass=ABCMeta):
|
|||
episodes = [episode]
|
||||
if state is not None:
|
||||
state_batch = [
|
||||
s.unsqueeze(0)
|
||||
if torch and isinstance(s, torch.Tensor) else [s]
|
||||
s.unsqueeze(0) if torch and isinstance(s, torch.Tensor) else
|
||||
np.expand_dims(s, 0)
|
||||
for s in state
|
||||
]
|
||||
|
||||
|
|
|
@ -114,15 +114,17 @@ class TorchPolicy(Policy):
|
|||
with torch.no_grad():
|
||||
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
SampleBatch.CUR_OBS: obs_batch,
|
||||
SampleBatch.CUR_OBS: np.asarray(obs_batch),
|
||||
"is_training": False,
|
||||
})
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
np.asarray(prev_action_batch)
|
||||
if prev_reward_batch is not None:
|
||||
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
||||
input_dict[SampleBatch.PREV_REWARDS] = \
|
||||
np.asarray(prev_reward_batch)
|
||||
state_batches = [
|
||||
self._convert_to_tensor(s) for s in (state_batches or [])
|
||||
convert_to_torch_tensor(s) for s in (state_batches or [])
|
||||
]
|
||||
|
||||
if self.action_sampler_fn:
|
||||
|
@ -411,17 +413,9 @@ class TorchPolicy(Policy):
|
|||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
train_batch = UsageTrackingDict(postprocessed_batch)
|
||||
train_batch.set_get_interceptor(self._convert_to_tensor)
|
||||
train_batch.set_get_interceptor(convert_to_torch_tensor)
|
||||
return train_batch
|
||||
|
||||
def _convert_to_tensor(self, arr):
|
||||
if torch.is_tensor(arr):
|
||||
return arr.to(self.device)
|
||||
tensor = torch.from_numpy(np.asarray(arr))
|
||||
if tensor.dtype == torch.double:
|
||||
tensor = tensor.float()
|
||||
return tensor.to(self.device)
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
"""TODO(sven): implement for torch.
|
||||
|
|
|
@ -251,6 +251,8 @@ def check_compute_single_action(trainer,
|
|||
|
||||
Args:
|
||||
trainer (Trainer): The Trainer object to test.
|
||||
include_state (bool): Whether to include the initial state of the
|
||||
Policy's Model in the `compute_action` call.
|
||||
include_prev_action_reward (bool): Whether to include the prev-action
|
||||
and -reward in the `compute_action` call.
|
||||
|
||||
|
@ -266,17 +268,16 @@ def check_compute_single_action(trainer,
|
|||
action_space = pol.action_space
|
||||
|
||||
for what in [pol, trainer]:
|
||||
print("what={}".format(what))
|
||||
method_to_test = trainer.compute_action if what is trainer else \
|
||||
pol.compute_single_action
|
||||
|
||||
for explore in [True, False]:
|
||||
print("explore={}".format(explore))
|
||||
for full_fetch in ([False, True] if what is trainer else [False]):
|
||||
print("full-fetch={}".format(full_fetch))
|
||||
call_kwargs = {}
|
||||
if what is trainer:
|
||||
call_kwargs["full_fetch"] = full_fetch
|
||||
else:
|
||||
call_kwargs["clip_actions"] = True
|
||||
|
||||
obs = np.clip(obs_space.sample(), -1.0, 1.0)
|
||||
state_in = None
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import numpy as np
|
||||
|
||||
from ray.rllib.models.repeated_values import RepeatedValues
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
|
@ -123,8 +124,14 @@ def convert_to_torch_tensor(stats, device=None):
|
|||
"""
|
||||
|
||||
def mapping(item):
|
||||
# Already torch tensor -> make sure it's on right device.
|
||||
if torch.is_tensor(item):
|
||||
return item if device is None else item.to(device)
|
||||
# Special handling of "Repeated" values.
|
||||
elif isinstance(item, RepeatedValues):
|
||||
return RepeatedValues(
|
||||
tree.map_structure(mapping, item.values),
|
||||
item.lengths, item.max_len)
|
||||
tensor = torch.from_numpy(np.asarray(item))
|
||||
# Floatify all float64 tensors.
|
||||
if tensor.dtype == torch.double:
|
||||
|
|
|
@ -17,6 +17,11 @@ class UsageTrackingDict(dict):
|
|||
def set_get_interceptor(self, fn):
|
||||
self.get_interceptor = fn
|
||||
|
||||
def copy(self):
|
||||
copy = UsageTrackingDict(**dict.copy(self))
|
||||
copy.set_get_interceptor(self.get_interceptor)
|
||||
return copy
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.accessed_keys.add(key)
|
||||
value = dict.__getitem__(self, key)
|
||||
|
|
Loading…
Add table
Reference in a new issue