import functools import gym import numpy as np from typing import Optional, Union from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import ( get_variable, try_import_tf, try_import_torch, TensorType, ) from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() class StochasticSampling(Exploration): """An exploration that simply samples from a distribution. The sampling can be made deterministic by passing explore=False into the call to `get_exploration_action`. Also allows for scheduled parameters for the distributions, such as lowering stddev, temperature, etc.. over time. """ def __init__( self, action_space: gym.spaces.Space, *, framework: str, model: ModelV2, random_timesteps: int = 0, **kwargs ): """Initializes a StochasticSampling Exploration object. Args: action_space: The gym action space used by the environment. framework: One of None, "tf", "torch". model: The ModelV2 used by the owning Policy. random_timesteps: The number of timesteps for which to act completely randomly. Only after this number of timesteps, actual samples will be drawn to get exploration actions. """ assert framework is not None super().__init__(action_space, model=model, framework=framework, **kwargs) # Create the Random exploration module (used for the first n # timesteps). self.random_timesteps = random_timesteps self.random_exploration = Random( action_space, model=self.model, framework=self.framework, **kwargs ) # The current timestep value (tf-var or python int). self.last_timestep = get_variable( np.array(0, np.int64), framework=self.framework, tf_name="timestep", dtype=np.int64, ) @override(Exploration) def get_exploration_action( self, *, action_distribution: ActionDistribution, timestep: Optional[Union[int, TensorType]] = None, explore: bool = True ): if self.framework == "torch": return self._get_torch_exploration_action( action_distribution, timestep, explore ) else: return self._get_tf_exploration_action_op( action_distribution, timestep, explore ) def _get_tf_exploration_action_op(self, action_dist, timestep, explore): ts = self.last_timestep + 1 stochastic_actions = tf.cond( pred=tf.convert_to_tensor(ts < self.random_timesteps), true_fn=lambda: ( self.random_exploration.get_tf_exploration_action_op( action_dist, explore=True )[0] ), false_fn=lambda: action_dist.sample(), ) deterministic_actions = action_dist.deterministic_sample() action = tf.cond( tf.constant(explore) if isinstance(explore, bool) else explore, true_fn=lambda: stochastic_actions, false_fn=lambda: deterministic_actions, ) logp = tf.cond( tf.math.logical_and( explore, tf.convert_to_tensor(ts >= self.random_timesteps) ), true_fn=lambda: action_dist.sampled_action_logp(), false_fn=functools.partial(zero_logps_from_actions, deterministic_actions), ) # Increment `last_timestep` by 1 (or set to `timestep`). if self.framework in ["tf2", "tfe"]: self.last_timestep.assign_add(1) return action, logp else: assign_op = ( tf1.assign_add(self.last_timestep, 1) if timestep is None else tf1.assign(self.last_timestep, timestep) ) with tf1.control_dependencies([assign_op]): return action, logp def _get_torch_exploration_action( self, action_dist: ActionDistribution, timestep: Union[TensorType, int], explore: Union[TensorType, bool], ): # Set last timestep or (if not given) increase by one. self.last_timestep = ( timestep if timestep is not None else self.last_timestep + 1 ) # Apply exploration. if explore: # Random exploration phase. if self.last_timestep < self.random_timesteps: action, logp = self.random_exploration.get_torch_exploration_action( action_dist, explore=True ) # Take a sample from our distribution. else: action = action_dist.sample() logp = action_dist.sampled_action_logp() # No exploration -> Return deterministic actions. else: action = action_dist.deterministic_sample() logp = torch.zeros_like(action_dist.sampled_action_logp()) return action, logp