ray/rllib/agents/dqn/tests/test_simple_q.py
Sven Mika 22ccc43670
[RLlib] DQN torch version. (#7597)
* Fix.

* Rollback.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix.

* Fix.

* Fix.

* WIP.

* WIP.

* Fix.

* Test case fixes.

* Test case fixes and LINT.

* Test case fixes and LINT.

* Rollback.

* WIP.

* WIP.

* Test case fixes.

* Fix.

* Fix.

* Fix.

* Add regression test for DQN w/ param noise.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Comment

* Regression test case.

* WIP.

* WIP.

* LINT.

* LINT.

* WIP.

* Fix.

* Fix.

* Fix.

* LINT.

* Fix (SAC does currently not support eager).

* Fix.

* WIP.

* LINT.

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* WIP.

* Fix.

* LINT.

* LINT.

* Fix and LINT.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Fix.

* Fix and LINT.

* Update rllib/utils/exploration/exploration.py

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Fixes.

* WIP.

* LINT.

* Fixes and LINT.

* LINT and fixes.

* LINT.

* Move action_dist back into torch extra_action_out_fn and LINT.

* Working SimpleQ learning cartpole on both torch AND tf.

* Working Rainbow learning cartpole on tf.

* Working Rainbow learning cartpole on tf.

* WIP.

* LINT.

* LINT.

* Update docs and add torch to APEX test.

* LINT.

* Fix.

* LINT.

* Fix.

* Fix.

* Fix and docstrings.

* Fix broken RLlib tests in master.

* Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier).

* Fix error_outputs option in BAZEL for RLlib regression tests.

* Fix.

* Tune param-noise tests.

* LINT.

* Fix.

* Fix.

* test

* test

* test

* Fix.

* Fix.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT.

* WIP.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 11:56:16 -07:00

102 lines
4 KiB
Python

import numpy as np
import unittest
import ray.rllib.agents.dqn as dqn
from ray.rllib.agents.dqn.simple_q_tf_policy import build_q_losses as loss_tf
from ray.rllib.agents.dqn.simple_q_torch_policy import build_q_losses as \
loss_torch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import fc, one_hot, huber_loss
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
class TestSimpleQ(unittest.TestCase):
def test_simple_q_compilation(self):
"""Test whether a SimpleQTrainer can be built on all frameworks."""
config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
for _ in framework_iterator(config):
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
num_iterations = 2
for i in range(num_iterations):
results = trainer.train()
print(results)
def test_simple_q_loss_function(self):
"""Tests the Simple-Q loss function results on all frameworks."""
config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
# Run locally.
config["num_workers"] = 0
# Use very simple net (layer0=10 nodes, q-layer=2 nodes (2 actions)).
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
for fw in framework_iterator(config):
# Generate Trainer and get its default Policy object.
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Batch of size=2.
input_ = {
SampleBatch.CUR_OBS: np.random.random(size=(2, 4)),
SampleBatch.ACTIONS: np.array([0, 1]),
SampleBatch.REWARDS: np.array([0.4, -1.23]),
SampleBatch.DONES: np.array([False, False]),
SampleBatch.NEXT_OBS: np.random.random(size=(2, 4))
}
# Get model vars for computing expected model outs (q-vals).
# 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
vars = policy.get_weights()
if isinstance(vars, dict):
vars = list(vars.values())
vars_t = policy.target_q_func_vars
if fw == "tf":
vars_t = policy.get_session().run(vars_t)
# Q(s,a) outputs.
q_t = np.sum(
one_hot(input_[SampleBatch.ACTIONS], 2) * fc(
fc(input_[SampleBatch.CUR_OBS],
vars[0],
vars[1],
framework=fw),
vars[2],
vars[3],
framework=fw), 1)
# max[a'](Qtarget(s',a')) outputs.
q_target_tp1 = np.max(
fc(fc(
input_[SampleBatch.NEXT_OBS],
vars_t[0],
vars_t[1],
framework=fw),
vars_t[2],
vars_t[3],
framework=fw), 1)
# TD-errors (Bellman equation).
td_error = q_t - config["gamma"] * input_[SampleBatch.REWARDS] + \
q_target_tp1
# Huber/Square loss on TD-error.
expected_loss = huber_loss(td_error).mean()
if fw == "torch":
input_ = policy._lazy_tensor_dict(input_)
# Get actual out and compare.
if fw == "tf":
out = policy.get_session().run(
policy._loss,
feed_dict=policy._get_loss_inputs_dict(
input_, shuffle=False))
else:
out = (loss_torch if fw == "torch" else
loss_tf)(policy, policy.model, None, input_)
check(out, expected_loss, decimals=1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))