from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple import unittest import ray from ray import tune from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.utils.test_utils import framework_iterator class TestLSTMs(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init(num_cpus=5) @classmethod def tearDownClass(cls) -> None: ray.shutdown() def test_lstm_w_prev_action_and_prev_reward(self): """Tests LSTM prev-a/r input insertions using complex actions.""" config = { "env": RandomEnv, "env_config": { "action_space": Dict( { "a": Box(-1.0, 1.0, ()), "b": Box(-1.0, 1.0, (2,)), "c": Tuple( [ Discrete(2), MultiDiscrete([2, 3]), Box(-1.0, 1.0, (3,)), ] ), } ), }, # Need to set this to True to enable complex (prev.) actions # as inputs to the LSTM. "_disable_action_flattening": True, "model": { "fcnet_hiddens": [10], "use_lstm": True, "lstm_cell_size": 16, "lstm_use_prev_action": True, "lstm_use_prev_reward": True, }, "num_sgd_iter": 1, "train_batch_size": 200, "sgd_minibatch_size": 50, "rollout_fragment_length": 100, "num_workers": 1, } for _ in framework_iterator(config): tune.run("PPO", config=config, stop={"training_iteration": 1}, verbose=1) if __name__ == "__main__": import pytest import sys sys.exit(pytest.main(["-v", __file__]))