ray/rllib/models/tests/test_lstms.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

62 lines
1.9 KiB
Python

from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import unittest
import ray
from ray import tune
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator
class TestLSTMs(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=5)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_lstm_w_prev_action_and_prev_reward(self):
"""Tests LSTM prev-a/r input insertions using complex actions."""
config = {
"env": RandomEnv,
"env_config": {
"action_space": Dict(
{
"a": Box(-1.0, 1.0, ()),
"b": Box(-1.0, 1.0, (2,)),
"c": Tuple(
[
Discrete(2),
MultiDiscrete([2, 3]),
Box(-1.0, 1.0, (3,)),
]
),
}
),
},
# Need to set this to True to enable complex (prev.) actions
# as inputs to the LSTM.
"_disable_action_flattening": True,
"model": {
"fcnet_hiddens": [10],
"use_lstm": True,
"lstm_cell_size": 16,
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
},
"num_sgd_iter": 1,
"train_batch_size": 200,
"sgd_minibatch_size": 50,
"rollout_fragment_length": 100,
"num_workers": 1,
}
for _ in framework_iterator(config):
tune.run("PPO", config=config, stop={"training_iteration": 1}, verbose=1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))