mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix crash when using StochasticSampling exploration (most PG-style algos) w/ tf and numpy > 1.19.5 (#18366)
This commit is contained in:
parent
5a89b47f56
commit
59f796edf3
7 changed files with 49 additions and 18 deletions
|
@ -8,7 +8,4 @@ python:
|
|||
conda_packages: []
|
||||
|
||||
post_build_cmds:
|
||||
- pip uninstall -y numpy ray || true
|
||||
- sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy
|
||||
- pip3 install numpy==1.19.5 || true
|
||||
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
|
||||
|
|
|
@ -26,6 +26,10 @@ parser.add_argument(
|
|||
action="store_true",
|
||||
help="Whether this script should be run as a test: --stop-reward must "
|
||||
"be achieved within --stop-timesteps AND --stop-iters.")
|
||||
parser.add_argument(
|
||||
"--local-mode",
|
||||
action="store_true",
|
||||
help="Init Ray in local mode for easier debugging.")
|
||||
parser.add_argument(
|
||||
"--stop-iters",
|
||||
type=int,
|
||||
|
@ -44,7 +48,7 @@ parser.add_argument(
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
|
||||
register_env("NestedSpaceRepeatAfterMeEnv",
|
||||
lambda c: NestedSpaceRepeatAfterMeEnv(c))
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
|||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.schedules import Schedule
|
||||
from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_ops import zero_logps_from_actions
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -131,7 +132,7 @@ class GaussianNoise(Exploration):
|
|||
true_fn=lambda: stochastic_actions,
|
||||
false_fn=lambda: deterministic_actions)
|
||||
# Logp=always zero.
|
||||
logp = tf.zeros_like(deterministic_actions, dtype=tf.float32)[:, 0]
|
||||
logp = zero_logps_from_actions(deterministic_actions)
|
||||
|
||||
# Increment `last_timestep` by 1 (or set to `timestep`).
|
||||
if self.framework in ["tf2", "tfe"]:
|
||||
|
|
|
@ -8,6 +8,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
|||
get_variable, TensorType
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.schedules import Schedule
|
||||
from ray.rllib.utils.tf_ops import zero_logps_from_actions
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -134,7 +135,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
|
|||
true_fn=lambda: exploration_actions,
|
||||
false_fn=lambda: deterministic_actions)
|
||||
# Logp=always zero.
|
||||
logp = tf.zeros_like(deterministic_actions, dtype=tf.float32)[:, 0]
|
||||
logp = zero_logps_from_actions(deterministic_actions)
|
||||
|
||||
# Increment `last_timestep` by 1 (or set to `timestep`).
|
||||
if self.framework in ["tf2", "tfe"]:
|
||||
|
|
|
@ -12,6 +12,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
|||
TensorType
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.tf_ops import zero_logps_from_actions
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -72,6 +73,10 @@ class Random(Exploration):
|
|||
# Function to produce random samples from primitive space
|
||||
# components: (Multi)Discrete or Box.
|
||||
def random_component(component):
|
||||
# Have at least an additional shape of (1,), even if the
|
||||
# component is Box(-1.0, 1.0, shape=()).
|
||||
shape = component.shape or (1, )
|
||||
|
||||
if isinstance(component, Discrete):
|
||||
return tf.random.uniform(
|
||||
shape=(batch_size, ) + component.shape,
|
||||
|
@ -91,19 +96,19 @@ class Random(Exploration):
|
|||
component.bounded_below.all():
|
||||
if component.dtype.name.startswith("int"):
|
||||
return tf.random.uniform(
|
||||
shape=(batch_size, ) + component.shape,
|
||||
shape=(batch_size, ) + shape,
|
||||
minval=component.low.flat[0],
|
||||
maxval=component.high.flat[0],
|
||||
dtype=component.dtype)
|
||||
else:
|
||||
return tf.random.uniform(
|
||||
shape=(batch_size, ) + component.shape,
|
||||
shape=(batch_size, ) + shape,
|
||||
minval=component.low,
|
||||
maxval=component.high,
|
||||
dtype=component.dtype)
|
||||
else:
|
||||
return tf.random.normal(
|
||||
shape=(batch_size, ) + component.shape,
|
||||
shape=(batch_size, ) + shape,
|
||||
dtype=component.dtype)
|
||||
else:
|
||||
assert isinstance(component, Simplex), \
|
||||
|
@ -111,7 +116,7 @@ class Random(Exploration):
|
|||
"sampling!".format(component)
|
||||
return tf.nn.softmax(
|
||||
tf.random.uniform(
|
||||
shape=(batch_size, ) + component.shape,
|
||||
shape=(batch_size, ) + shape,
|
||||
minval=0.0,
|
||||
maxval=1.0,
|
||||
dtype=component.dtype))
|
||||
|
@ -129,8 +134,7 @@ class Random(Exploration):
|
|||
true_fn=true_fn,
|
||||
false_fn=false_fn)
|
||||
|
||||
# TODO(sven): Move into (deterministic_)sample(logp=True|False)
|
||||
logp = tf.zeros_like(tree.flatten(action)[0], dtype=tf.float32)[:1]
|
||||
logp = zero_logps_from_actions(action)
|
||||
return action, logp
|
||||
|
||||
def get_torch_exploration_action(self, action_dist: ActionDistribution,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import functools
|
||||
import gym
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from typing import Union
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
|
@ -10,6 +10,7 @@ from ray.rllib.utils.exploration.exploration import Exploration
|
|||
from ray.rllib.utils.exploration.random import Random
|
||||
from ray.rllib.utils.framework import get_variable, try_import_tf, \
|
||||
try_import_torch, TensorType
|
||||
from ray.rllib.utils.tf_ops import zero_logps_from_actions
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -90,15 +91,12 @@ class StochasticSampling(Exploration):
|
|||
true_fn=lambda: stochastic_actions,
|
||||
false_fn=lambda: deterministic_actions)
|
||||
|
||||
def logp_false_fn():
|
||||
batch_size = tf.shape(tree.flatten(action)[0])[0]
|
||||
return tf.zeros(shape=(batch_size, ), dtype=tf.float32)
|
||||
|
||||
logp = tf.cond(
|
||||
tf.math.logical_and(
|
||||
explore, tf.convert_to_tensor(ts >= self.random_timesteps)),
|
||||
true_fn=lambda: action_dist.sampled_action_logp(),
|
||||
false_fn=logp_false_fn)
|
||||
false_fn=functools.partial(zero_logps_from_actions,
|
||||
deterministic_actions))
|
||||
|
||||
# Increment `last_timestep` by 1 (or set to `timestep`).
|
||||
if self.framework in ["tf2", "tfe"]:
|
||||
|
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import tree # pip install dm_tree
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorStructType, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -110,6 +111,31 @@ def huber_loss(x, delta=1.0):
|
|||
tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta))
|
||||
|
||||
|
||||
def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
|
||||
"""Helper function useful for returning dummy logp's (0) for some actions.
|
||||
|
||||
Args:
|
||||
actions (TensorStructType): The input actions. This can be any struct
|
||||
of complex action components or a simple tensor of different
|
||||
dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}.
|
||||
|
||||
Returns:
|
||||
TensorType: A 1D tensor of 0.0 (dummy logp's) matching the batch
|
||||
dim of `actions` (shape=[B]).
|
||||
"""
|
||||
# Need to flatten `actions` in case we have a complex action space.
|
||||
# Take the 0th component to extract the batch dim.
|
||||
action_component = tree.flatten(actions)[0]
|
||||
logp_ = tf.zeros_like(action_component, dtype=tf.float32)
|
||||
# Logp's should be single values (but with the same batch dim as
|
||||
# `deterministic_actions` or `stochastic_actions`). In case
|
||||
# actions are just [B], zeros_like works just fine here, but if
|
||||
# actions are [B, ...], we have to reduce logp back to just [B].
|
||||
if len(logp_.shape) > 1:
|
||||
logp_ = logp_[:, 0]
|
||||
return logp_
|
||||
|
||||
|
||||
def one_hot(x, space):
|
||||
if isinstance(space, Discrete):
|
||||
return tf.one_hot(x, space.n)
|
||||
|
|
Loading…
Add table
Reference in a new issue