mirror of
https://github.com/vale981/ray
synced 2025-03-11 21:56:39 -04:00
81 lines
3.3 KiB
Python
81 lines
3.3 KiB
Python
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.ppo as ppo
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
|
|
class TestTrajectoryViewAPI(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_plain(self):
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
|
for _ in framework_iterator(config, frameworks="torch"):
|
|
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
|
policy = trainer.get_policy()
|
|
view_req_model = policy.model.inference_view_requirements()
|
|
view_req_policy = policy.training_view_requirements()
|
|
assert len(view_req_model) == 1
|
|
assert len(view_req_policy) == 6
|
|
for key in [
|
|
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
|
SampleBatch.DONES, SampleBatch.NEXT_OBS, SampleBatch.VF_PREDS
|
|
]:
|
|
assert key in view_req_policy
|
|
# None of the view cols has a special underlying data_col,
|
|
# except next-obs.
|
|
if key != SampleBatch.NEXT_OBS:
|
|
assert view_req_policy[key].data_col is None
|
|
else:
|
|
assert view_req_policy[key].data_col == SampleBatch.OBS
|
|
assert view_req_policy[key].shift == 1
|
|
trainer.stop()
|
|
|
|
def test_lstm_prev_actions_and_rewards(self):
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
|
config["model"] = config["model"].copy()
|
|
# Activate LSTM + prev-action + rewards.
|
|
config["model"]["use_lstm"] = True
|
|
config["model"]["lstm_use_prev_action_reward"] = True
|
|
|
|
for _ in framework_iterator(config, frameworks="torch"):
|
|
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
|
policy = trainer.get_policy()
|
|
view_req_model = policy.model.inference_view_requirements()
|
|
view_req_policy = policy.training_view_requirements()
|
|
assert len(view_req_model) == 3 # obs, prev_a, prev_r
|
|
assert len(view_req_policy) == 8
|
|
for key in [
|
|
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
|
SampleBatch.DONES, SampleBatch.NEXT_OBS, SampleBatch.VF_PREDS,
|
|
SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS
|
|
]:
|
|
assert key in view_req_policy
|
|
|
|
if key == SampleBatch.PREV_ACTIONS:
|
|
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
|
|
assert view_req_policy[key].shift == -1
|
|
elif key == SampleBatch.PREV_REWARDS:
|
|
assert view_req_policy[key].data_col == SampleBatch.REWARDS
|
|
assert view_req_policy[key].shift == -1
|
|
elif key not in [SampleBatch.NEXT_OBS,
|
|
SampleBatch.PREV_ACTIONS,
|
|
SampleBatch.PREV_REWARDS]:
|
|
assert view_req_policy[key].data_col is None
|
|
else:
|
|
assert view_req_policy[key].data_col == SampleBatch.OBS
|
|
assert view_req_policy[key].shift == 1
|
|
trainer.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|