ray/rllib/connectors/tests/test_agent.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

463 lines
16 KiB
Python
Raw Normal View History

import gym
import numpy as np
import unittest
from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentConnectorsOutput,
)
class TestAgentConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
pipeline = AgentConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
def test_obs_preprocessor_connector(self):
obs_space = gym.spaces.Dict(
{
"a": gym.spaces.Box(low=0, high=1, shape=(1,)),
"b": gym.spaces.Tuple(
[gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
),
}
)
ctx = ConnectorContext(config={}, observation_space=obs_space)
c = ObsPreprocessorConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
obs = obs_space.sample()
# Fake deterministic data.
obs["a"][0] = 0.5
obs["b"] = (1, np.array([0, 2]))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.OBS: obs,
},
)
preprocessed = c([d])
# obs is completely flattened.
self.assertTrue(
(preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
)
def test_clip_reward_connector(self):
ctx = ConnectorContext()
c = ClipRewardAgentConnector(ctx, limit=2.0)
name, params = c.to_config()
self.assertEqual(name, "ClipRewardAgentConnector")
self.assertAlmostEqual(params["limit"], 2.0)
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.REWARDS: 5.8,
},
)
clipped = restored([d])
self.assertEqual(len(clipped), 1)
self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
def test_flatten_data_connector(self):
ctx = ConnectorContext()
c = FlattenDataAgentConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
sample_batch = {
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.REWARDS: 5.8,
SampleBatch.ACTIONS: [[1, 1], [2]],
SampleBatch.INFOS: {"random": "info"},
}
d = AgentConnectorDataType(
0,
1,
# FlattenDataAgentConnector does NOT touch for_training dict,
# so simply pass None here.
AgentConnectorsOutput(None, sample_batch),
)
flattened = c([d])
self.assertEqual(len(flattened), 1)
batch = flattened[0].data.for_action
self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all())
self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
# Not flattened.
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
def test_state_buffer_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)),
)
c = StateBufferConnector(ctx)
# Reset without any buffered data should do nothing.
c.reset(env_id=0)
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
},
)
with_buffered = c([d])
self.assertEqual(len(with_buffered), 1)
self.assertTrue((with_buffered[0].data[SampleBatch.ACTIONS] == [0, 0, 0]).all())
c.on_policy_output(ActionConnectorDataType(0, 1, ([1, 2, 3], [], {})))
with_buffered = c([d])
self.assertEqual(len(with_buffered), 1)
self.assertEqual(with_buffered[0].data[SampleBatch.ACTIONS], [1, 2, 3])
class TestViewRequirementConnector(unittest.TestCase):
def test_vr_connector_respects_training_or_inference_vr_flags(self):
"""Tests that the connector respects the flags within view_requirements (i.e.
used_for_training, used_for_compute_actions) under different is_training modes.
is_training = False (inference mode)
the returned data is a SampleBatch that can be used to run corresponding
policy.
is_training = True (training mode)
the returned data is the input dict itself, which the policy collector in
env_runner will use to construct the episode, and a SampleBatch that can be
used to run corresponding policy.
"""
view_rq_dict = {
"both": ViewRequirement(
"obs", used_for_training=True, used_for_compute_actions=True
),
"only_inference": ViewRequirement(
"obs", used_for_training=False, used_for_compute_actions=True
),
"none": ViewRequirement(
"obs", used_for_training=False, used_for_compute_actions=False
),
"only_training": ViewRequirement(
"obs", used_for_training=True, used_for_compute_actions=False
),
}
obs_arr = np.array([0, 1, 2, 3])
agent_data = {SampleBatch.NEXT_OBS: obs_arr}
data = AgentConnectorDataType(0, 1, agent_data)
config = PPOConfig().to_dict()
ctx = ConnectorContext(
view_requirements=view_rq_dict,
config=config,
is_policy_recurrent=True,
)
for_action_expected = SampleBatch(
{
"both": obs_arr[None],
"only_inference": obs_arr[None],
"seq_lens": np.array([1]),
}
)
for_training_expected_list = [
# is_training = False
None,
# is_training = True
agent_data,
]
for is_training in [True, False]:
c = ViewRequirementAgentConnector(ctx)
c.is_training(is_training)
processed = c([data])
for_training = processed[0].data.for_training
for_training_expected = for_training_expected_list[is_training]
for_action = processed[0].data.for_action
print("-" * 30)
print(f"is_training = {is_training}")
print("for action:")
print(for_action)
print("for training:")
print(for_training)
check(for_training, for_training_expected)
check(for_action, for_action_expected)
def test_vr_connector_shift_by_one(self):
"""Test that the ViewRequirementConnector can handle shift by one correctly and
can ignore future referencing view_requirements to respect causality"""
view_rq_dict = {
"state": ViewRequirement("obs"),
"next_state": ViewRequirement(
"obs", shift=1, used_for_compute_actions=False
),
"prev_state": ViewRequirement("obs", shift=-1),
}
obs_arrs = np.arange(10)[:, None] + 1
config = PPOConfig().to_dict()
ctx = ConnectorContext(
view_requirements=view_rq_dict, config=config, is_policy_recurrent=True
)
c = ViewRequirementAgentConnector(ctx)
# keep a running list of observations
obs_list = []
for t, obs in enumerate(obs_arrs):
# t=0 is the next state of t=-1
data = AgentConnectorDataType(
0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: t - 1}
)
processed = c([data]) # env.reset() for t == -1 else env.step()
for_action = processed[0].data.for_action
# add cur obs to the list
obs_list.append(obs)
if t == 0:
check(for_action["prev_state"], for_action["state"])
else:
# prev state should be equal to the prev time step obs
check(for_action["prev_state"], obs_list[-2][None])
def test_vr_connector_causal_slice(self):
"""Test that the ViewRequirementConnector can handle slice shifts correctly."""
view_rq_dict = {
"state": ViewRequirement("obs"),
# shift array should be [-2, -1, 0]
"prev_states": ViewRequirement("obs", shift="-2:0"),
# shift array should be [-4, -2, 0]
"prev_strided_states_even": ViewRequirement("obs", shift="-4:0:2"),
# shift array should be [-3, -1]
"prev_strided_states_odd": ViewRequirement("obs", shift="-3:0:2"),
}
obs_arrs = np.arange(10)[:, None] + 1
config = PPOConfig().to_dict()
ctx = ConnectorContext(
view_requirements=view_rq_dict, config=config, is_policy_recurrent=True
)
c = ViewRequirementAgentConnector(ctx)
# keep a queue of observations
obs_list = []
for t, obs in enumerate(obs_arrs):
# t=0 is the next state of t=-1
data = AgentConnectorDataType(
0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: t - 1}
)
processed = c([data])
for_action = processed[0].data.for_action
if t == 0:
obs_list.extend([obs for _ in range(5)])
else:
# remove the first obs and add the current obs to the end
obs_list.pop(0)
obs_list.append(obs)
# check state
check(for_action["state"], obs[None])
# check prev_states
check(
for_action["prev_states"],
np.stack(obs_list)[np.array([-3, -2, -1])][None],
)
# check prev_strided_states_even
check(
for_action["prev_strided_states_even"],
np.stack(obs_list)[np.array([-5, -3, -1])][None],
)
check(
for_action["prev_strided_states_odd"],
np.stack(obs_list)[np.array([-4, -2])][None],
)
def test_vr_connector_with_multiple_buffers(self):
"""Test that the ViewRequirementConnector can handle slice shifts correctly
when it has multiple buffers to shift."""
context_len = 5
# This view requirement simulates the use-case of a decision transformer
# without reward-to-go.
view_rq_dict = {
# obs[t-context_len+1:t]
"context_obs": ViewRequirement("obs", shift=f"-{context_len-1}:0"),
# next_obs[t-context_len+1:t]
"context_next_obs": ViewRequirement(
"obs", shift=f"-{context_len}:1", used_for_compute_actions=False
),
# act[t-context_len+1:t]
"context_act": ViewRequirement(
SampleBatch.ACTIONS, shift=f"-{context_len-1}:-1"
),
}
obs_arrs = np.arange(10)[:, None] + 1
act_arrs = (np.arange(10)[:, None] + 1) * 100
n_steps = obs_arrs.shape[0]
config = PPOConfig().to_dict()
ctx = ConnectorContext(
view_requirements=view_rq_dict, config=config, is_policy_recurrent=True
)
c = ViewRequirementAgentConnector(ctx)
# keep a queue of length ctx_len of observations
obs_list, act_list = [], []
for t in range(n_steps):
# next state and action at time t-1 are the following
timestep_data = {
SampleBatch.NEXT_OBS: obs_arrs[t],
SampleBatch.ACTIONS: (
np.zeros_like(act_arrs[0]) if t == 0 else act_arrs[t - 1]
),
SampleBatch.T: t - 1,
}
data = AgentConnectorDataType(0, 1, timestep_data)
processed = c([data])
for_action = processed[0].data.for_action
if t == 0:
obs_list.extend([obs_arrs[0] for _ in range(context_len)])
act_list.extend(
[np.zeros_like(act_arrs[0]) for _ in range(context_len)]
)
else:
obs_list.pop(0)
act_list.pop(0)
obs_list.append(obs_arrs[t])
act_list.append(act_arrs[t - 1])
self.assertTrue("context_next_obs" not in for_action)
check(for_action["context_obs"], np.stack(obs_list)[None])
check(for_action["context_act"], np.stack(act_list[:-1])[None])
def test_connector_pipline_with_view_requirement(self):
"""A very minimal test that checks wheter pipeline connectors work in a
simulation rollout."""
# TODO: make this test beefier and more comprehensive
config = (
PPOConfig()
.framework("torch")
.environment(env="CartPole-v0")
.rollouts(create_env_on_local_worker=True)
)
algo = PPO(config)
rollout_worker = algo.workers.local_worker()
policy = rollout_worker.get_policy()
env = rollout_worker.env
# create a connector context
ctx = ConnectorContext(
view_requirements=policy.view_requirements,
config=policy.config,
initial_states=policy.get_initial_state(),
is_policy_recurrent=policy.is_recurrent(),
observation_space=policy.observation_space,
action_space=policy.action_space,
)
# build chain of connectors
connectors = [
ObsPreprocessorConnector(ctx),
StateBufferConnector(ctx),
ViewRequirementAgentConnector(ctx),
]
agent_connector = AgentConnectorPipeline(ctx, connectors)
name, params = agent_connector.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
for cidx, c in enumerate(connectors):
check(restored.connectors[cidx].to_config(), c.to_config())
# simulate a rollout
n_steps = 10
obs = env.reset()
env_out = AgentConnectorDataType(
0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: -1}
)
agent_obs = agent_connector([env_out])[0]
t = 0
total_rewards = 0
while t < n_steps:
policy_output = policy.compute_actions_from_input_dict(
agent_obs.data.for_action
)
agent_connector.on_policy_output(
ActionConnectorDataType(0, 1, policy_output)
)
action = policy_output[0][0]
next_obs, rewards, dones, info = env.step(action)
env_out_dict = {
SampleBatch.NEXT_OBS: next_obs,
SampleBatch.REWARDS: rewards,
SampleBatch.DONES: dones,
SampleBatch.INFOS: info,
SampleBatch.ACTIONS: action,
SampleBatch.T: t,
# state_out
}
env_out = AgentConnectorDataType(0, 1, env_out_dict)
agent_obs = agent_connector([env_out])[0]
total_rewards += rewards
t += 1
print(total_rewards)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))