mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
5693cd1344
commit
fb88f7efe6
1 changed files with 21 additions and 3 deletions
|
@ -261,17 +261,35 @@ TupleActions = namedtuple("TupleActions", ["batches"])
|
|||
|
||||
|
||||
class Dirichlet(ActionDistribution):
|
||||
"""Dirichlet distribution for countinuous actions that are between
|
||||
"""Dirichlet distribution for continuous actions that are between
|
||||
[0,1] and sum to 1.
|
||||
|
||||
e.g. actions that represent resource allocation."""
|
||||
|
||||
def __init__(self, inputs):
|
||||
self.dist = tf.distributions.Dirichlet(concentration=inputs)
|
||||
ActionDistribution.__init__(self, inputs)
|
||||
"""Input is a tensor of logits. The exponential of logits is used to
|
||||
parametrize the Dirichlet distribution as all parameters need to be
|
||||
positive. An arbitrary small epsilon is added to the concentration
|
||||
parameters to be zero due to numerical error.
|
||||
|
||||
See issue #4440 for more details.
|
||||
"""
|
||||
self.epsilon = 1e-7
|
||||
concentration = tf.exp(inputs) + self.epsilon
|
||||
self.dist = tf.distributions.Dirichlet(
|
||||
concentration=concentration,
|
||||
validate_args=True,
|
||||
allow_nan_stats=False,
|
||||
)
|
||||
ActionDistribution.__init__(self, concentration)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
# Support of Dirichlet are positive real numbers. x is already be
|
||||
# an array of positive number, but we clip to avoid zeros due to
|
||||
# numerical errors.
|
||||
x = tf.maximum(x, self.epsilon)
|
||||
x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
|
||||
return self.dist.log_prob(x)
|
||||
|
||||
@override(ActionDistribution)
|
||||
|
|
Loading…
Add table
Reference in a new issue