mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
241 lines
8.8 KiB
Python
241 lines
8.8 KiB
Python
import unittest
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
|
from ray.rllib.utils.test_utils import check_same_batch
|
|
|
|
|
|
class TestMultiAgentBatch(unittest.TestCase):
|
|
def test_timeslices_non_overlapping_experiences(self):
|
|
"""Tests if timeslices works as expected on a MultiAgentBatch
|
|
consisting of two non-overlapping SampleBatches.
|
|
"""
|
|
|
|
def _generate_data(agent_idx):
|
|
batch = SampleBatch(
|
|
{
|
|
SampleBatch.T: [0, 1],
|
|
SampleBatch.EPS_ID: 2 * [agent_idx],
|
|
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
|
SampleBatch.SEQ_LENS: [2],
|
|
}
|
|
)
|
|
return batch
|
|
|
|
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
|
|
ma_batch = MultiAgentBatch(policy_batches, 4)
|
|
sliced_ma_batches = ma_batch.timeslices(1)
|
|
|
|
[
|
|
check_same_batch(i, j)
|
|
for i, j in zip(
|
|
sliced_ma_batches,
|
|
[
|
|
MultiAgentBatch(
|
|
{
|
|
"0": SampleBatch(
|
|
{
|
|
SampleBatch.T: [0],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [0],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
)
|
|
},
|
|
1,
|
|
),
|
|
MultiAgentBatch(
|
|
{
|
|
"0": SampleBatch(
|
|
{
|
|
SampleBatch.T: [1],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [0],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
)
|
|
},
|
|
1,
|
|
),
|
|
MultiAgentBatch(
|
|
{
|
|
"1": SampleBatch(
|
|
{
|
|
SampleBatch.T: [0],
|
|
SampleBatch.EPS_ID: [1],
|
|
SampleBatch.AGENT_INDEX: [1],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
)
|
|
},
|
|
1,
|
|
),
|
|
MultiAgentBatch(
|
|
{
|
|
"1": SampleBatch(
|
|
{
|
|
SampleBatch.T: [1],
|
|
SampleBatch.EPS_ID: [1],
|
|
SampleBatch.AGENT_INDEX: [1],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
)
|
|
},
|
|
1,
|
|
),
|
|
],
|
|
)
|
|
]
|
|
|
|
def test_timeslices_partially_overlapping_experiences(self):
|
|
"""Tests if timeslices works as expected on a MultiAgentBatch
|
|
consisting of two partially overlapping SampleBatches.
|
|
"""
|
|
|
|
def _generate_data(agent_idx, t_start):
|
|
batch = SampleBatch(
|
|
{
|
|
SampleBatch.T: [t_start, t_start + 1],
|
|
SampleBatch.EPS_ID: [0, 0],
|
|
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
|
SampleBatch.SEQ_LENS: [2],
|
|
}
|
|
)
|
|
return batch
|
|
|
|
policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))}
|
|
ma_batch = MultiAgentBatch(policy_batches, 4)
|
|
sliced_ma_batches = ma_batch.timeslices(1)
|
|
|
|
[
|
|
check_same_batch(i, j)
|
|
for i, j in zip(
|
|
sliced_ma_batches,
|
|
[
|
|
MultiAgentBatch(
|
|
{
|
|
"0": SampleBatch(
|
|
{
|
|
SampleBatch.T: [0],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [0],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
)
|
|
},
|
|
1,
|
|
),
|
|
MultiAgentBatch(
|
|
{
|
|
"0": SampleBatch(
|
|
{
|
|
SampleBatch.T: [1],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [0],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
),
|
|
"1": SampleBatch(
|
|
{
|
|
SampleBatch.T: [1],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [1],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
),
|
|
},
|
|
1,
|
|
),
|
|
MultiAgentBatch(
|
|
{
|
|
"1": SampleBatch(
|
|
{
|
|
SampleBatch.T: [2],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [1],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
)
|
|
},
|
|
1,
|
|
),
|
|
],
|
|
)
|
|
]
|
|
|
|
def test_timeslices_fully_overlapping_experiences(self):
|
|
"""Tests if timeslices works as expected on a MultiAgentBatch
|
|
consisting of two fully overlapping SampleBatches.
|
|
"""
|
|
|
|
def _generate_data(agent_idx):
|
|
batch = SampleBatch(
|
|
{
|
|
SampleBatch.T: [0, 1],
|
|
SampleBatch.EPS_ID: [0, 0],
|
|
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
|
SampleBatch.SEQ_LENS: [2],
|
|
}
|
|
)
|
|
return batch
|
|
|
|
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
|
|
ma_batch = MultiAgentBatch(policy_batches, 4)
|
|
sliced_ma_batches = ma_batch.timeslices(1)
|
|
|
|
[
|
|
check_same_batch(i, j)
|
|
for i, j in zip(
|
|
sliced_ma_batches,
|
|
[
|
|
MultiAgentBatch(
|
|
{
|
|
"0": SampleBatch(
|
|
{
|
|
SampleBatch.T: [0],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [0],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
),
|
|
"1": SampleBatch(
|
|
{
|
|
SampleBatch.T: [0],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [1],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
),
|
|
},
|
|
1,
|
|
),
|
|
MultiAgentBatch(
|
|
{
|
|
"0": SampleBatch(
|
|
{
|
|
SampleBatch.T: [1],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [0],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
),
|
|
"1": SampleBatch(
|
|
{
|
|
SampleBatch.T: [1],
|
|
SampleBatch.EPS_ID: [0],
|
|
SampleBatch.AGENT_INDEX: [1],
|
|
SampleBatch.SEQ_LENS: [1],
|
|
}
|
|
),
|
|
},
|
|
1,
|
|
),
|
|
],
|
|
)
|
|
]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|