import unittest from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.test_utils import check_same_batch class TestMultiAgentBatch(unittest.TestCase): def test_timeslices_non_overlapping_experiences(self): """Tests if timeslices works as expected on a MultiAgentBatch consisting of two non-overlapping SampleBatches. """ def _generate_data(agent_idx): batch = SampleBatch( { SampleBatch.T: [0, 1], SampleBatch.EPS_ID: 2 * [agent_idx], SampleBatch.AGENT_INDEX: 2 * [agent_idx], SampleBatch.SEQ_LENS: [2], } ) return batch policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))} ma_batch = MultiAgentBatch(policy_batches, 4) sliced_ma_batches = ma_batch.timeslices(1) [ check_same_batch(i, j) for i, j in zip( sliced_ma_batches, [ MultiAgentBatch( { "0": SampleBatch( { SampleBatch.T: [0], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], } ) }, 1, ), MultiAgentBatch( { "0": SampleBatch( { SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], } ) }, 1, ), MultiAgentBatch( { "1": SampleBatch( { SampleBatch.T: [0], SampleBatch.EPS_ID: [1], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], } ) }, 1, ), MultiAgentBatch( { "1": SampleBatch( { SampleBatch.T: [1], SampleBatch.EPS_ID: [1], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], } ) }, 1, ), ], ) ] def test_timeslices_partially_overlapping_experiences(self): """Tests if timeslices works as expected on a MultiAgentBatch consisting of two partially overlapping SampleBatches. """ def _generate_data(agent_idx, t_start): batch = SampleBatch( { SampleBatch.T: [t_start, t_start + 1], SampleBatch.EPS_ID: [0, 0], SampleBatch.AGENT_INDEX: 2 * [agent_idx], SampleBatch.SEQ_LENS: [2], } ) return batch policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))} ma_batch = MultiAgentBatch(policy_batches, 4) sliced_ma_batches = ma_batch.timeslices(1) [ check_same_batch(i, j) for i, j in zip( sliced_ma_batches, [ MultiAgentBatch( { "0": SampleBatch( { SampleBatch.T: [0], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], } ) }, 1, ), MultiAgentBatch( { "0": SampleBatch( { SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], } ), "1": SampleBatch( { SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], } ), }, 1, ), MultiAgentBatch( { "1": SampleBatch( { SampleBatch.T: [2], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], } ) }, 1, ), ], ) ] def test_timeslices_fully_overlapping_experiences(self): """Tests if timeslices works as expected on a MultiAgentBatch consisting of two fully overlapping SampleBatches. """ def _generate_data(agent_idx): batch = SampleBatch( { SampleBatch.T: [0, 1], SampleBatch.EPS_ID: [0, 0], SampleBatch.AGENT_INDEX: 2 * [agent_idx], SampleBatch.SEQ_LENS: [2], } ) return batch policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))} ma_batch = MultiAgentBatch(policy_batches, 4) sliced_ma_batches = ma_batch.timeslices(1) [ check_same_batch(i, j) for i, j in zip( sliced_ma_batches, [ MultiAgentBatch( { "0": SampleBatch( { SampleBatch.T: [0], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], } ), "1": SampleBatch( { SampleBatch.T: [0], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], } ), }, 1, ), MultiAgentBatch( { "0": SampleBatch( { SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], } ), "1": SampleBatch( { SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], } ), }, 1, ), ], ) ] if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))