2021-09-29 21:31:34 +02:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
2022-05-22 18:58:47 +01:00
|
|
|
from ray.rllib.algorithms import sac
|
2021-09-29 21:31:34 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
from ray.rllib.utils.test_utils import check_compute_single_action, framework_iterator
|
|
|
|
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
|
|
|
|
|
|
class TestRNNSAC(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_rnnsac_compilation(self):
|
2022-06-04 07:35:24 +02:00
|
|
|
"""Test whether RNNSAC can be built on all frameworks."""
|
2022-05-22 18:58:47 +01:00
|
|
|
config = (
|
|
|
|
sac.RNNSACConfig()
|
|
|
|
.rollouts(num_rollout_workers=0)
|
|
|
|
.training(
|
|
|
|
# Wrap with an LSTM and use a very simple base-model.
|
|
|
|
model={"max_seq_len": 20},
|
|
|
|
policy_model_config={
|
|
|
|
"use_lstm": True,
|
|
|
|
"lstm_cell_size": 64,
|
|
|
|
"fcnet_hiddens": [10],
|
|
|
|
"lstm_use_prev_action": True,
|
|
|
|
"lstm_use_prev_reward": True,
|
|
|
|
},
|
|
|
|
q_model_config={
|
|
|
|
"use_lstm": True,
|
|
|
|
"lstm_cell_size": 64,
|
|
|
|
"fcnet_hiddens": [10],
|
|
|
|
"lstm_use_prev_action": True,
|
|
|
|
"lstm_use_prev_reward": True,
|
|
|
|
},
|
|
|
|
replay_buffer_config={
|
|
|
|
"type": "MultiAgentPrioritizedReplayBuffer",
|
|
|
|
"replay_burn_in": 20,
|
|
|
|
"zero_init_states": True,
|
|
|
|
},
|
|
|
|
lr=5e-4,
|
|
|
|
)
|
|
|
|
)
|
2021-09-29 21:31:34 +02:00
|
|
|
num_iterations = 1
|
|
|
|
|
|
|
|
# Test building an RNNSAC agent in all frameworks.
|
|
|
|
for _ in framework_iterator(config, frameworks="torch"):
|
2022-06-11 15:10:39 +02:00
|
|
|
algo = config.build(env="CartPole-v0")
|
2021-09-29 21:31:34 +02:00
|
|
|
for i in range(num_iterations):
|
2022-06-11 15:10:39 +02:00
|
|
|
results = algo.train()
|
2021-09-29 21:31:34 +02:00
|
|
|
print(results)
|
|
|
|
|
|
|
|
check_compute_single_action(
|
2022-06-11 15:10:39 +02:00
|
|
|
algo,
|
2021-09-29 21:31:34 +02:00
|
|
|
include_state=True,
|
|
|
|
include_prev_action_reward=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-09-29 21:31:34 +02:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|