2021-07-25 16:04:52 +02:00
|
|
|
import gym
|
2021-05-19 00:15:39 -07:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray import tune
|
|
|
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
|
|
from ray.rllib.examples.models.neural_computer import DNCMemory
|
2021-07-25 16:04:52 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
|
|
|
|
torch, _ = try_import_torch()
|
2021-05-19 00:15:39 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestDNC(unittest.TestCase):
|
|
|
|
stop = {
|
|
|
|
"episode_reward_mean": 100.0,
|
|
|
|
"timesteps_total": 10000000,
|
|
|
|
}
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init(num_cpus=4, ignore_reinit_error=True)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_pack_unpack(self):
|
2022-01-29 18:41:57 -08:00
|
|
|
d = DNCMemory(gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
|
2021-05-19 00:15:39 -07:00
|
|
|
# Add batch dim
|
|
|
|
packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
|
|
|
|
[m.random_() for m in packed_state]
|
|
|
|
original_packed = [m.clone() for m in packed_state]
|
|
|
|
|
|
|
|
B, T = packed_state[0].shape[:2]
|
|
|
|
unpacked = d.unpack_state(packed_state)
|
|
|
|
packed = d.pack_state(*unpacked)
|
|
|
|
|
|
|
|
self.assertTrue(len(packed) > 0)
|
|
|
|
self.assertEqual(len(packed), len(original_packed))
|
|
|
|
|
|
|
|
for m_idx in range(len(packed)):
|
|
|
|
self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
|
|
|
|
|
|
|
|
def test_dnc_learning(self):
|
|
|
|
ModelCatalog.register_custom_model("dnc", DNCMemory)
|
|
|
|
config = {
|
|
|
|
"env": StatelessCartPole,
|
|
|
|
"gamma": 0.99,
|
|
|
|
"num_envs_per_worker": 5,
|
|
|
|
"framework": "torch",
|
|
|
|
"num_workers": 1,
|
|
|
|
"num_cpus_per_worker": 2.0,
|
|
|
|
"lr": 0.01,
|
|
|
|
"entropy_coeff": 0.0005,
|
|
|
|
"vf_loss_coeff": 1e-5,
|
|
|
|
"model": {
|
|
|
|
"custom_model": "dnc",
|
|
|
|
"max_seq_len": 64,
|
|
|
|
"custom_model_config": {
|
|
|
|
"nr_cells": 10,
|
|
|
|
"cell_size": 8,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
tune.run("A2C", config=config, stop=self.stop, verbose=1)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-05-19 00:15:39 -07:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|