ray/rllib/evaluation/tests/test_postprocessing.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

180 lines
6.7 KiB
Python

import numpy as np
import unittest
import ray
from ray.rllib.evaluation.postprocessing import adjust_nstep, discount_cumsum
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import check
class TestPostprocessing(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_n_step_3(self):
"""Tests, whether n-step adjustments of trajectories work."""
# n-step = 3
gamma = 0.9
obs = [1, 2, 3, 4, 5, 6, 7]
actions = ["ac1", "ac2", "ac1", "ac1", "ac1", "ac2", "ac1"]
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
dones = [0, 0, 0, 0, 0, 0, 1]
next_obs = [2, 3, 4, 5, 6, 7, 8]
batch = SampleBatch(
{
SampleBatch.OBS: obs,
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: rewards,
SampleBatch.DONES: dones,
SampleBatch.NEXT_OBS: next_obs,
}
)
adjust_nstep(3, gamma, batch)
check(batch[SampleBatch.OBS], [1, 2, 3, 4, 5, 6, 7])
check(
batch[SampleBatch.ACTIONS],
["ac1", "ac2", "ac1", "ac1", "ac1", "ac2", "ac1"],
)
check(batch[SampleBatch.NEXT_OBS], [4, 5, 6, 7, 8, 8, 8])
check(batch[SampleBatch.DONES], [0, 0, 0, 0, 1, 1, 1])
check(
batch[SampleBatch.REWARDS], [91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0]
)
def test_n_step_4(self):
"""Tests, whether n-step adjustments of trajectories work."""
# n-step = 4
gamma = 0.99
obs = np.arange(0, 7)
actions = np.random.randint(-1, 3, size=(7,))
check_actions = actions.copy()
rewards = [10.0, 0.0, 100.0, 50.0, 60.0, 10.0, 100.0]
dones = [False, False, False, False, False, False, True]
next_obs = np.arange(1, 8)
batch = SampleBatch(
{
SampleBatch.OBS: obs,
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: rewards,
SampleBatch.DONES: dones,
SampleBatch.NEXT_OBS: next_obs,
}
)
adjust_nstep(4, gamma, batch)
check(batch[SampleBatch.OBS], [0, 1, 2, 3, 4, 5, 6])
check(batch[SampleBatch.ACTIONS], check_actions)
check(batch[SampleBatch.NEXT_OBS], [4, 5, 6, 7, 7, 7, 7])
check(batch[SampleBatch.DONES], [False, False, False, True, True, True, True])
check(
batch[SampleBatch.REWARDS],
[
discount_cumsum(np.array(rewards[0:4]), gamma)[0],
discount_cumsum(np.array(rewards[1:5]), gamma)[0],
discount_cumsum(np.array(rewards[2:6]), gamma)[0],
discount_cumsum(np.array(rewards[3:7]), gamma)[0],
discount_cumsum(np.array(rewards[4:]), gamma)[0],
discount_cumsum(np.array(rewards[5:]), gamma)[0],
discount_cumsum(np.array(rewards[6:]), gamma)[0],
],
)
def test_n_step_malformed_dones(self):
# Test bad input (trajectory has dones in middle).
# Re-use same batch, but change dones.
gamma = 1.0
obs = np.arange(0, 7)
actions = np.random.randint(-1, 3, size=(7,))
rewards = [10.0, 0.0, 100.0, 50.0, 60.0, 10.0, 100.0]
next_obs = np.arange(1, 8)
batch = SampleBatch(
{
SampleBatch.OBS: obs,
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: rewards,
SampleBatch.DONES: [False, False, True, False, False, False, True],
SampleBatch.NEXT_OBS: next_obs,
}
)
self.assertRaisesRegex(
AssertionError,
"Unexpected done in middle",
lambda: adjust_nstep(5, gamma, batch),
)
def test_n_step_very_short_trajectory(self):
"""Tests, whether n-step also works for very small trajectories."""
gamma = 1.0
obs = np.arange(0, 2)
actions = np.random.randint(-100, 300, size=(2,))
check_actions = actions.copy()
rewards = [10.0, 100.0]
next_obs = np.arange(1, 3)
batch = SampleBatch(
{
SampleBatch.OBS: obs,
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: rewards,
SampleBatch.DONES: [False, False],
SampleBatch.NEXT_OBS: next_obs,
}
)
adjust_nstep(3, gamma, batch)
check(batch[SampleBatch.OBS], [0, 1])
check(batch[SampleBatch.ACTIONS], check_actions)
check(batch[SampleBatch.DONES], [False, False])
check(batch[SampleBatch.REWARDS], [10.0 + gamma * 100.0, 100.0])
check(batch[SampleBatch.NEXT_OBS], [2, 2])
def test_n_step_from_same_obs_source_array(self):
"""Tests, whether n-step also works on a shared obs/new-obs array."""
gamma = 0.99
# The underlying observation data. Both obs and next_obs will
# be references into that same np.array.
underlying_obs = np.arange(0, 8)
obs = underlying_obs[:7]
next_obs = underlying_obs[1:]
actions = np.random.randint(-1, 3, size=(7,))
check_actions = actions.copy()
rewards = [10.0, 0.0, 100.0, 50.0, 60.0, 10.0, 100.0]
dones = [False, False, False, False, False, False, True]
batch = SampleBatch(
{
SampleBatch.OBS: obs,
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: rewards,
SampleBatch.DONES: dones,
SampleBatch.NEXT_OBS: next_obs,
}
)
adjust_nstep(4, gamma, batch)
check(batch[SampleBatch.OBS], [0, 1, 2, 3, 4, 5, 6])
check(batch[SampleBatch.ACTIONS], check_actions)
check(batch[SampleBatch.NEXT_OBS], [4, 5, 6, 7, 7, 7, 7])
check(batch[SampleBatch.DONES], [False, False, False, True, True, True, True])
check(
batch[SampleBatch.REWARDS],
[
discount_cumsum(np.array(rewards[0:4]), gamma)[0],
discount_cumsum(np.array(rewards[1:5]), gamma)[0],
discount_cumsum(np.array(rewards[2:6]), gamma)[0],
discount_cumsum(np.array(rewards[3:7]), gamma)[0],
discount_cumsum(np.array(rewards[4:]), gamma)[0],
discount_cumsum(np.array(rewards[5:]), gamma)[0],
discount_cumsum(np.array(rewards[6:]), gamma)[0],
],
)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))