ray/rllib/agents/dqn/tests/test_simple_q.py
Sven Mika 428516056a
[RLlib] SAC Torch (incl. Atari learning) (#7984)
* Policy-classes cleanup and torch/tf unification.
- Make Policy abstract.
- Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch).
- Move some methods and vars to base Policy
  (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more.

* Fix `clip_action` import from Policy (should probably be moved into utils altogether).

* - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy).
- Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces).

* Add `config` to c'tor call to TFPolicy.

* Add missing `config` to c'tor call to TFPolicy in marvil_policy.py.

* Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract).

* Fix LINT errors in Policy classes.

* Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py.

* policy.py LINT errors.

* Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases).

* policy.py
- Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented).
- Fix docstring of `num_state_tensors`.

* Make QMIX torch Policy a child of TorchPolicy (instead of Policy).

* QMixPolicy add empty implementations of abstract Policy methods.

* Store Policy's config in self.config in base Policy c'tor.

* - Make only compute_actions in base Policy's an abstractmethod and provide pass
implementation to all other methods if not defined.
- Fix state_batches=None (most Policies don't have internal states).

* Cartpole tf learning.

* Cartpole tf AND torch learning (in ~ same ts).

* Cartpole tf AND torch learning (in ~ same ts). 2

* Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3

* Cartpole tf AND torch learning (in ~ same ts). 4

* Cartpole tf AND torch learning (in ~ same ts). 5

* Cartpole tf AND torch learning (in ~ same ts). 6

* Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning.

* WIP.

* WIP.

* SAC torch learning Pendulum.

* WIP.

* SAC torch and tf learning Pendulum and Cartpole after cleanup.

* WIP.

* LINT.

* LINT.

* SAC: Move policy.target_model to policy.device as well.

* Fixes and cleanup.

* Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default).

* Fixes and LINT.

* Fixes and LINT.

* Fix and LINT.

* WIP.

* Test fixes and LINT.

* Fixes and LINT.

Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00

102 lines
4.1 KiB
Python

import numpy as np
import unittest
import ray.rllib.agents.dqn as dqn
from ray.rllib.agents.dqn.simple_q_tf_policy import build_q_losses as loss_tf
from ray.rllib.agents.dqn.simple_q_torch_policy import build_q_losses as \
loss_torch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import fc, one_hot, huber_loss
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
class TestSimpleQ(unittest.TestCase):
def test_simple_q_compilation(self):
"""Test whether a SimpleQTrainer can be built on all frameworks."""
config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
for _ in framework_iterator(config):
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
num_iterations = 2
for i in range(num_iterations):
results = trainer.train()
print(results)
def test_simple_q_loss_function(self):
"""Tests the Simple-Q loss function results on all frameworks."""
config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
# Run locally.
config["num_workers"] = 0
# Use very simple net (layer0=10 nodes, q-layer=2 nodes (2 actions)).
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
for fw in framework_iterator(config):
# Generate Trainer and get its default Policy object.
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Batch of size=2.
input_ = {
SampleBatch.CUR_OBS: np.random.random(size=(2, 4)),
SampleBatch.ACTIONS: np.array([0, 1]),
SampleBatch.REWARDS: np.array([0.4, -1.23]),
SampleBatch.DONES: np.array([False, False]),
SampleBatch.NEXT_OBS: np.random.random(size=(2, 4))
}
# Get model vars for computing expected model outs (q-vals).
# 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
vars = policy.get_weights()
if isinstance(vars, dict):
vars = list(vars.values())
vars_t = policy.target_q_func_vars
if fw == "tf":
vars_t = policy.get_session().run(vars_t)
# Q(s,a) outputs.
q_t = np.sum(
one_hot(input_[SampleBatch.ACTIONS], 2) * fc(
fc(input_[SampleBatch.CUR_OBS],
vars[0 if fw != "torch" else 2],
vars[1 if fw != "torch" else 3],
framework=fw),
vars[2 if fw != "torch" else 0],
vars[3 if fw != "torch" else 1],
framework=fw), 1)
# max[a'](Qtarget(s',a')) outputs.
q_target_tp1 = np.max(
fc(fc(
input_[SampleBatch.NEXT_OBS],
vars_t[0 if fw != "torch" else 2],
vars_t[1 if fw != "torch" else 3],
framework=fw),
vars_t[2 if fw != "torch" else 0],
vars_t[3 if fw != "torch" else 1],
framework=fw), 1)
# TD-errors (Bellman equation).
td_error = q_t - config["gamma"] * input_[SampleBatch.REWARDS] + \
q_target_tp1
# Huber/Square loss on TD-error.
expected_loss = huber_loss(td_error).mean()
if fw == "torch":
input_ = policy._lazy_tensor_dict(input_)
# Get actual out and compare.
if fw == "tf":
out = policy.get_session().run(
policy._loss,
feed_dict=policy._get_loss_inputs_dict(
input_, shuffle=False))
else:
out = (loss_torch if fw == "torch" else
loss_tf)(policy, policy.model, None, input_)
check(out, expected_loss, decimals=1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))