ray/rllib/policy/tests/test_multi_agent_batch.py

241 lines
8.8 KiB
Python

import unittest
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.test_utils import check_same_batch
class TestMultiAgentBatch(unittest.TestCase):
def test_timeslices_non_overlapping_experiences(self):
"""Tests if timeslices works as expected on a MultiAgentBatch
consisting of two non-overlapping SampleBatches.
"""
def _generate_data(agent_idx):
batch = SampleBatch(
{
SampleBatch.T: [0, 1],
SampleBatch.EPS_ID: 2 * [agent_idx],
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
SampleBatch.SEQ_LENS: [2],
}
)
return batch
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
ma_batch = MultiAgentBatch(policy_batches, 4)
sliced_ma_batches = ma_batch.timeslices(1)
[
check_same_batch(i, j)
for i, j in zip(
sliced_ma_batches,
[
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"1": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [1],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"1": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [1],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
],
)
]
def test_timeslices_partially_overlapping_experiences(self):
"""Tests if timeslices works as expected on a MultiAgentBatch
consisting of two partially overlapping SampleBatches.
"""
def _generate_data(agent_idx, t_start):
batch = SampleBatch(
{
SampleBatch.T: [t_start, t_start + 1],
SampleBatch.EPS_ID: [0, 0],
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
SampleBatch.SEQ_LENS: [2],
}
)
return batch
policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))}
ma_batch = MultiAgentBatch(policy_batches, 4)
sliced_ma_batches = ma_batch.timeslices(1)
[
check_same_batch(i, j)
for i, j in zip(
sliced_ma_batches,
[
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
),
"1": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
),
},
1,
),
MultiAgentBatch(
{
"1": SampleBatch(
{
SampleBatch.T: [2],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
)
},
1,
),
],
)
]
def test_timeslices_fully_overlapping_experiences(self):
"""Tests if timeslices works as expected on a MultiAgentBatch
consisting of two fully overlapping SampleBatches.
"""
def _generate_data(agent_idx):
batch = SampleBatch(
{
SampleBatch.T: [0, 1],
SampleBatch.EPS_ID: [0, 0],
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
SampleBatch.SEQ_LENS: [2],
}
)
return batch
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
ma_batch = MultiAgentBatch(policy_batches, 4)
sliced_ma_batches = ma_batch.timeslices(1)
[
check_same_batch(i, j)
for i, j in zip(
sliced_ma_batches,
[
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
),
"1": SampleBatch(
{
SampleBatch.T: [0],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
),
},
1,
),
MultiAgentBatch(
{
"0": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [0],
SampleBatch.SEQ_LENS: [1],
}
),
"1": SampleBatch(
{
SampleBatch.T: [1],
SampleBatch.EPS_ID: [0],
SampleBatch.AGENT_INDEX: [1],
SampleBatch.SEQ_LENS: [1],
}
),
},
1,
),
],
)
]
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))