ray/rllib/connectors/tests/test_agent.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

425 lines
16 KiB
Python
Raw Normal View History

import gym
import numpy as np
import unittest
import pytest
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.rllib.utils.test_utils import check
class TestAgentConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
pipeline = AgentConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
def test_obs_preprocessor_connector(self):
obs_space = gym.spaces.Dict(
{
"a": gym.spaces.Box(low=0, high=1, shape=(1,)),
"b": gym.spaces.Tuple(
[gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
),
}
)
ctx = ConnectorContext(config={}, observation_space=obs_space)
c = ObsPreprocessorConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
obs = obs_space.sample()
# Fake deterministic data.
obs["a"][0] = 0.5
obs["b"] = (1, np.array([0, 2]))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.OBS: obs,
},
)
preprocessed = c([d])
# obs is completely flattened.
self.assertTrue(
(preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
)
def test_clip_reward_connector(self):
ctx = ConnectorContext()
c = ClipRewardAgentConnector(ctx, limit=2.0)
name, params = c.to_config()
self.assertEqual(name, "ClipRewardAgentConnector")
self.assertAlmostEqual(params["limit"], 2.0)
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.REWARDS: 5.8,
},
)
clipped = restored([d])
self.assertEqual(len(clipped), 1)
self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
def test_flatten_data_connector(self):
ctx = ConnectorContext()
c = FlattenDataAgentConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
sample_batch = {
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.REWARDS: 5.8,
SampleBatch.ACTIONS: [[1, 1], [2]],
SampleBatch.INFOS: {"random": "info"},
}
d = AgentConnectorDataType(
0,
1,
# FlattenDataAgentConnector does NOT touch for_training dict,
# so simply pass None here.
AgentConnectorsOutput(None, sample_batch),
)
flattened = c([d])
self.assertEqual(len(flattened), 1)
batch = flattened[0].data.for_action
self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all())
self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
# Not flattened.
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
def test_state_buffer_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)),
)
c = StateBufferConnector(ctx)
# Reset without any buffered data should do nothing.
c.reset(env_id=0)
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
},
)
with_buffered = c([d])
self.assertEqual(len(with_buffered), 1)
self.assertTrue((with_buffered[0].data[SampleBatch.ACTIONS] == [0, 0, 0]).all())
c.on_policy_output(ActionConnectorDataType(0, 1, ([1, 2, 3], [], {})))
with_buffered = c([d])
self.assertEqual(len(with_buffered), 1)
self.assertEqual(with_buffered[0].data[SampleBatch.ACTIONS], [1, 2, 3])
def test_view_requirement_connector(self):
# TODO: @kourosh remove this test when we have a better way to test
view_requirements = {
"obs": ViewRequirement(
used_for_training=True, used_for_compute_actions=True
),
"prev_actions": ViewRequirement(
data_col="actions",
shift=-1,
used_for_training=True,
used_for_compute_actions=True,
),
}
ctx = ConnectorContext(view_requirements=view_requirements)
c = ViewRequirementAgentConnector(ctx)
f = FlattenDataAgentConnector(ctx)
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.ACTIONS: np.array(0),
},
)
# ViewRequirementAgentConnector then FlattenAgentConnector.
processed = f(c([d]))
self.assertTrue("obs" in processed[0].data.for_action)
self.assertTrue("prev_actions" in processed[0].data.for_action)
@pytest.mark.skip(reason="activate when view_requirement is fully implemented.")
class TestViewRequirementConnector(unittest.TestCase):
def test_vr_connector_respects_training_or_inference_vr_flags(self):
"""Tests that the connector respects the flags within view_requirements (i.e.
used_for_training, used_for_compute_actions) under different is_training modes.
For inference,
the returned data should be state -> obs
For training,
the returned data should be the data itself. The higher level policy
collector in env_runner will construct the proper data structure.
"""
view_rq_dict = {
"both": ViewRequirement(
"obs", used_for_training=True, used_for_compute_actions=True
),
"only_inference": ViewRequirement(
"obs", used_for_training=False, used_for_compute_actions=True
),
"none": ViewRequirement(
"obs", used_for_training=False, used_for_compute_actions=False
),
"only_training": ViewRequirement(
"obs", used_for_training=True, used_for_compute_actions=False
),
}
obs_arr = np.array([0, 1, 2, 3])
agent_data = dict(obs=obs_arr)
data = AgentConnectorDataType(0, 1, agent_data)
ctx = ConnectorContext(view_requirements=view_rq_dict)
# TODO @jun What is the expected behavior of this test?
for_action_expected_list = [
# is_training = False
SampleBatch({"both": obs_arr, "only_inference": obs_arr}),
# is_training = True
SampleBatch({"both": obs_arr, "only_inference": obs_arr}),
]
for_training_expected_list = [
# is_training = False
None,
# is_training = True
agent_data,
]
for is_training in [True, False]:
c = ViewRequirementAgentConnector(ctx)
c.is_training(is_training)
processed = c([data])
for_training = processed[0].data.for_training
for_training_expected = for_training_expected_list[is_training]
for_action = processed[0].data.for_action
for_action_expected = for_action_expected_list[is_training]
print("-" * 30)
print(f"is_training = {is_training}")
print("for action:")
print(for_action)
print("for training:")
print(for_training)
# TODO @jun is for_training expected to always be equal to data?
check(for_training, for_training_expected)
check(for_action, for_action_expected)
def test_vr_connector_shift_by_one(self):
"""Test that the ViewRequirementConnector can handle shift by one correctly and
can ignore future refrencing view_requirements to respect causality"""
view_rq_dict = {
"state": ViewRequirement("obs"),
"next_state": ViewRequirement(
"obs", shift=1, used_for_compute_actions=False
),
"prev_state": ViewRequirement("obs", shift=-1),
}
obs_arrs = np.arange(10)[:, None] + 1
ctx = ConnectorContext(view_requirements=view_rq_dict)
c = ViewRequirementAgentConnector(ctx)
for is_training in [True, False]:
c.is_training(is_training)
for i, obs_arr in enumerate(obs_arrs):
data = AgentConnectorDataType(0, 1, dict(obs=obs_arr))
processed = c([data])
for_action = processed[0].data.for_action
self.assertTrue("next_state" not in for_action)
check(for_action["state"], obs_arrs[i])
if i == 0:
check(for_action["prev_state"], np.array([0]))
else:
check(for_action["prev_state"], obs_arrs[i - 1])
def test_vr_connector_causal_slice(self):
"""Test that the ViewRequirementConnector can handle slice shifts correctly.
This includes things like `-2:0:1`. `start:end:step` should be interpreted as
np.arange(start, end, step). Both start and end have to be specified when using
this format. If step is not specified it defaults to 1.
"""
view_rq_dict = {
"state": ViewRequirement("obs"),
# shift array should be [-2, -1]
"prev_states": ViewRequirement("obs", shift="-2:0"),
# shift array should be [-4, -2]
"prev_strided_states_even": ViewRequirement("obs", shift="-4:0:2"),
# shift array should be [-3, -1]
"prev_strided_states_odd": ViewRequirement("obs", shift="-3:0:2"),
}
obs_arrs = np.arange(10)[:, None] + 1
ctx = ConnectorContext(view_requirements=view_rq_dict)
c = ViewRequirementAgentConnector(ctx)
for is_training in [True, False]:
c.is_training(is_training)
for i, obs_arr in enumerate(obs_arrs):
data = AgentConnectorDataType(0, 1, dict(obs=obs_arr))
processed = c([data])
for_action = processed[0].data.for_action
check(for_action["state"], obs_arrs[i])
# check prev_states
if i == 0:
check(for_action["prev_states"], np.array([[0], [0]]))
elif i == 1:
check(for_action["prev_states"], np.array([[0], [1]]))
else:
check(for_action["prev_states"], obs_arrs[i - 2 : i])
# check strided states
if i == 0:
# for this case they should all be equal to the padded value
check(
for_action["prev_states"],
for_action["prev_strided_states_even"],
)
check(
for_action["prev_states"], for_action["prev_strided_states_odd"]
)
elif i == 1:
check(
for_action["prev_state"], for_action["prev_strided_states_even"]
)
check(
for_action["prev_strided_states_odd"],
np.array([[0], [1]]), # [-2, 0]
)
elif i == 2:
check(
for_action["prev_strided_states_even"],
np.array([[0], [1]]), # [-2, 0]
)
check(
for_action["prev_strided_states_odd"],
np.array([[0], [2]]), # [-1, 1]
)
elif i == 3:
check(
for_action["prev_strided_states_even"],
np.array([[0], [2]]), # [-1, 1]
)
check(
for_action["prev_strided_states_odd"],
np.array([[1], [3]]), # [0, 2]
)
else:
check(
for_action["prev_strided_states_even"], obs_arrs[i - 4 : i : 2]
)
check(
for_action["prev_strided_states_even"], obs_arrs[i - 3 : i : 2]
)
def test_vr_connector_with_multiple_buffers(self):
"""Test that the ViewRequirementConnector can handle slice shifts correctly
when it has multiple buffers to shift."""
context_len = 5
# This view requirement simulates the use-case of a decision transformer
# without reward-to-go.
view_rq_dict = {
# obs[t-context_len:t+1]
"context_obs": ViewRequirement("obs", shift=f"{-context_len}:1"),
# act[t-context_len:t]
"context_act": ViewRequirement("act", shift=f"{-context_len}:0"),
}
obs_arrs = np.arange(10)[:, None] + 1
act_arrs = np.arange(10)[:, None] * 100 + 1
n_steps = obs_arrs.shape[0]
ctx = ConnectorContext(view_requirements=view_rq_dict)
c = ViewRequirementAgentConnector(ctx)
for is_training in [True, False]:
c.is_training(is_training)
for i in range(n_steps):
data = AgentConnectorDataType(
0, 1, dict(obs=obs_arrs[i], act=act_arrs[i])
)
processed = c([data])
for_action = processed[0].data.for_action
if i < context_len:
check(
for_action["context_obs"],
np.concatenate([np.array([[0] * i]), obs_arrs[: i + 1]]),
)
check(
for_action["context_act"],
np.concatenate([np.array([[0] * i]), act_arrs[:i]]),
)
else:
check(for_action["context_obs"], obs_arrs[i - context_len : i + 1])
check(for_action["context_act"], act_arrs[i - context_len : i])
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))