diff --git a/release/rllib_tests/app_config.yaml b/release/rllib_tests/app_config.yaml index 9270dfe01..938f2d63a 100755 --- a/release/rllib_tests/app_config.yaml +++ b/release/rllib_tests/app_config.yaml @@ -8,7 +8,4 @@ python: conda_packages: [] post_build_cmds: - - pip uninstall -y numpy ray || true - - sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy - - pip3 install numpy==1.19.5 || true - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} diff --git a/rllib/examples/nested_action_spaces.py b/rllib/examples/nested_action_spaces.py index e80ac3a9c..1b00d3009 100644 --- a/rllib/examples/nested_action_spaces.py +++ b/rllib/examples/nested_action_spaces.py @@ -26,6 +26,10 @@ parser.add_argument( action="store_true", help="Whether this script should be run as a test: --stop-reward must " "be achieved within --stop-timesteps AND --stop-iters.") +parser.add_argument( + "--local-mode", + action="store_true", + help="Init Ray in local mode for easier debugging.") parser.add_argument( "--stop-iters", type=int, @@ -44,7 +48,7 @@ parser.add_argument( if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None) + ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode) register_env("NestedSpaceRepeatAfterMeEnv", lambda c: NestedSpaceRepeatAfterMeEnv(c)) diff --git a/rllib/utils/exploration/gaussian_noise.py b/rllib/utils/exploration/gaussian_noise.py index 3178d6532..3c1972d1e 100644 --- a/rllib/utils/exploration/gaussian_noise.py +++ b/rllib/utils/exploration/gaussian_noise.py @@ -12,6 +12,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import Schedule from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.tf_ops import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -131,7 +132,7 @@ class GaussianNoise(Exploration): true_fn=lambda: stochastic_actions, false_fn=lambda: deterministic_actions) # Logp=always zero. - logp = tf.zeros_like(deterministic_actions, dtype=tf.float32)[:, 0] + logp = zero_logps_from_actions(deterministic_actions) # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: diff --git a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py index 4daae937b..ba7582903 100644 --- a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -8,6 +8,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ get_variable, TensorType from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import Schedule +from ray.rllib.utils.tf_ops import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -134,7 +135,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise): true_fn=lambda: exploration_actions, false_fn=lambda: deterministic_actions) # Logp=always zero. - logp = tf.zeros_like(deterministic_actions, dtype=tf.float32)[:, 0] + logp = zero_logps_from_actions(deterministic_actions) # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index 3653abadf..d1d6c4d0a 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -12,6 +12,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.tf_ops import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -72,6 +73,10 @@ class Random(Exploration): # Function to produce random samples from primitive space # components: (Multi)Discrete or Box. def random_component(component): + # Have at least an additional shape of (1,), even if the + # component is Box(-1.0, 1.0, shape=()). + shape = component.shape or (1, ) + if isinstance(component, Discrete): return tf.random.uniform( shape=(batch_size, ) + component.shape, @@ -91,19 +96,19 @@ class Random(Exploration): component.bounded_below.all(): if component.dtype.name.startswith("int"): return tf.random.uniform( - shape=(batch_size, ) + component.shape, + shape=(batch_size, ) + shape, minval=component.low.flat[0], maxval=component.high.flat[0], dtype=component.dtype) else: return tf.random.uniform( - shape=(batch_size, ) + component.shape, + shape=(batch_size, ) + shape, minval=component.low, maxval=component.high, dtype=component.dtype) else: return tf.random.normal( - shape=(batch_size, ) + component.shape, + shape=(batch_size, ) + shape, dtype=component.dtype) else: assert isinstance(component, Simplex), \ @@ -111,7 +116,7 @@ class Random(Exploration): "sampling!".format(component) return tf.nn.softmax( tf.random.uniform( - shape=(batch_size, ) + component.shape, + shape=(batch_size, ) + shape, minval=0.0, maxval=1.0, dtype=component.dtype)) @@ -129,8 +134,7 @@ class Random(Exploration): true_fn=true_fn, false_fn=false_fn) - # TODO(sven): Move into (deterministic_)sample(logp=True|False) - logp = tf.zeros_like(tree.flatten(action)[0], dtype=tf.float32)[:1] + logp = zero_logps_from_actions(action) return action, logp def get_torch_exploration_action(self, action_dist: ActionDistribution, diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index 67a192cf2..daa6089d4 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -1,6 +1,6 @@ +import functools import gym import numpy as np -import tree # pip install dm_tree from typing import Union from ray.rllib.models.action_dist import ActionDistribution @@ -10,6 +10,7 @@ from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import get_variable, try_import_tf, \ try_import_torch, TensorType +from ray.rllib.utils.tf_ops import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -90,15 +91,12 @@ class StochasticSampling(Exploration): true_fn=lambda: stochastic_actions, false_fn=lambda: deterministic_actions) - def logp_false_fn(): - batch_size = tf.shape(tree.flatten(action)[0])[0] - return tf.zeros(shape=(batch_size, ), dtype=tf.float32) - logp = tf.cond( tf.math.logical_and( explore, tf.convert_to_tensor(ts >= self.random_timesteps)), true_fn=lambda: action_dist.sampled_action_logp(), - false_fn=logp_false_fn) + false_fn=functools.partial(zero_logps_from_actions, + deterministic_actions)) # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 36fed0808..29ec25e6b 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -4,6 +4,7 @@ import numpy as np import tree # pip install dm_tree from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorStructType, TensorType tf1, tf, tfv = try_import_tf() @@ -110,6 +111,31 @@ def huber_loss(x, delta=1.0): tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta)) +def zero_logps_from_actions(actions: TensorStructType) -> TensorType: + """Helper function useful for returning dummy logp's (0) for some actions. + + Args: + actions (TensorStructType): The input actions. This can be any struct + of complex action components or a simple tensor of different + dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}. + + Returns: + TensorType: A 1D tensor of 0.0 (dummy logp's) matching the batch + dim of `actions` (shape=[B]). + """ + # Need to flatten `actions` in case we have a complex action space. + # Take the 0th component to extract the batch dim. + action_component = tree.flatten(actions)[0] + logp_ = tf.zeros_like(action_component, dtype=tf.float32) + # Logp's should be single values (but with the same batch dim as + # `deterministic_actions` or `stochastic_actions`). In case + # actions are just [B], zeros_like works just fine here, but if + # actions are [B, ...], we have to reduce logp back to just [B]. + if len(logp_.shape) > 1: + logp_ = logp_[:, 0] + return logp_ + + def one_hot(x, space): if isinstance(space, Discrete): return tf.one_hot(x, space.n)