2020-02-15 23:50:44 +01:00
import numpy as np
2020-02-22 23:19:49 +01:00
from gym.spaces import Box
2020-04-15 13:25:16 +02:00
from scipy.stats import norm, beta
2020-02-22 23:19:49 +01:00
import unittest
2020-02-15 23:50:44 +01:00
2020-03-04 09:41:40 +01:00
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
2020-03-06 19:37:12 +01:00
SquashedGaussian, GumbelSoftmax
2020-04-15 13:25:16 +02:00
from ray.rllib.models.torch.torch_action_dist import TorchMultiCategorical, \
TorchSquashedGaussian, TorchBeta
2020-03-04 09:41:40 +01:00
from ray.rllib.utils import try_import_tf, try_import_torch
2020-04-15 13:25:16 +02:00
from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
2020-04-03 21:24:25 +02:00
from ray.rllib.utils.test_utils import check, framework_iterator
2020-02-15 23:50:44 +01:00
tf = try_import_tf()
2020-03-04 09:41:40 +01:00
torch, _ = try_import_torch()
2020-02-15 23:50:44 +01:00
class TestDistributions(unittest.TestCase):
2020-02-22 23:19:49 +01:00
"""Tests ActionDistribution classes."""
2020-02-15 23:50:44 +01:00
def test_categorical(self):
2020-02-22 23:19:49 +01:00
"""Tests the Categorical ActionDistribution (tf only)."""
2020-02-15 23:50:44 +01:00
num_samples = 100000
logits = tf.placeholder(tf.float32, shape=(None, 10))
z = 8 * (np.random.rand(10) - 0.5)
data = np.tile(z, (num_samples, 1))
c = Categorical(logits, {}) # dummy config dict
sample_op = c.sample()
sess = tf.Session()
samples = sess.run(sample_op, feed_dict={logits: data})
counts = np.zeros(10)
for sample in samples:
counts[sample] += 1.0
probs = np.exp(z) / np.sum(np.exp(z))
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
2020-02-19 21:18:45 +01:00
2020-03-04 09:41:40 +01:00
def test_multi_categorical(self):
batch_size = 100
num_categories = 3
num_sub_distributions = 5
# Create 5 categorical distributions of 3 categories each.
inputs_space = Box(
shape=(batch_size, num_sub_distributions * num_categories))
values_space = Box(
num_categories - 1,
shape=(num_sub_distributions, batch_size),
inputs = inputs_space.sample()
input_lengths = [num_categories] * num_sub_distributions
inputs_split = np.split(inputs, num_sub_distributions, axis=1)
2020-04-03 21:24:25 +02:00
for fw in framework_iterator():
2020-03-06 19:37:12 +01:00
# Create the correct distribution object.
2020-03-04 09:41:40 +01:00
cls = MultiCategorical if fw != "torch" else TorchMultiCategorical
multi_categorical = cls(inputs, None, input_lengths)
# Batch of size=3 and deterministic (True).
expected = np.transpose(np.argmax(inputs_split, axis=-1))
# Sample, expect always max value
# (max likelihood for deterministic draw).
out = multi_categorical.deterministic_sample()
check(out, expected)
# Batch of size=3 and non-deterministic -> expect roughly the mean.
out = multi_categorical.sample()
if fw != "torch" else torch.mean(out.float()),
# Test log-likelihood outputs.
probs = softmax(inputs_split)
values = values_space.sample()
out = multi_categorical.logp(values if fw != "torch" else [
torch.Tensor(values[i]) for i in range(num_sub_distributions)
]) # v in np.stack(values, 1)])
expected = []
for i in range(batch_size):
for j in range(num_sub_distributions)
check(out, expected, decimals=4)
# Test entropy outputs.
out = multi_categorical.entropy()
expected_entropy = -np.sum(np.sum(probs * np.log(probs), 0), -1)
check(out, expected_entropy)
2020-02-22 23:19:49 +01:00
def test_squashed_gaussian(self):
2020-04-15 13:25:16 +02:00
"""Tests the SquashedGaussian ActionDistribution for all frameworks."""
input_space = Box(-2.0, 2.0, shape=(200, 10))
low, high = -2.0, 1.0
for fw, sess in framework_iterator(session=True):
cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian
2020-02-22 23:19:49 +01:00
# Batch of size=n and deterministic.
inputs = input_space.sample()
means, _ = np.split(inputs, 2, axis=-1)
2020-04-15 13:25:16 +02:00
squashed_distribution = cls(inputs, {}, low=low, high=high)
2020-02-22 23:19:49 +01:00
expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
# Sample n times, expect always mean value (deterministic draw).
out = squashed_distribution.deterministic_sample()
check(out, expected)
# Batch of size=n and non-deterministic -> expect roughly the mean.
inputs = input_space.sample()
means, log_stds = np.split(inputs, 2, axis=-1)
2020-04-15 13:25:16 +02:00
squashed_distribution = cls(inputs, {}, low=low, high=high)
2020-02-22 23:19:49 +01:00
expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
values = squashed_distribution.sample()
2020-04-03 21:24:25 +02:00
if sess:
values = sess.run(values)
2020-04-15 13:25:16 +02:00
values = values.numpy()
2020-02-22 23:19:49 +01:00
self.assertTrue(np.max(values) < high)
self.assertTrue(np.min(values) > low)
check(np.mean(values), expected.mean(), decimals=1)
# Test log-likelihood outputs.
2020-04-15 13:25:16 +02:00
sampled_action_logp = squashed_distribution.logp(
values if fw != "torch" else torch.Tensor(values))
2020-04-03 21:24:25 +02:00
if sess:
sampled_action_logp = sess.run(sampled_action_logp)
2020-04-15 13:25:16 +02:00
sampled_action_logp = sampled_action_logp.numpy()
2020-02-22 23:19:49 +01:00
# Convert to parameters for distr.
stds = np.exp(
np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
# Unsquash values, then get log-llh from regular gaussian.
2020-04-15 13:25:16 +02:00
# atanh_in = np.clip((values - low) / (high - low) * 2.0 - 1.0,
atanh_in = (values - low) / (high - low) * 2.0 - 1.0
unsquashed_values = np.arctanh(atanh_in)
log_prob_unsquashed = np.sum(
norm.pdf(unsquashed_values, means, stds) + SMALL_NUMBER),
2020-02-22 23:19:49 +01:00
log_prob = log_prob_unsquashed - \
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
2020-04-15 13:25:16 +02:00
check(np.sum(sampled_action_logp), np.sum(log_prob), rtol=0.05)
2020-02-22 23:19:49 +01:00
# NN output.
means = np.array([[0.1, 0.2, 0.3, 0.4, 50.0],
[-0.1, -0.2, -0.3, -0.4, -1.0]])
log_stds = np.array([[0.8, -0.2, 0.3, -1.0, 2.0],
[0.7, -0.3, 0.4, -0.9, 2.0]])
2020-04-15 13:25:16 +02:00
squashed_distribution = cls(
inputs=np.concatenate([means, log_stds], axis=-1),
2020-02-22 23:19:49 +01:00
# Convert to parameters for distr.
stds = np.exp(log_stds)
# Values to get log-likelihoods for.
values = np.array([[0.9, 0.2, 0.4, -0.1, -1.05],
[-0.9, -0.2, 0.4, -0.1, -1.05]])
# Unsquash values, then get log-llh from regular gaussian.
unsquashed_values = np.arctanh((values - low) /
(high - low) * 2.0 - 1.0)
log_prob_unsquashed = \
np.sum(np.log(norm.pdf(unsquashed_values, means, stds)), -1)
log_prob = log_prob_unsquashed - \
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
2020-04-15 13:25:16 +02:00
outs = squashed_distribution.logp(values if fw != "torch" else
2020-04-03 21:24:25 +02:00
if sess:
outs = sess.run(outs)
2020-04-15 13:25:16 +02:00
check(outs, log_prob, decimals=4)
def test_beta(self):
input_space = Box(-2.0, 1.0, shape=(200, 10))
low, high = -1.0, 2.0
plain_beta_value_space = Box(0.0, 1.0, shape=(200, 5))
for fw, sess in framework_iterator(frameworks="torch", session=True):
cls = TorchBeta
inputs = input_space.sample()
beta_distribution = cls(inputs, {}, low=low, high=high)
inputs = beta_distribution.inputs
alpha, beta_ = np.split(inputs.numpy(), 2, axis=-1)
# Mean for a Beta distribution: 1 / [1 + (beta/alpha)]
expected = (1.0 / (1.0 + beta_ / alpha)) * (high - low) + low
# Sample n times, expect always mean value (deterministic draw).
out = beta_distribution.deterministic_sample()
check(out, expected, rtol=0.01)
# Batch of size=n and non-deterministic -> expect roughly the mean.
values = beta_distribution.sample()
if sess:
values = sess.run(values)
values = values.numpy()
self.assertTrue(np.max(values) <= high)
self.assertTrue(np.min(values) >= low)
check(np.mean(values), expected.mean(), decimals=1)
# Test log-likelihood outputs (against scipy).
inputs = input_space.sample()
beta_distribution = cls(inputs, {}, low=low, high=high)
inputs = beta_distribution.inputs
alpha, beta_ = np.split(inputs.numpy(), 2, axis=-1)
values = plain_beta_value_space.sample()
values_scaled = values * (high - low) + low
out = beta_distribution.logp(torch.Tensor(values_scaled))
np.sum(np.log(beta.pdf(values, alpha, beta_)), -1),
# TODO(sven): Test entropy outputs (against scipy).
2020-02-22 23:19:49 +01:00
2020-03-06 19:37:12 +01:00
def test_gumbel_softmax(self):
"""Tests the GumbelSoftmax ActionDistribution (tf-eager only)."""
2020-04-03 21:24:25 +02:00
for fw, sess in framework_iterator(
frameworks=["tf", "eager"], session=True):
2020-03-06 19:37:12 +01:00
batch_size = 1000
num_categories = 5
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
# Batch of size=n and deterministic.
inputs = input_space.sample()
gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
expected = softmax(inputs)
# Sample n times, expect always mean value (deterministic draw).
out = gumbel_softmax.deterministic_sample()
check(out, expected)
# Batch of size=n and non-deterministic -> expect roughly that
# the max-likelihood (argmax) ints are output (most of the time).
inputs = input_space.sample()
gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
outs = gumbel_softmax.sample()
2020-04-03 21:24:25 +02:00
if sess:
outs = sess.run(outs)
2020-03-06 19:37:12 +01:00
check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
2020-02-19 21:18:45 +01:00
if __name__ == "__main__":
2020-03-12 04:39:47 +01:00
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))