mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.dqn as dqn
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
|
from ray.rllib.utils.test_utils import (
|
|
check_compute_single_action,
|
|
check_train_results,
|
|
framework_iterator,
|
|
)
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
def check_batch_sizes(train_results):
|
|
"""Check if batch sizes are according to what we expect from config."""
|
|
info = train_results["info"]
|
|
learner_info = info[LEARNER_INFO]
|
|
|
|
for pid, policy_stats in learner_info.items():
|
|
if pid == "batch_count":
|
|
continue
|
|
# Expect td-errors to be per batch-item.
|
|
configured_b = train_results["config"]["train_batch_size"]
|
|
actual_b = policy_stats["td_error"].shape[0]
|
|
if (configured_b - actual_b) / actual_b > 0.1:
|
|
assert (
|
|
configured_b
|
|
/ (
|
|
train_results["config"]["model"]["max_seq_len"]
|
|
+ train_results["config"]["replay_buffer_config"]["replay_burn_in"]
|
|
)
|
|
== actual_b
|
|
)
|
|
|
|
|
|
class TestR2D2(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_r2d2_compilation(self):
|
|
"""Test whether a R2D2Trainer can be built on all frameworks."""
|
|
config = (
|
|
dqn.r2d2.R2D2Config()
|
|
.rollouts(num_rollout_workers=0)
|
|
.training(
|
|
model={
|
|
# Wrap with an LSTM and use a very simple base-model.
|
|
"use_lstm": True,
|
|
"max_seq_len": 20,
|
|
"fcnet_hiddens": [32],
|
|
"lstm_cell_size": 64,
|
|
},
|
|
dueling=False,
|
|
lr=5e-4,
|
|
zero_init_states=True,
|
|
replay_buffer_config={"replay_burn_in": 20},
|
|
)
|
|
.exploration(exploration_config={"epsilon_timesteps": 100000})
|
|
)
|
|
|
|
num_iterations = 1
|
|
|
|
# Test building an R2D2 agent in all frameworks.
|
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
|
trainer = config.build(env="CartPole-v0")
|
|
for i in range(num_iterations):
|
|
results = trainer.train()
|
|
check_train_results(results)
|
|
check_batch_sizes(results)
|
|
print(results)
|
|
|
|
check_compute_single_action(trainer, include_state=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|