[RLlib] Fix crash when using StochasticSampling exploration (most PG-style algos) w/ tf and numpy > 1.19.5 (#18366)

This commit is contained in:
Sven Mika 2021-09-06 12:14:00 +02:00 committed by GitHub
parent 5a89b47f56
commit 59f796edf3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 49 additions and 18 deletions

View file

@ -8,7 +8,4 @@ python:
conda_packages: []
post_build_cmds:
- pip uninstall -y numpy ray || true
- sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy
- pip3 install numpy==1.19.5 || true
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}

View file

@ -26,6 +26,10 @@ parser.add_argument(
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")
parser.add_argument(
"--stop-iters",
type=int,
@ -44,7 +48,7 @@ parser.add_argument(
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
register_env("NestedSpaceRepeatAfterMeEnv",
lambda c: NestedSpaceRepeatAfterMeEnv(c))

View file

@ -12,6 +12,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules import Schedule
from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
from ray.rllib.utils.tf_ops import zero_logps_from_actions
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -131,7 +132,7 @@ class GaussianNoise(Exploration):
true_fn=lambda: stochastic_actions,
false_fn=lambda: deterministic_actions)
# Logp=always zero.
logp = tf.zeros_like(deterministic_actions, dtype=tf.float32)[:, 0]
logp = zero_logps_from_actions(deterministic_actions)
# Increment `last_timestep` by 1 (or set to `timestep`).
if self.framework in ["tf2", "tfe"]:

View file

@ -8,6 +8,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
get_variable, TensorType
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules import Schedule
from ray.rllib.utils.tf_ops import zero_logps_from_actions
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -134,7 +135,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
true_fn=lambda: exploration_actions,
false_fn=lambda: deterministic_actions)
# Logp=always zero.
logp = tf.zeros_like(deterministic_actions, dtype=tf.float32)[:, 0]
logp = zero_logps_from_actions(deterministic_actions)
# Increment `last_timestep` by 1 (or set to `timestep`).
if self.framework in ["tf2", "tfe"]:

View file

@ -12,6 +12,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_ops import zero_logps_from_actions
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -72,6 +73,10 @@ class Random(Exploration):
# Function to produce random samples from primitive space
# components: (Multi)Discrete or Box.
def random_component(component):
# Have at least an additional shape of (1,), even if the
# component is Box(-1.0, 1.0, shape=()).
shape = component.shape or (1, )
if isinstance(component, Discrete):
return tf.random.uniform(
shape=(batch_size, ) + component.shape,
@ -91,19 +96,19 @@ class Random(Exploration):
component.bounded_below.all():
if component.dtype.name.startswith("int"):
return tf.random.uniform(
shape=(batch_size, ) + component.shape,
shape=(batch_size, ) + shape,
minval=component.low.flat[0],
maxval=component.high.flat[0],
dtype=component.dtype)
else:
return tf.random.uniform(
shape=(batch_size, ) + component.shape,
shape=(batch_size, ) + shape,
minval=component.low,
maxval=component.high,
dtype=component.dtype)
else:
return tf.random.normal(
shape=(batch_size, ) + component.shape,
shape=(batch_size, ) + shape,
dtype=component.dtype)
else:
assert isinstance(component, Simplex), \
@ -111,7 +116,7 @@ class Random(Exploration):
"sampling!".format(component)
return tf.nn.softmax(
tf.random.uniform(
shape=(batch_size, ) + component.shape,
shape=(batch_size, ) + shape,
minval=0.0,
maxval=1.0,
dtype=component.dtype))
@ -129,8 +134,7 @@ class Random(Exploration):
true_fn=true_fn,
false_fn=false_fn)
# TODO(sven): Move into (deterministic_)sample(logp=True|False)
logp = tf.zeros_like(tree.flatten(action)[0], dtype=tf.float32)[:1]
logp = zero_logps_from_actions(action)
return action, logp
def get_torch_exploration_action(self, action_dist: ActionDistribution,

View file

@ -1,6 +1,6 @@
import functools
import gym
import numpy as np
import tree # pip install dm_tree
from typing import Union
from ray.rllib.models.action_dist import ActionDistribution
@ -10,6 +10,7 @@ from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.exploration.random import Random
from ray.rllib.utils.framework import get_variable, try_import_tf, \
try_import_torch, TensorType
from ray.rllib.utils.tf_ops import zero_logps_from_actions
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -90,15 +91,12 @@ class StochasticSampling(Exploration):
true_fn=lambda: stochastic_actions,
false_fn=lambda: deterministic_actions)
def logp_false_fn():
batch_size = tf.shape(tree.flatten(action)[0])[0]
return tf.zeros(shape=(batch_size, ), dtype=tf.float32)
logp = tf.cond(
tf.math.logical_and(
explore, tf.convert_to_tensor(ts >= self.random_timesteps)),
true_fn=lambda: action_dist.sampled_action_logp(),
false_fn=logp_false_fn)
false_fn=functools.partial(zero_logps_from_actions,
deterministic_actions))
# Increment `last_timestep` by 1 (or set to `timestep`).
if self.framework in ["tf2", "tfe"]:

View file

@ -4,6 +4,7 @@ import numpy as np
import tree # pip install dm_tree
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorStructType, TensorType
tf1, tf, tfv = try_import_tf()
@ -110,6 +111,31 @@ def huber_loss(x, delta=1.0):
tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta))
def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
"""Helper function useful for returning dummy logp's (0) for some actions.
Args:
actions (TensorStructType): The input actions. This can be any struct
of complex action components or a simple tensor of different
dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}.
Returns:
TensorType: A 1D tensor of 0.0 (dummy logp's) matching the batch
dim of `actions` (shape=[B]).
"""
# Need to flatten `actions` in case we have a complex action space.
# Take the 0th component to extract the batch dim.
action_component = tree.flatten(actions)[0]
logp_ = tf.zeros_like(action_component, dtype=tf.float32)
# Logp's should be single values (but with the same batch dim as
# `deterministic_actions` or `stochastic_actions`). In case
# actions are just [B], zeros_like works just fine here, but if
# actions are [B, ...], we have to reduce logp back to just [B].
if len(logp_.shape) > 1:
logp_ = logp_[:, 0]
return logp_
def one_hot(x, space):
if isinstance(space, Discrete):
return tf.one_hot(x, space.n)