ray/rllib/policy/tests/test_compute_log_likelihoods.py
Sven Mika 428516056a
[RLlib] SAC Torch (incl. Atari learning) (#7984)
* Policy-classes cleanup and torch/tf unification.
- Make Policy abstract.
- Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch).
- Move some methods and vars to base Policy
  (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more.

* Fix `clip_action` import from Policy (should probably be moved into utils altogether).

* - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy).
- Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces).

* Add `config` to c'tor call to TFPolicy.

* Add missing `config` to c'tor call to TFPolicy in marvil_policy.py.

* Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract).

* Fix LINT errors in Policy classes.

* Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py.

* policy.py LINT errors.

* Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases).

* policy.py
- Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented).
- Fix docstring of `num_state_tensors`.

* Make QMIX torch Policy a child of TorchPolicy (instead of Policy).

* QMixPolicy add empty implementations of abstract Policy methods.

* Store Policy's config in self.config in base Policy c'tor.

* - Make only compute_actions in base Policy's an abstractmethod and provide pass
implementation to all other methods if not defined.
- Fix state_batches=None (most Policies don't have internal states).

* Cartpole tf learning.

* Cartpole tf AND torch learning (in ~ same ts).

* Cartpole tf AND torch learning (in ~ same ts). 2

* Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3

* Cartpole tf AND torch learning (in ~ same ts). 4

* Cartpole tf AND torch learning (in ~ same ts). 5

* Cartpole tf AND torch learning (in ~ same ts). 6

* Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning.

* WIP.

* WIP.

* SAC torch learning Pendulum.

* WIP.

* SAC torch and tf learning Pendulum and Cartpole after cleanup.

* WIP.

* LINT.

* LINT.

* SAC: Move policy.target_model to policy.device as well.

* Fixes and cleanup.

* Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default).

* Fixes and LINT.

* Fixes and LINT.

* Fix and LINT.

* WIP.

* Test fixes and LINT.

* Fixes and LINT.

Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00

196 lines
7.6 KiB
Python

