ray/rllib/tests/test_dnc.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

74 lines
2.1 KiB
Python
Raw Permalink Normal View History

import gym
import unittest
import ray
from ray import tune
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.examples.models.neural_computer import DNCMemory
from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
class TestDNC(unittest.TestCase):
stop = {
"episode_reward_mean": 100.0,
"timesteps_total": 10000000,
}
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4, ignore_reinit_error=True)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_pack_unpack(self):
d = DNCMemory(gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
# Add batch dim
packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
[m.random_() for m in packed_state]
original_packed = [m.clone() for m in packed_state]
B, T = packed_state[0].shape[:2]
unpacked = d.unpack_state(packed_state)
packed = d.pack_state(*unpacked)
self.assertTrue(len(packed) > 0)
self.assertEqual(len(packed), len(original_packed))
for m_idx in range(len(packed)):
self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
def test_dnc_learning(self):
ModelCatalog.register_custom_model("dnc", DNCMemory)
config = {
"env": StatelessCartPole,
"gamma": 0.99,
"num_envs_per_worker": 5,
"framework": "torch",
"num_workers": 1,
"num_cpus_per_worker": 2.0,
"lr": 0.01,
"entropy_coeff": 0.0005,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "dnc",
"max_seq_len": 64,
"custom_model_config": {
"nr_cells": 10,
"cell_size": 8,
},
},
}
tune.run("A2C", config=config, stop=self.stop, verbose=1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))