mirror of
https://github.com/vale981/ray
synced 2025-03-11 13:46:40 -04:00
346 lines
13 KiB
Python
346 lines
13 KiB
Python
import gym
|
|
import numpy as np
|
|
import unittest
|
|
import ray
|
|
import math
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.utils.test_utils import check
|
|
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
|
|
|
|
|
|
class TestAgentCollector(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def _simulate_env_steps(self, ac, n_steps=1):
|
|
obses = []
|
|
obses.append(np.random.rand(4))
|
|
ac.add_init_obs(
|
|
episode_id=0,
|
|
agent_index=1,
|
|
env_id=0,
|
|
t=-1,
|
|
init_obs=obses[-1],
|
|
)
|
|
|
|
for t in range(n_steps):
|
|
obses.append(np.random.rand(4))
|
|
ac.add_action_reward_next_obs(
|
|
{SampleBatch.NEXT_OBS: obses[-1], SampleBatch.T: t}
|
|
)
|
|
|
|
return obses
|
|
|
|
def test_inference_vs_training_batch(self):
|
|
"""Test whether build_for_inference and build_for_training return the same
|
|
batch when they have to."""
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
ctx_len = 5
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
# include the current obs in the context
|
|
"prev_obses": ViewRequirement("obs", shift=f"-{ctx_len - 1}:0"),
|
|
}
|
|
|
|
n_steps = 100
|
|
obses = np.random.rand(n_steps, 4)
|
|
# list to store the last ctx_len obses
|
|
for training_mode in [False, True]:
|
|
ac = AgentCollector(
|
|
view_reqs=view_reqs,
|
|
is_policy_recurrent=True,
|
|
max_seq_len=20, # default max_seq_len in lstm
|
|
is_training=training_mode,
|
|
)
|
|
obses_ctx = []
|
|
for t, obs in enumerate(obses):
|
|
if t == 0:
|
|
# e.g. state = env.reset()
|
|
ac.add_init_obs(
|
|
episode_id=0,
|
|
agent_index=1,
|
|
env_id=0,
|
|
t=-1,
|
|
init_obs=obs,
|
|
)
|
|
obses_ctx.extend([obs for _ in range(ctx_len)])
|
|
else:
|
|
# e.g. next_state = env.step()
|
|
ac.add_action_reward_next_obs(
|
|
{SampleBatch.NEXT_OBS: obs, SampleBatch.T: t - 1}
|
|
)
|
|
# pop from front and add to the end
|
|
obses_ctx.pop(0)
|
|
obses_ctx.append(obs)
|
|
eval_batch = ac.build_for_inference()
|
|
# batch size should always be one
|
|
self.assertEqual(eval_batch.count, 1)
|
|
# shape of prev_obses should be (1, ctx_len, 4)
|
|
self.assertEqual(eval_batch["prev_obses"].shape, (1, ctx_len, 4))
|
|
# obs should always be the last time step obs added
|
|
check(eval_batch["obs"], obs[None])
|
|
# prev_obs should always be the last ctx_len time steps obs added
|
|
# (excluding the current time step)
|
|
check(eval_batch["prev_obses"], np.stack(obses_ctx, 0)[None])
|
|
|
|
# in inference mode the buffer length at the end should be just ctx_len
|
|
if not training_mode:
|
|
check(len(ac.buffers[SampleBatch.OBS][0]), ctx_len)
|
|
else:
|
|
# otherwise it should be n_steps + ctx_len - 1
|
|
check(len(ac.buffers[SampleBatch.OBS][0]), n_steps + ctx_len - 1)
|
|
|
|
self.assertTrue(ac.training, "Training mode should be True.")
|
|
train_batch = ac.build_for_training(view_reqs)
|
|
self.assertEqual(
|
|
len(train_batch["seq_lens"]), math.ceil(n_steps / ac.max_seq_len)
|
|
)
|
|
self.assertEqual(train_batch["prev_obses"].shape, (n_steps - 1, ctx_len, 4))
|
|
self.assertEqual(train_batch[SampleBatch.OBS].shape, (n_steps - 1, 4))
|
|
|
|
def test_inference_respects_causality(self):
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"future_obs": ViewRequirement("obs", shift=1),
|
|
"past_obs": ViewRequirement("obs", shift=-1),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
# build_for_train should return all keys
|
|
train_batch = ac.build_for_training(view_reqs)
|
|
self.assertTrue(all(key in train_batch.keys() for key in view_reqs.keys()))
|
|
|
|
# should error out since future_obs has used_for_compute_actions=True but
|
|
# depends on future
|
|
with self.assertRaises(ValueError):
|
|
ac.build_for_inference()
|
|
|
|
view_reqs["future_obs"] = ViewRequirement(
|
|
"obs", shift=1, used_for_compute_actions=False
|
|
)
|
|
# since future_obs is shoulld not be used in inference, it should not be in the
|
|
# batch
|
|
eval_batch = ac.build_for_inference()
|
|
self.assertTrue(
|
|
all(
|
|
k in eval_batch.keys()
|
|
for k, vr in view_reqs.items()
|
|
if vr.used_for_compute_actions
|
|
)
|
|
)
|
|
|
|
def test_slice_with_repeat_value_1(self):
|
|
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
ctx_len = 5
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"prev_obses": ViewRequirement("obs", shift=f"-{ctx_len}:-1"),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
# exclude the last one since these are the next_obses
|
|
expected_obses = np.stack(obses[:-1])
|
|
|
|
check(expected_obses, sample_batch[SampleBatch.OBS])
|
|
|
|
for t in range(10):
|
|
# no padding
|
|
if t > ctx_len - 1:
|
|
check(sample_batch["prev_obses"][t], expected_obses[t - ctx_len : t])
|
|
else:
|
|
# with padding
|
|
for offset in range(ctx_len):
|
|
if offset < ctx_len - t:
|
|
# check the padding
|
|
check(sample_batch["prev_obses"][t, offset], expected_obses[0])
|
|
else:
|
|
# check the rest of the data
|
|
check(
|
|
sample_batch["prev_obses"][t, offset:],
|
|
expected_obses[t - ctx_len + offset : t],
|
|
)
|
|
break
|
|
|
|
def test_slice_with_repeat_value_larger_1(self):
|
|
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
ctx_len = 5
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"prev_obses": ViewRequirement(
|
|
"obs", shift=f"-{ctx_len}:-1", batch_repeat_value=ctx_len
|
|
),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
# exclude the last one since these are the next_obses
|
|
expected_obses = np.stack(obses[:-1])
|
|
|
|
check(expected_obses, sample_batch[SampleBatch.OBS])
|
|
self.assertEqual(sample_batch["prev_obses"].shape, (2, ctx_len, 4))
|
|
|
|
# the first prev_obses should be just the first obses repeated ctx_len times
|
|
check(sample_batch["prev_obses"][0], np.ones((ctx_len, 1)) * expected_obses[0])
|
|
# the second prev_obses should be ctx_len slice of obses started at index 0
|
|
check(sample_batch["prev_obses"][1], expected_obses[:ctx_len])
|
|
|
|
def test_shift_by_one_with_repeat_value_larger_1(self):
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
ctx_len = 5
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"prev_obses": ViewRequirement("obs", shift=-1, batch_repeat_value=ctx_len),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
# exclude the last one since these are the next_obses
|
|
expected_obses = np.stack(obses[:-1])
|
|
|
|
self.assertEqual(sample_batch["prev_obses"].shape, (2, 4))
|
|
|
|
# should be the same as padding
|
|
check(sample_batch["prev_obses"][0], expected_obses[0])
|
|
|
|
# should be the same as index ctx_len - 1
|
|
check(sample_batch["prev_obses"][1], expected_obses[ctx_len - 1])
|
|
|
|
def test_shift_by_one_with_repeat_1(self):
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"prev_obses": ViewRequirement("obs", shift=-1),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
# exclude the last one since these are the next_obses
|
|
expected_obses = np.stack(obses[:-1])
|
|
|
|
# check the padding
|
|
check(sample_batch["prev_obses"][0], expected_obses[0])
|
|
|
|
# check the data
|
|
check(sample_batch["prev_obses"][1:], expected_obses[:-1])
|
|
|
|
def test_shift_positive_one_with_repeat_1(self):
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
SampleBatch.NEXT_OBS: ViewRequirement("obs", shift=1),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
check(sample_batch[SampleBatch.NEXT_OBS], np.stack(obses)[1:])
|
|
|
|
def test_shift_positive_one_with_repeat_larger_1(self):
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
ctx_len = 5
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
SampleBatch.NEXT_OBS: ViewRequirement(
|
|
"obs", shift=1, batch_repeat_value=ctx_len
|
|
),
|
|
}
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
expected_obses = np.stack(obses)
|
|
|
|
self.assertEqual(sample_batch[SampleBatch.NEXT_OBS].shape, (2, 4))
|
|
|
|
# next_obs at index = 0 should be equal to obs at index = 1
|
|
check(sample_batch[SampleBatch.NEXT_OBS][0], expected_obses[1])
|
|
|
|
# next_obs at index = 1 should be equal to next_obs at index = ctx_len - 1
|
|
# which is obs at index = ctx_len
|
|
check(sample_batch[SampleBatch.NEXT_OBS][1], expected_obses[ctx_len + 1])
|
|
|
|
def test_slice_with_array(self):
|
|
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"prev_obses": ViewRequirement("obs", shift=[-3, -1]),
|
|
}
|
|
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
# exclude the last one since these are the next_obses
|
|
expected_obses = np.stack(obses[:-1])
|
|
|
|
self.assertEqual(sample_batch["prev_obses"].shape, (10, 2, 4))
|
|
|
|
# check if the last time step is correct
|
|
check(sample_batch["prev_obses"][-1], expected_obses[-4:-1:2])
|
|
|
|
# check if the padding in the beginning is correct
|
|
check(sample_batch["prev_obses"][0], np.ones((2, 1)) * expected_obses[0])
|
|
|
|
def test_view_requirement_with_shfit_step(self):
|
|
obs_space = gym.spaces.Box(-np.ones(4), np.ones(4))
|
|
view_reqs = {
|
|
SampleBatch.T: ViewRequirement(SampleBatch.T),
|
|
SampleBatch.OBS: ViewRequirement("obs", space=obs_space),
|
|
"prev_obses": ViewRequirement("obs", shift="-5:-1:2"), # [-5, -3, -1]
|
|
}
|
|
|
|
ac = AgentCollector(view_reqs=view_reqs, is_policy_recurrent=True)
|
|
|
|
obses = self._simulate_env_steps(ac, n_steps=10)
|
|
|
|
sample_batch = ac.build_for_training(view_reqs)
|
|
|
|
# exclude the last one since these are the next_obses
|
|
expected_obses = np.stack(obses[:-1])
|
|
|
|
self.assertEqual(sample_batch["prev_obses"].shape, (10, 3, 4))
|
|
|
|
# check if the last time step is correct
|
|
check(sample_batch["prev_obses"][-1], expected_obses[-6:-1:2])
|
|
|
|
# check if the padding in the beginning is correct
|
|
check(sample_batch["prev_obses"][0], np.ones((3, 1)) * expected_obses[0])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|