mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
242 lines
8.8 KiB
Python
242 lines
8.8 KiB
Python
![]() |
import unittest
|
||
|
|
||
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||
|
from ray.rllib.utils.test_utils import check_same_batch
|
||
|
|
||
|
|
||
|
class TestMultiAgentBatch(unittest.TestCase):
|
||
|
def test_timeslices_non_overlapping_experiences(self):
|
||
|
"""Tests if timeslices works as expected on a MultiAgentBatch
|
||
|
consisting of two non-overlapping SampleBatches.
|
||
|
"""
|
||
|
|
||
|
def _generate_data(agent_idx):
|
||
|
batch = SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0, 1],
|
||
|
SampleBatch.EPS_ID: 2 * [agent_idx],
|
||
|
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
||
|
SampleBatch.SEQ_LENS: [2],
|
||
|
}
|
||
|
)
|
||
|
return batch
|
||
|
|
||
|
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
|
||
|
ma_batch = MultiAgentBatch(policy_batches, 4)
|
||
|
sliced_ma_batches = ma_batch.timeslices(1)
|
||
|
|
||
|
[
|
||
|
check_same_batch(i, j)
|
||
|
for i, j in zip(
|
||
|
sliced_ma_batches,
|
||
|
[
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"0": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [0],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
)
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"0": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [1],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [0],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
)
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"1": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0],
|
||
|
SampleBatch.EPS_ID: [1],
|
||
|
SampleBatch.AGENT_INDEX: [1],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
)
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"1": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [1],
|
||
|
SampleBatch.EPS_ID: [1],
|
||
|
SampleBatch.AGENT_INDEX: [1],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
)
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
]
|
||
|
|
||
|
def test_timeslices_partially_overlapping_experiences(self):
|
||
|
"""Tests if timeslices works as expected on a MultiAgentBatch
|
||
|
consisting of two partially overlapping SampleBatches.
|
||
|
"""
|
||
|
|
||
|
def _generate_data(agent_idx, t_start):
|
||
|
batch = SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [t_start, t_start + 1],
|
||
|
SampleBatch.EPS_ID: [0, 0],
|
||
|
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
||
|
SampleBatch.SEQ_LENS: [2],
|
||
|
}
|
||
|
)
|
||
|
return batch
|
||
|
|
||
|
policy_batches = {str(idx): _generate_data(idx, idx) for idx in (range(2))}
|
||
|
ma_batch = MultiAgentBatch(policy_batches, 4)
|
||
|
sliced_ma_batches = ma_batch.timeslices(1)
|
||
|
|
||
|
[
|
||
|
check_same_batch(i, j)
|
||
|
for i, j in zip(
|
||
|
sliced_ma_batches,
|
||
|
[
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"0": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [0],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
)
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"0": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [1],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [0],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
),
|
||
|
"1": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [1],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [1],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
),
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"1": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [2],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [1],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
)
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
]
|
||
|
|
||
|
def test_timeslices_fully_overlapping_experiences(self):
|
||
|
"""Tests if timeslices works as expected on a MultiAgentBatch
|
||
|
consisting of two fully overlapping SampleBatches.
|
||
|
"""
|
||
|
|
||
|
def _generate_data(agent_idx):
|
||
|
batch = SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0, 1],
|
||
|
SampleBatch.EPS_ID: [0, 0],
|
||
|
SampleBatch.AGENT_INDEX: 2 * [agent_idx],
|
||
|
SampleBatch.SEQ_LENS: [2],
|
||
|
}
|
||
|
)
|
||
|
return batch
|
||
|
|
||
|
policy_batches = {str(idx): _generate_data(idx) for idx in (range(2))}
|
||
|
ma_batch = MultiAgentBatch(policy_batches, 4)
|
||
|
sliced_ma_batches = ma_batch.timeslices(1)
|
||
|
|
||
|
[
|
||
|
check_same_batch(i, j)
|
||
|
for i, j in zip(
|
||
|
sliced_ma_batches,
|
||
|
[
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"0": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [0],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
),
|
||
|
"1": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [0],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [1],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
),
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
MultiAgentBatch(
|
||
|
{
|
||
|
"0": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [1],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [0],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
),
|
||
|
"1": SampleBatch(
|
||
|
{
|
||
|
SampleBatch.T: [1],
|
||
|
SampleBatch.EPS_ID: [0],
|
||
|
SampleBatch.AGENT_INDEX: [1],
|
||
|
SampleBatch.SEQ_LENS: [1],
|
||
|
}
|
||
|
),
|
||
|
},
|
||
|
1,
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
]
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import pytest
|
||
|
import sys
|
||
|
|
||
|
sys.exit(pytest.main(["-v", __file__]))
|