mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 22625: MultiAgentBatch.timeslices()
does not behave as expected. (#22657)
This commit is contained in:
parent
4576f53fe3
commit
c0ade5f0b7
5 changed files with 335 additions and 7 deletions
|
@ -1344,6 +1344,13 @@ py_test(
|
|||
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "policy/tests/test_multi_agent_batch",
|
||||
tags = ["team:ml", "policy"],
|
||||
size = "small",
|
||||
srcs = ["policy/tests/test_multi_agent_batch.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "policy/tests/test_policy",
|
||||
tags = ["team:ml", "policy"],
|
||||
|
|
|
@ -312,7 +312,7 @@ class SampleBatch(dict):
|
|||
def rows(self) -> Iterator[Dict[str, TensorType]]:
|
||||
"""Returns an iterator over data rows, i.e. dicts with column values.
|
||||
|
||||
Note that if `seq_lens` is set in self, we set it to [1] in the rows.
|
||||
Note that if `seq_lens` is set in self, we set it to 1 in the rows.
|
||||
|
||||
Yields:
|
||||
The column values of the row in this iteration.
|
||||
|
@ -325,13 +325,12 @@ class SampleBatch(dict):
|
|||
... })
|
||||
>>> for row in batch.rows():
|
||||
print(row)
|
||||
{"a": 1, "b": 4, "seq_lens": [1]}
|
||||
{"a": 2, "b": 5, "seq_lens": [1]}
|
||||
{"a": 3, "b": 6, "seq_lens": [1]}
|
||||
{"a": 1, "b": 4, "seq_lens": 1}
|
||||
{"a": 2, "b": 5, "seq_lens": 1}
|
||||
{"a": 3, "b": 6, "seq_lens": 1}
|
||||
"""
|
||||
|
||||
# Do we add seq_lens=[1] to each row?
|
||||
seq_lens = None if self.get(SampleBatch.SEQ_LENS) is None else np.array([1])
|
||||
seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
|
||||
|
||||
self_as_dict = {k: v for k, v in self.items()}
|
||||
|
||||
|
@ -1182,6 +1181,7 @@ class MultiAgentBatch:
|
|||
{k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
|
||||
)
|
||||
cur_slice_size = 0
|
||||
cur_slice.clear()
|
||||
finished_slices.append(batch)
|
||||
|
||||
# For each unique env timestep.
|
||||
|
|
241
rllib/policy/tests/test_multi_agent_batch.py
Normal file
241
rllib/policy/tests/test_multi_agent_batch.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
import unittest
|
||||
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.test_utils import check_same_batch
|
||||
|
||||
|
||||
class TestMultiAgentBatch(unittest.TestCase):
|
||||
def test_timeslices_non_overlapping_experiences(self):
|
||||
"""Tests if timeslices works as expected on a MultiAgentBatch
|
||||
consisting of two non-overlapping SampleBatches.
|
||||
"""
|
||||
|
||||
def _generate_data(agent_idx):
|
||||
batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0, 1],
|
||||
SampleBatch.EPS_ID: 2 * [agent_idx],
|
||||
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
||||
SampleBatch.SEQ_LENS: [2],
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
|
||||
ma_batch = MultiAgentBatch(policy_batches, 4)
|
||||
sliced_ma_batches = ma_batch.timeslices(1)
|
||||
|
||||
[
|
||||
check_same_batch(i, j)
|
||||
for i, j in zip(
|
||||
sliced_ma_batches,
|
||||
[
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"0": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [0],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
)
|
||||
},
|
||||
1,
|
||||
),
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"0": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [1],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [0],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
)
|
||||
},
|
||||
1,
|
||||
),
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"1": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0],
|
||||
SampleBatch.EPS_ID: [1],
|
||||
SampleBatch.AGENT_INDEX: [1],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
)
|
||||
},
|
||||
1,
|
||||
),
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"1": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [1],
|
||||
SampleBatch.EPS_ID: [1],
|
||||
SampleBatch.AGENT_INDEX: [1],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
)
|
||||
},
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_timeslices_partially_overlapping_experiences(self):
|
||||
"""Tests if timeslices works as expected on a MultiAgentBatch
|
||||
consisting of two partially overlapping SampleBatches.
|
||||
"""
|
||||
|
||||
def _generate_data(agent_idx, t_start):
|
||||
batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [t_start, t_start + 1],
|
||||
SampleBatch.EPS_ID: [0, 0],
|
||||
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
||||
SampleBatch.SEQ_LENS: [2],
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))}
|
||||
ma_batch = MultiAgentBatch(policy_batches, 4)
|
||||
sliced_ma_batches = ma_batch.timeslices(1)
|
||||
|
||||
[
|
||||
check_same_batch(i, j)
|
||||
for i, j in zip(
|
||||
sliced_ma_batches,
|
||||
[
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"0": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [0],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
)
|
||||
},
|
||||
1,
|
||||
),
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"0": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [1],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [0],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
),
|
||||
"1": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [1],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [1],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
),
|
||||
},
|
||||
1,
|
||||
),
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"1": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [2],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [1],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
)
|
||||
},
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def test_timeslices_fully_overlapping_experiences(self):
|
||||
"""Tests if timeslices works as expected on a MultiAgentBatch
|
||||
consisting of two fully overlapping SampleBatches.
|
||||
"""
|
||||
|
||||
def _generate_data(agent_idx):
|
||||
batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0, 1],
|
||||
SampleBatch.EPS_ID: [0, 0],
|
||||
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
||||
SampleBatch.SEQ_LENS: [2],
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
|
||||
ma_batch = MultiAgentBatch(policy_batches, 4)
|
||||
sliced_ma_batches = ma_batch.timeslices(1)
|
||||
|
||||
[
|
||||
check_same_batch(i, j)
|
||||
for i, j in zip(
|
||||
sliced_ma_batches,
|
||||
[
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"0": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [0],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
),
|
||||
"1": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [0],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [1],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
),
|
||||
},
|
||||
1,
|
||||
),
|
||||
MultiAgentBatch(
|
||||
{
|
||||
"0": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [1],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [0],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
),
|
||||
"1": SampleBatch(
|
||||
{
|
||||
SampleBatch.T: [1],
|
||||
SampleBatch.EPS_ID: [0],
|
||||
SampleBatch.AGENT_INDEX: [1],
|
||||
SampleBatch.SEQ_LENS: [1],
|
||||
}
|
||||
),
|
||||
},
|
||||
1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -136,7 +136,7 @@ class TestSampleBatch(unittest.TestCase):
|
|||
)
|
||||
check(
|
||||
next(s1.rows()),
|
||||
{"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: [1]},
|
||||
{"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: 1},
|
||||
)
|
||||
|
||||
def test_compression(self):
|
||||
|
|
|
@ -820,3 +820,83 @@ def run_learning_tests_from_yaml(
|
|||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_same_batch(batch1, batch2) -> None:
|
||||
"""Check if both batches are (almost) identical.
|
||||
|
||||
For MultiAgentBatches, the step count and individual policy's
|
||||
SampleBatches are checked for identity. For SampleBatches, identity is
|
||||
checked as the almost numerical key-value-pair identity between batches
|
||||
with ray.rllib.utils.test_utils.check(). unroll_id is compared only if
|
||||
both batches have an unroll_id.
|
||||
|
||||
Args:
|
||||
batch1: Batch to compare against batch2
|
||||
batch2: Batch to compare against batch1
|
||||
"""
|
||||
# Avoids circular import
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
||||
assert type(batch1) == type(
|
||||
batch2
|
||||
), "Input batches are of different " "types {} and {}".format(
|
||||
str(type(batch1)), str(type(batch2))
|
||||
)
|
||||
|
||||
def check_sample_batches(_batch1, _batch2, _policy_id=None):
|
||||
unroll_id_1 = _batch1.get("unroll_id", None)
|
||||
unroll_id_2 = _batch2.get("unroll_id", None)
|
||||
# unroll IDs only have to fit if both batches have them
|
||||
if unroll_id_1 is not None and unroll_id_2 is not None:
|
||||
assert unroll_id_1 == unroll_id_2
|
||||
|
||||
batch1_keys = set()
|
||||
for k, v in _batch1.items():
|
||||
# unroll_id is compared above already
|
||||
if k == "unroll_id":
|
||||
continue
|
||||
check(v, _batch2[k])
|
||||
batch1_keys.add(k)
|
||||
|
||||
batch2_keys = set(_batch2.keys())
|
||||
# unroll_id is compared above already
|
||||
batch2_keys.discard("unroll_id")
|
||||
_difference = batch1_keys.symmetric_difference(batch2_keys)
|
||||
|
||||
# Cases where one batch has info and the other has not
|
||||
if _policy_id:
|
||||
assert not _difference, (
|
||||
"SampleBatches for policy with ID {} "
|
||||
"don't share information on the "
|
||||
"following information: \n{}"
|
||||
"".format(_policy_id, _difference)
|
||||
)
|
||||
else:
|
||||
assert not _difference, (
|
||||
"SampleBatches don't share information "
|
||||
"on the following information: \n{}"
|
||||
"".format(_difference)
|
||||
)
|
||||
|
||||
if type(batch1) == SampleBatch:
|
||||
check_sample_batches(batch1, batch2)
|
||||
elif type(batch1) == MultiAgentBatch:
|
||||
assert batch1.count == batch2.count
|
||||
batch1_ids = set()
|
||||
for policy_id, policy_batch in batch1.policy_batches.items():
|
||||
check_sample_batches(
|
||||
policy_batch, batch2.policy_batches[policy_id], policy_id
|
||||
)
|
||||
batch1_ids.add(policy_id)
|
||||
|
||||
# Case where one ma batch has info on a policy the other has not
|
||||
batch2_ids = set(batch2.policy_batches.keys())
|
||||
difference = batch1_ids.symmetric_difference(batch2_ids)
|
||||
assert (
|
||||
not difference
|
||||
), "MultiAgentBatches don't share the following" "information: \n{}.".format(
|
||||
difference
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported batch type " + str(type(batch1)))
|
||||
|
|
Loading…
Add table
Reference in a new issue