import numpy as np
from scipy.stats import norm
import unittest
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.pg as pg
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.sac as sac
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.utils.numpy import one_hot, fc, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
tf = try_import_tf()
def do_test_log_likelihood(run,
config,
prev_a=None,
continuous=False,
layer_key=("fc", (0, 4), ("_hidden_layers.0.",
"_logits.")),
logp_func=None):
config = config.copy()
# Run locally.
config["num_workers"] = 0
# Env setup.
if continuous:
env = "Pendulum-v0"
obs_batch = preprocessed_obs_batch = np.array([[0.0, 0.1, -0.1]])
else:
env = "FrozenLake-v0"
config["env_config"] = {"is_slippery": False, "map_name": "4x4"}
obs_batch = np.array([0])
preprocessed_obs_batch = one_hot(obs_batch, depth=16)
prev_r = None if prev_a is None else np.array(0.0)
# Test against all frameworks.
for fw in framework_iterator(config):
if run in [sac.SACTrainer] and fw == "eager":
continue
trainer = run(config=config, env=env)
policy = trainer.get_policy()
vars = policy.get_weights()
# Sample n actions, then roughly check their logp against their
# counts.
num_actions = 1000 if not continuous else 50
actions = []
for _ in range(num_actions):
# Single action from single obs.
actions.append(
trainer.compute_action(
obs_batch[0],
prev_action=prev_a,
prev_reward=prev_r,
explore=True))
# Test all taken actions for their log-likelihoods vs expected values.
if continuous:
for idx in range(num_actions):
a = actions[idx]
if fw == "tf" or fw == "eager":
if isinstance(vars, list):
expected_mean_logstd = fc(
fc(obs_batch, vars[layer_key[1][0]]),
vars[layer_key[1][1]])
else:
expected_mean_logstd = fc(
fc(
obs_batch,
vars["default_policy/{}_1/kernel".format(
layer_key[0])]),
vars["default_policy/{}_out/kernel".format(
layer_key[0])])
else:
expected_mean_logstd = fc(
fc(obs_batch,
vars["{}_model.0.weight".format(layer_key[2][0])],
framework=fw),
vars["{}_model.0.weight".format(layer_key[2][1])],
framework=fw)
mean, log_std = np.split(expected_mean_logstd, 2, axis=-1)
if logp_func is None:
expected_logp = np.log(norm.pdf(a, mean, np.exp(log_std)))
else:
expected_logp = logp_func(mean, log_std, a)
logp = policy.compute_log_likelihoods(
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]),
prev_reward_batch=np.array([prev_r]))
check(logp, expected_logp[0], rtol=0.2)
# Test all available actions for their logp values.
else:
for a in [0, 1, 2, 3]:
count = actions.count(a)
expected_prob = count / num_actions
logp = policy.compute_log_likelihoods(
np.array([a]),
preprocessed_obs_batch,
prev_action_batch=np.array([prev_a]),
prev_reward_batch=np.array([prev_r]))
check(np.exp(logp), expected_prob, atol=0.2)
class TestComputeLogLikelihood(unittest.TestCase):
def test_dqn(self):
"""Tests, whether DQN correctly computes logp in soft-q mode."""
config = dqn.DEFAULT_CONFIG.copy()
# Soft-Q for DQN.
config["exploration_config"] = {"type": "SoftQ", "temperature": 0.5}
do_test_log_likelihood(dqn.DQNTrainer, config)
def test_pg_cont(self):
"""Tests PG's (cont. actions) compute_log_likelihoods method."""
config = pg.DEFAULT_CONFIG.copy()
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
do_test_log_likelihood(
pg.PGTrainer,
config,
prev_a,
continuous=True,
layer_key=("fc", (0, 2), ("_hidden_layers.0.", "_logits.")))
def test_pg_discr(self):
"""Tests PG's (cont. actions) compute_log_likelihoods method."""
config = pg.DEFAULT_CONFIG.copy()
prev_a = np.array(0)
do_test_log_likelihood(pg.PGTrainer, config, prev_a)
def test_ppo_cont(self):
"""Tests PPO's (cont. actions) compute_log_likelihoods method."""
config = ppo.DEFAULT_CONFIG.copy()
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
do_test_log_likelihood(ppo.PPOTrainer, config, prev_a, continuous=True)
def test_ppo_discr(self):
"""Tests PPO's (discr. actions) compute_log_likelihoods method."""
prev_a = np.array(0)
do_test_log_likelihood(ppo.PPOTrainer, ppo.DEFAULT_CONFIG, prev_a)
def test_sac_cont(self):
"""Tests SAC's (cont. actions) compute_log_likelihoods method."""
config = sac.DEFAULT_CONFIG.copy()
config["policy_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_activation"] = "linear"
prev_a = np.array([0.0])
# SAC cont uses a squashed normal distribution. Implement it's logp
# logic here in numpy for comparing results.
def logp_func(means, log_stds, values, low=-1.0, high=1.0):
stds = np.exp(
np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
unsquashed_values = np.arctanh((values - low) /
(high - low) * 2.0 - 1.0)
log_prob_unsquashed = \
np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1)
return log_prob_unsquashed - \
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
axis=-1)
do_test_log_likelihood(
sac.SACTrainer,
config,
prev_a,
continuous=True,
layer_key=("sequential/action", (0, 2),
("action_model.action_0.", "action_model.action_out.")),
logp_func=logp_func)
def test_sac_discr(self):
"""Tests SAC's (discrete actions) compute_log_likelihoods method."""
config = sac.DEFAULT_CONFIG.copy()
config["policy_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_activation"] = "linear"
prev_a = np.array(0)
do_test_log_likelihood(
sac.SACTrainer,
config,
prev_a,
layer_key=("sequential/action", (0, 2),
("action_model.action_0.", "action_model.action_out.")))
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))