import copy import gym from gym.spaces import Box, Discrete import time import unittest import ray import ray.rllib.agents.dqn as dqn import ray.rllib.agents.ppo as ppo from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ EpisodeEnvAwareLSTMPolicy from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.test_utils import framework_iterator, check class TestTrajectoryViewAPI(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init() @classmethod def tearDownClass(cls) -> None: ray.shutdown() def test_traj_view_normal_case(self): """Tests, whether Model and Policy return the correct ViewRequirements. """ config = dqn.DEFAULT_CONFIG.copy() config["num_envs_per_worker"] = 10 config["rollout_fragment_length"] = 4 for _ in framework_iterator(config): trainer = dqn.DQNTrainer( config, env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv") policy = trainer.get_policy() view_req_model = policy.model.inference_view_requirements view_req_policy = policy.view_requirements assert len(view_req_model) == 1, view_req_model assert len(view_req_policy) == 8, view_req_policy for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.DONES, SampleBatch.NEXT_OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, "weights", ]: assert key in view_req_policy # None of the view cols has a special underlying data_col, # except next-obs. if key != SampleBatch.NEXT_OBS: assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS assert view_req_policy[key].data_rel_pos == 1 rollout_worker = trainer.workers.local_worker() sample_batch = rollout_worker.sample() expected_count = \ config["num_envs_per_worker"] * \ config["rollout_fragment_length"] assert sample_batch.count == expected_count for v in sample_batch.data.values(): assert len(v) == expected_count trainer.stop() def test_traj_view_lstm_prev_actions_and_rewards(self): """Tests, whether Policy/Model return correct LSTM ViewRequirements. """ config = ppo.DEFAULT_CONFIG.copy() config["model"] = config["model"].copy() # Activate LSTM + prev-action + rewards. config["model"]["use_lstm"] = True config["model"]["lstm_use_prev_action"] = True config["model"]["lstm_use_prev_reward"] = True for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config, env="CartPole-v0") policy = trainer.get_policy() view_req_model = policy.model.inference_view_requirements view_req_policy = policy.view_requirements # 7=obs, prev-a + r, 2x state-in, 2x state-out. assert len(view_req_model) == 7, view_req_model assert len(view_req_policy) == 19, view_req_policy for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.DONES, SampleBatch.NEXT_OBS, SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS, "advantages", "value_targets", SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP ]: assert key in view_req_policy if key == SampleBatch.PREV_ACTIONS: assert view_req_policy[key].data_col == SampleBatch.ACTIONS assert view_req_policy[key].data_rel_pos == -1 elif key == SampleBatch.PREV_REWARDS: assert view_req_policy[key].data_col == SampleBatch.REWARDS assert view_req_policy[key].data_rel_pos == -1 elif key not in [ SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS ]: assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS assert view_req_policy[key].data_rel_pos == 1 trainer.stop() def test_traj_view_simple_performance(self): """Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`. """ config = copy.deepcopy(ppo.DEFAULT_CONFIG) action_space = Discrete(2) obs_space = Box(-1.0, 1.0, shape=(700, )) from ray.rllib.examples.env.random_env import RandomMultiAgentEnv from ray.tune import register_env register_env("ma_env", lambda c: RandomMultiAgentEnv({ "num_agents": 2, "p_done": 0.0, "max_episode_len": 104, "action_space": action_space, "observation_space": obs_space })) config["num_workers"] = 3 config["num_envs_per_worker"] = 8 config["num_sgd_iter"] = 1 # Put less weight on training. policies = { "pol0": (None, obs_space, action_space, {}), } def policy_fn(agent_id): return "pol0" config["multiagent"] = { "policies": policies, "policy_mapping_fn": policy_fn, } num_iterations = 2 for _ in framework_iterator(config, frameworks="torch"): print("w/ traj. view API") config["_use_trajectory_view_api"] = True trainer = ppo.PPOTrainer(config=config, env="ma_env") learn_time_w = 0.0 sampler_perf_w = {} start = time.time() for i in range(num_iterations): out = trainer.train() ts = out["timesteps_total"] sampler_perf_ = out["sampler_perf"] sampler_perf_w = { k: sampler_perf_w.get(k, 0.0) + (sampler_perf_[k] * 1000 / ts) for k, v in sampler_perf_.items() } delta = out["timers"]["learn_time_ms"] / ts learn_time_w += delta print("{}={}s".format(i, delta)) sampler_perf_w = { k: sampler_perf_w[k] / (num_iterations if "mean_" in k else 1) for k, v in sampler_perf_w.items() } duration_w = time.time() - start print("Duration: {}s " "sampler-perf.={} learn-time/iter={}s".format( duration_w, sampler_perf_w, learn_time_w / num_iterations)) trainer.stop() print("w/o traj. view API") config["_use_trajectory_view_api"] = False trainer = ppo.PPOTrainer(config=config, env="ma_env") learn_time_wo = 0.0 sampler_perf_wo = {} start = time.time() for i in range(num_iterations): out = trainer.train() ts = out["timesteps_total"] sampler_perf_ = out["sampler_perf"] sampler_perf_wo = { k: sampler_perf_wo.get(k, 0.0) + (sampler_perf_[k] * 1000 / ts) for k, v in sampler_perf_.items() } delta = out["timers"]["learn_time_ms"] / ts learn_time_wo += delta print("{}={}s".format(i, delta)) sampler_perf_wo = { k: sampler_perf_wo[k] / (num_iterations if "mean_" in k else 1) for k, v in sampler_perf_wo.items() } duration_wo = time.time() - start print("Duration: {}s " "sampler-perf.={} learn-time/iter={}s".format( duration_wo, sampler_perf_wo, learn_time_wo / num_iterations)) trainer.stop() # Assert `_use_trajectory_view_api` is faster. self.assertLess(sampler_perf_w["mean_raw_obs_processing_ms"], sampler_perf_wo["mean_raw_obs_processing_ms"]) self.assertLess(sampler_perf_w["mean_action_processing_ms"], sampler_perf_wo["mean_action_processing_ms"]) self.assertLess(duration_w, duration_wo) def test_traj_view_next_action(self): action_space = Discrete(2) rollout_worker_w_api = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy_config=ppo.DEFAULT_CONFIG, rollout_fragment_length=200, policy_spec=ppo.PPOTorchPolicy, policy_mapping_fn=None, num_envs=1, ) # Add the next action to the view reqs of the policy. # This should be visible then in postprocessing and train batches. rollout_worker_w_api.policy_map["default_policy"].view_requirements[ "next_actions"] = ViewRequirement( SampleBatch.ACTIONS, shift=1, space=action_space) # Make sure, we have DONEs as well. rollout_worker_w_api.policy_map["default_policy"].view_requirements[ "dones"] = ViewRequirement() batch = rollout_worker_w_api.sample() self.assertTrue("next_actions" in batch.data) expected_a_ = None # expected next action for i in range(len(batch["actions"])): a, d, a_ = batch["actions"][i], batch["dones"][i], \ batch["next_actions"][i] if not d and expected_a_ is not None: check(a, expected_a_) elif d: check(a_, 0) expected_a_ = None continue expected_a_ = a_ def test_traj_view_lstm_functionality(self): action_space = Box(-float("inf"), float("inf"), shape=(3, )) obs_space = Box(float("-inf"), float("inf"), (4, )) max_seq_len = 50 rollout_fragment_length = 200 assert rollout_fragment_length % max_seq_len == 0 policies = { "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}), } def policy_fn(agent_id): return "pol0" config = { "multiagent": { "policies": policies, "policy_mapping_fn": policy_fn, }, "model": { "use_lstm": True, "max_seq_len": max_seq_len, }, }, rollout_worker_w_api = RolloutWorker( env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), policy_config=dict(config, **{"_use_trajectory_view_api": True}), rollout_fragment_length=rollout_fragment_length, policy_spec=policies, policy_mapping_fn=policy_fn, num_envs=1, ) rollout_worker_wo_api = RolloutWorker( env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), policy_config=dict(config, **{"_use_trajectory_view_api": False}), rollout_fragment_length=rollout_fragment_length, policy_spec=policies, policy_mapping_fn=policy_fn, num_envs=1, ) for iteration in range(20): result = rollout_worker_w_api.sample() check(result.count, rollout_fragment_length) pol_batch_w = result.policy_batches["pol0"] assert pol_batch_w.count >= rollout_fragment_length analyze_rnn_batch(pol_batch_w, max_seq_len) result = rollout_worker_wo_api.sample() pol_batch_wo = result.policy_batches["pol0"] check(pol_batch_w.data, pol_batch_wo.data) def analyze_rnn_batch(batch, max_seq_len): count = batch.count # Check prev_reward/action, next_obs consistency. for idx in range(count): # If timestep tracked by batch, good. if "t" in batch: ts = batch["t"][idx] # Else, ts else: ts = batch["obs"][idx][3] obs_t = batch["obs"][idx] a_t = batch["actions"][idx] r_t = batch["rewards"][idx] state_in_0 = batch["state_in_0"][idx] state_in_1 = batch["state_in_1"][idx] # Check postprocessing outputs. if "2xobs" in batch: postprocessed_col_t = batch["2xobs"][idx] assert (obs_t == postprocessed_col_t / 2.0).all() # Check state-in/out and next-obs values. if idx > 0: next_obs_t_m_1 = batch["new_obs"][idx - 1] state_out_0_t_m_1 = batch["state_out_0"][idx - 1] state_out_1_t_m_1 = batch["state_out_1"][idx - 1] # Same trajectory as for t-1 -> Should be able to match. if (batch[SampleBatch.AGENT_INDEX][idx] == batch[SampleBatch.AGENT_INDEX][idx - 1] and batch[SampleBatch.EPS_ID][idx] == batch[SampleBatch.EPS_ID][idx - 1]): assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx] assert (obs_t == next_obs_t_m_1).all() assert (state_in_0 == state_out_0_t_m_1).all() assert (state_in_1 == state_out_1_t_m_1).all() # Different trajectory. else: assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx] assert not (obs_t == next_obs_t_m_1).all() assert not (state_in_0 == state_out_0_t_m_1).all() assert not (state_in_1 == state_out_1_t_m_1).all() # Check initial 0-internal states. if ts == 0: assert (state_in_0 == 0.0).all() assert (state_in_1 == 0.0).all() # Check initial 0-internal states (at ts=0). if ts == 0: assert (state_in_0 == 0.0).all() assert (state_in_1 == 0.0).all() # Check prev. a/r values. if idx < count - 1: prev_actions_t_p_1 = batch["prev_actions"][idx + 1] prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1] # Same trajectory as for t+1 -> Should be able to match. if batch[SampleBatch.AGENT_INDEX][idx] == \ batch[SampleBatch.AGENT_INDEX][idx + 1] and \ batch[SampleBatch.EPS_ID][idx] == \ batch[SampleBatch.EPS_ID][idx + 1]: assert (a_t == prev_actions_t_p_1).all() assert r_t == prev_rewards_t_p_1 # Different (new) trajectory. Assume t-1 (prev-a/r) to be # always 0.0s. [3]=ts elif ts == 0: assert (prev_actions_t_p_1 == 0).all() assert prev_rewards_t_p_1 == 0.0 pad_batch_to_sequences_of_same_size( batch, max_seq_len=max_seq_len, shuffle=False, batch_divisibility_req=1) # Check after seq-len 0-padding. cursor = 0 for i, seq_len in enumerate(batch["seq_lens"]): state_in_0 = batch["state_in_0"][i] state_in_1 = batch["state_in_1"][i] for j in range(seq_len): k = cursor + j ts = batch["t"][k] obs_t = batch["obs"][k] a_t = batch["actions"][k] r_t = batch["rewards"][k] # Check postprocessing outputs. if "2xobs" in batch: postprocessed_col_t = batch["2xobs"][k] assert (obs_t == postprocessed_col_t / 2.0).all() # Check state-in/out and next-obs values. if j > 0: next_obs_t_m_1 = batch["new_obs"][k - 1] # state_out_0_t_m_1 = batch["state_out_0"][k - 1] # state_out_1_t_m_1 = batch["state_out_1"][k - 1] # Always same trajectory as for t-1. assert batch["unroll_id"][k - 1] == batch["unroll_id"][k] assert (obs_t == next_obs_t_m_1).all() # assert (state_in_0 == state_out_0_t_m_1).all()) # assert (state_in_1 == state_out_1_t_m_1).all()) # Check initial 0-internal states. elif ts == 0: assert (state_in_0 == 0.0).all() assert (state_in_1 == 0.0).all() for j in range(seq_len, max_seq_len): k = cursor + j obs_t = batch["obs"][k] a_t = batch["actions"][k] r_t = batch["rewards"][k] assert (obs_t == 0.0).all() assert (a_t == 0.0).all() assert (r_t == 0.0).all() cursor += max_seq_len if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))