import unittest import ray import ray.rllib.agents.ppo as ppo from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.test_utils import framework_iterator class TestTrajectoryViewAPI(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init() @classmethod def tearDownClass(cls) -> None: ray.shutdown() def test_plain(self): config = ppo.DEFAULT_CONFIG.copy() for _ in framework_iterator(config, frameworks="torch"): trainer = ppo.PPOTrainer(config, env="CartPole-v0") policy = trainer.get_policy() view_req_model = policy.model.inference_view_requirements() view_req_policy = policy.training_view_requirements() assert len(view_req_model) == 1 assert len(view_req_policy) == 6 for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.DONES, SampleBatch.NEXT_OBS, SampleBatch.VF_PREDS ]: assert key in view_req_policy # None of the view cols has a special underlying data_col, # except next-obs. if key != SampleBatch.NEXT_OBS: assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS assert view_req_policy[key].shift == 1 trainer.stop() def test_lstm_prev_actions_and_rewards(self): config = ppo.DEFAULT_CONFIG.copy() config["model"] = config["model"].copy() # Activate LSTM + prev-action + rewards. config["model"]["use_lstm"] = True config["model"]["lstm_use_prev_action_reward"] = True for _ in framework_iterator(config, frameworks="torch"): trainer = ppo.PPOTrainer(config, env="CartPole-v0") policy = trainer.get_policy() view_req_model = policy.model.inference_view_requirements() view_req_policy = policy.training_view_requirements() assert len(view_req_model) == 3 # obs, prev_a, prev_r assert len(view_req_policy) == 8 for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.DONES, SampleBatch.NEXT_OBS, SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS ]: assert key in view_req_policy if key == SampleBatch.PREV_ACTIONS: assert view_req_policy[key].data_col == SampleBatch.ACTIONS assert view_req_policy[key].shift == -1 elif key == SampleBatch.PREV_REWARDS: assert view_req_policy[key].data_col == SampleBatch.REWARDS assert view_req_policy[key].shift == -1 elif key not in [SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS]: assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS assert view_req_policy[key].shift == 1 trainer.stop() if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))