mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Clip DDPG ou-noise to avoid exceeding action bounds (#3386)
Closes #2965
This commit is contained in:
parent
55fca828ce
commit
18a8dbfcfb
4 changed files with 15 additions and 7 deletions
|
@ -89,8 +89,10 @@ class ActionNetwork(object):
|
|||
exploration_value = tf.assign_add(
|
||||
exploration_sample,
|
||||
theta * (.0 - exploration_sample) + sigma * normal_sample)
|
||||
stochastic_actions = deterministic_actions + eps * (
|
||||
high_action - low_action) * exploration_value
|
||||
stochastic_actions = tf.clip_by_value(
|
||||
deterministic_actions +
|
||||
eps * (high_action - low_action) * exploration_value,
|
||||
low_action, high_action)
|
||||
|
||||
self.actions = tf.cond(stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
|
|
|
@ -78,7 +78,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
|||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
# This ads the "advantages" column to the sample batch
|
||||
# This adds the "advantages" column to the sample batch
|
||||
return compute_advantages(
|
||||
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|
||||
|
||||
|
|
|
@ -102,9 +102,9 @@ class DiagGaussian(ActionDistribution):
|
|||
self.low = low
|
||||
self.high = high
|
||||
|
||||
# Squash to range if specified.
|
||||
# TODO(ekl) might make sense to use a beta distribution instead:
|
||||
# http://proceedings.mlr.press/v70/chou17a/chou17a.pdf
|
||||
# Squash to range if specified. We use a sigmoid here this to avoid the
|
||||
# mean drifting too far past the bounds and causing nan outputs.
|
||||
# https://github.com/ray-project/ray/issues/1862
|
||||
if low is not None:
|
||||
self.mean = low + tf.sigmoid(self.mean) * (high - low)
|
||||
|
||||
|
|
|
@ -112,7 +112,13 @@ class ModelSupportedSpaces(unittest.TestCase):
|
|||
def testAll(self):
|
||||
stats = {}
|
||||
check_support("IMPALA", {"num_gpus": 0}, stats)
|
||||
check_support("DDPG", {"timesteps_per_iteration": 1}, stats)
|
||||
check_support(
|
||||
"DDPG", {
|
||||
"noise_scale": 100.0,
|
||||
"timesteps_per_iteration": 1
|
||||
},
|
||||
stats,
|
||||
check_bounds=True)
|
||||
check_support("DQN", {"timesteps_per_iteration": 1}, stats)
|
||||
check_support("A3C", {
|
||||
"num_workers": 1,
|
||||
|
|
Loading…
Add table
Reference in a new issue