ray/rllib/tests/test_dnc.py

74 lines
2.1 KiB
Python
Raw Normal View History

import gym
import unittest
import ray
from ray import tune
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.examples.models.neural_computer import DNCMemory
from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
class TestDNC(unittest.TestCase):
stop = {
"episode_reward_mean": 100.0,
"timesteps_total": 10000000,
}
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4, ignore_reinit_error=True)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_pack_unpack(self):
d = DNCMemory(
gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
# Add batch dim
packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
[m.random_() for m in packed_state]
original_packed = [m.clone() for m in packed_state]
B, T = packed_state[0].shape[:2]
unpacked = d.unpack_state(packed_state)
packed = d.pack_state(*unpacked)
self.assertTrue(len(packed) > 0)
self.assertEqual(len(packed), len(original_packed))
for m_idx in range(len(packed)):
self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
def test_dnc_learning(self):
ModelCatalog.register_custom_model("dnc", DNCMemory)
config = {
"env": StatelessCartPole,
"gamma": 0.99,
"num_envs_per_worker": 5,
"framework": "torch",
"num_workers": 1,
"num_cpus_per_worker": 2.0,
"lr": 0.01,
"entropy_coeff": 0.0005,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "dnc",
"max_seq_len": 64,
"custom_model_config": {
"nr_cells": 10,
"cell_size": 8,
},
},
}
tune.run("A2C", config=config, stop=self.stop, verbose=1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))