ray/rllib/evaluation/tests/test_trajectory_view_api.py

421 lines
17 KiB
Python

import copy
import gym
from gym.spaces import Box, Discrete
import time
import unittest
import ray
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwareLSTMPolicy
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.test_utils import framework_iterator, check
class TestTrajectoryViewAPI(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_traj_view_normal_case(self):
"""Tests, whether Model and Policy return the correct ViewRequirements.
"""
config = dqn.DEFAULT_CONFIG.copy()
config["num_envs_per_worker"] = 10
config["rollout_fragment_length"] = 4
for _ in framework_iterator(config):
trainer = dqn.DQNTrainer(
config,
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.view_requirements
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 8, view_req_policy
for key in [
SampleBatch.OBS,
SampleBatch.ACTIONS,
SampleBatch.REWARDS,
SampleBatch.DONES,
SampleBatch.NEXT_OBS,
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX,
"weights",
]:
assert key in view_req_policy
# None of the view cols has a special underlying data_col,
# except next-obs.
if key != SampleBatch.NEXT_OBS:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].data_rel_pos == 1
rollout_worker = trainer.workers.local_worker()
sample_batch = rollout_worker.sample()
expected_count = \
config["num_envs_per_worker"] * \
config["rollout_fragment_length"]
assert sample_batch.count == expected_count
for v in sample_batch.data.values():
assert len(v) == expected_count
trainer.stop()
def test_traj_view_lstm_prev_actions_and_rewards(self):
"""Tests, whether Policy/Model return correct LSTM ViewRequirements.
"""
config = ppo.DEFAULT_CONFIG.copy()
config["model"] = config["model"].copy()
# Activate LSTM + prev-action + rewards.
config["model"]["use_lstm"] = True
config["model"]["lstm_use_prev_action"] = True
config["model"]["lstm_use_prev_reward"] = True
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.view_requirements
# 7=obs, prev-a + r, 2x state-in, 2x state-out.
assert len(view_req_model) == 7, view_req_model
assert len(view_req_policy) == 19, view_req_policy
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS, "advantages", "value_targets",
SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
]:
assert key in view_req_policy
if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].data_rel_pos == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].data_rel_pos == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
]:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].data_rel_pos == 1
trainer.stop()
def test_traj_view_simple_performance(self):
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
"""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
action_space = Discrete(2)
obs_space = Box(-1.0, 1.0, shape=(700, ))
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
from ray.tune import register_env
register_env("ma_env", lambda c: RandomMultiAgentEnv({
"num_agents": 2,
"p_done": 0.0,
"max_episode_len": 104,
"action_space": action_space,
"observation_space": obs_space
}))
config["num_workers"] = 3
config["num_envs_per_worker"] = 8
config["num_sgd_iter"] = 1 # Put less weight on training.
policies = {
"pol0": (None, obs_space, action_space, {}),
}
def policy_fn(agent_id):
return "pol0"
config["multiagent"] = {
"policies": policies,
"policy_mapping_fn": policy_fn,
}
num_iterations = 2
for _ in framework_iterator(config, frameworks="torch"):
print("w/ traj. view API")
config["_use_trajectory_view_api"] = True
trainer = ppo.PPOTrainer(config=config, env="ma_env")
learn_time_w = 0.0
sampler_perf_w = {}
start = time.time()
for i in range(num_iterations):
out = trainer.train()
ts = out["timesteps_total"]
sampler_perf_ = out["sampler_perf"]
sampler_perf_w = {
k:
sampler_perf_w.get(k, 0.0) + (sampler_perf_[k] * 1000 / ts)
for k, v in sampler_perf_.items()
}
delta = out["timers"]["learn_time_ms"] / ts
learn_time_w += delta
print("{}={}s".format(i, delta))
sampler_perf_w = {
k: sampler_perf_w[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf_w.items()
}
duration_w = time.time() - start
print("Duration: {}s "
"sampler-perf.={} learn-time/iter={}s".format(
duration_w, sampler_perf_w,
learn_time_w / num_iterations))
trainer.stop()
print("w/o traj. view API")
config["_use_trajectory_view_api"] = False
trainer = ppo.PPOTrainer(config=config, env="ma_env")
learn_time_wo = 0.0
sampler_perf_wo = {}
start = time.time()
for i in range(num_iterations):
out = trainer.train()
ts = out["timesteps_total"]
sampler_perf_ = out["sampler_perf"]
sampler_perf_wo = {
k: sampler_perf_wo.get(k, 0.0) +
(sampler_perf_[k] * 1000 / ts)
for k, v in sampler_perf_.items()
}
delta = out["timers"]["learn_time_ms"] / ts
learn_time_wo += delta
print("{}={}s".format(i, delta))
sampler_perf_wo = {
k: sampler_perf_wo[k] / (num_iterations if "mean_" in k else 1)
for k, v in sampler_perf_wo.items()
}
duration_wo = time.time() - start
print("Duration: {}s "
"sampler-perf.={} learn-time/iter={}s".format(
duration_wo, sampler_perf_wo,
learn_time_wo / num_iterations))
trainer.stop()
# Assert `_use_trajectory_view_api` is faster.
self.assertLess(sampler_perf_w["mean_raw_obs_processing_ms"],
sampler_perf_wo["mean_raw_obs_processing_ms"])
self.assertLess(sampler_perf_w["mean_action_processing_ms"],
sampler_perf_wo["mean_action_processing_ms"])
self.assertLess(duration_w, duration_wo)
def test_traj_view_next_action(self):
action_space = Discrete(2)
rollout_worker_w_api = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_config=ppo.DEFAULT_CONFIG,
rollout_fragment_length=200,
policy_spec=ppo.PPOTorchPolicy,
policy_mapping_fn=None,
num_envs=1,
)
# Add the next action to the view reqs of the policy.
# This should be visible then in postprocessing and train batches.
rollout_worker_w_api.policy_map["default_policy"].view_requirements[
"next_actions"] = ViewRequirement(
SampleBatch.ACTIONS, shift=1, space=action_space)
# Make sure, we have DONEs as well.
rollout_worker_w_api.policy_map["default_policy"].view_requirements[
"dones"] = ViewRequirement()
batch = rollout_worker_w_api.sample()
self.assertTrue("next_actions" in batch.data)
expected_a_ = None # expected next action
for i in range(len(batch["actions"])):
a, d, a_ = batch["actions"][i], batch["dones"][i], \
batch["next_actions"][i]
if not d and expected_a_ is not None:
check(a, expected_a_)
elif d:
check(a_, 0)
expected_a_ = None
continue
expected_a_ = a_
def test_traj_view_lstm_functionality(self):
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
obs_space = Box(float("-inf"), float("inf"), (4, ))
max_seq_len = 50
rollout_fragment_length = 200
assert rollout_fragment_length % max_seq_len == 0
policies = {
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
}
def policy_fn(agent_id):
return "pol0"
config = {
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_fn,
},
"model": {
"use_lstm": True,
"max_seq_len": max_seq_len,
},
},
rollout_worker_w_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
rollout_fragment_length=rollout_fragment_length,
policy_spec=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
rollout_worker_wo_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": False}),
rollout_fragment_length=rollout_fragment_length,
policy_spec=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
for iteration in range(20):
result = rollout_worker_w_api.sample()
check(result.count, rollout_fragment_length)
pol_batch_w = result.policy_batches["pol0"]
assert pol_batch_w.count >= rollout_fragment_length
analyze_rnn_batch(pol_batch_w, max_seq_len)
result = rollout_worker_wo_api.sample()
pol_batch_wo = result.policy_batches["pol0"]
check(pol_batch_w.data, pol_batch_wo.data)
def analyze_rnn_batch(batch, max_seq_len):
count = batch.count
# Check prev_reward/action, next_obs consistency.
for idx in range(count):
# If timestep tracked by batch, good.
if "t" in batch:
ts = batch["t"][idx]
# Else, ts
else:
ts = batch["obs"][idx][3]
obs_t = batch["obs"][idx]
a_t = batch["actions"][idx]
r_t = batch["rewards"][idx]
state_in_0 = batch["state_in_0"][idx]
state_in_1 = batch["state_in_1"][idx]
# Check postprocessing outputs.
if "2xobs" in batch:
postprocessed_col_t = batch["2xobs"][idx]
assert (obs_t == postprocessed_col_t / 2.0).all()
# Check state-in/out and next-obs values.
if idx > 0:
next_obs_t_m_1 = batch["new_obs"][idx - 1]
state_out_0_t_m_1 = batch["state_out_0"][idx - 1]
state_out_1_t_m_1 = batch["state_out_1"][idx - 1]
# Same trajectory as for t-1 -> Should be able to match.
if (batch[SampleBatch.AGENT_INDEX][idx] ==
batch[SampleBatch.AGENT_INDEX][idx - 1]
and batch[SampleBatch.EPS_ID][idx] ==
batch[SampleBatch.EPS_ID][idx - 1]):
assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx]
assert (obs_t == next_obs_t_m_1).all()
assert (state_in_0 == state_out_0_t_m_1).all()
assert (state_in_1 == state_out_1_t_m_1).all()
# Different trajectory.
else:
assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx]
assert not (obs_t == next_obs_t_m_1).all()
assert not (state_in_0 == state_out_0_t_m_1).all()
assert not (state_in_1 == state_out_1_t_m_1).all()
# Check initial 0-internal states.
if ts == 0:
assert (state_in_0 == 0.0).all()
assert (state_in_1 == 0.0).all()
# Check initial 0-internal states (at ts=0).
if ts == 0:
assert (state_in_0 == 0.0).all()
assert (state_in_1 == 0.0).all()
# Check prev. a/r values.
if idx < count - 1:
prev_actions_t_p_1 = batch["prev_actions"][idx + 1]
prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1]
# Same trajectory as for t+1 -> Should be able to match.
if batch[SampleBatch.AGENT_INDEX][idx] == \
batch[SampleBatch.AGENT_INDEX][idx + 1] and \
batch[SampleBatch.EPS_ID][idx] == \
batch[SampleBatch.EPS_ID][idx + 1]:
assert (a_t == prev_actions_t_p_1).all()
assert r_t == prev_rewards_t_p_1
# Different (new) trajectory. Assume t-1 (prev-a/r) to be
# always 0.0s. [3]=ts
elif ts == 0:
assert (prev_actions_t_p_1 == 0).all()
assert prev_rewards_t_p_1 == 0.0
pad_batch_to_sequences_of_same_size(
batch,
max_seq_len=max_seq_len,
shuffle=False,
batch_divisibility_req=1)
# Check after seq-len 0-padding.
cursor = 0
for i, seq_len in enumerate(batch["seq_lens"]):
state_in_0 = batch["state_in_0"][i]
state_in_1 = batch["state_in_1"][i]
for j in range(seq_len):
k = cursor + j
ts = batch["t"][k]
obs_t = batch["obs"][k]
a_t = batch["actions"][k]
r_t = batch["rewards"][k]
# Check postprocessing outputs.
if "2xobs" in batch:
postprocessed_col_t = batch["2xobs"][k]
assert (obs_t == postprocessed_col_t / 2.0).all()
# Check state-in/out and next-obs values.
if j > 0:
next_obs_t_m_1 = batch["new_obs"][k - 1]
# state_out_0_t_m_1 = batch["state_out_0"][k - 1]
# state_out_1_t_m_1 = batch["state_out_1"][k - 1]
# Always same trajectory as for t-1.
assert batch["unroll_id"][k - 1] == batch["unroll_id"][k]
assert (obs_t == next_obs_t_m_1).all()
# assert (state_in_0 == state_out_0_t_m_1).all())
# assert (state_in_1 == state_out_1_t_m_1).all())
# Check initial 0-internal states.
elif ts == 0:
assert (state_in_0 == 0.0).all()
assert (state_in_1 == 0.0).all()
for j in range(seq_len, max_seq_len):
k = cursor + j
obs_t = batch["obs"][k]
a_t = batch["actions"][k]
r_t = batch["rewards"][k]
assert (obs_t == 0.0).all()
assert (a_t == 0.0).all()
assert (r_t == 0.0).all()
cursor += max_seq_len
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))