Fixed bug in Dirichlet (#4440) (#4560)

This commit is contained in:
Federico Fontana 2019-04-04 22:33:09 +01:00 committed by Eric Liang
parent 5693cd1344
commit fb88f7efe6

View file

@ -261,17 +261,35 @@ TupleActions = namedtuple("TupleActions", ["batches"])
class Dirichlet(ActionDistribution):
"""Dirichlet distribution for countinuous actions that are between
"""Dirichlet distribution for continuous actions that are between
[0,1] and sum to 1.
e.g. actions that represent resource allocation."""
def __init__(self, inputs):
self.dist = tf.distributions.Dirichlet(concentration=inputs)
ActionDistribution.__init__(self, inputs)
"""Input is a tensor of logits. The exponential of logits is used to
parametrize the Dirichlet distribution as all parameters need to be
positive. An arbitrary small epsilon is added to the concentration
parameters to be zero due to numerical error.
See issue #4440 for more details.
"""
self.epsilon = 1e-7
concentration = tf.exp(inputs) + self.epsilon
self.dist = tf.distributions.Dirichlet(
concentration=concentration,
validate_args=True,
allow_nan_stats=False,
)
ActionDistribution.__init__(self, concentration)
@override(ActionDistribution)
def logp(self, x):
# Support of Dirichlet are positive real numbers. x is already be
# an array of positive number, but we clip to avoid zeros due to
# numerical errors.
x = tf.maximum(x, self.epsilon)
x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
return self.dist.log_prob(x)
@override(ActionDistribution)