ray/rllib/utils/exploration/slate_epsilon_greedy.py

109 lines
3.6 KiB
Python
Raw Normal View History

from typing import Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
from ray.rllib.utils.exploration.exploration import TensorType
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class SlateEpsilonGreedy(EpsilonGreedy):
@override(EpsilonGreedy)
def _get_tf_exploration_action_op(
self,
action_distribution: ActionDistribution,
explore: Union[bool, TensorType],
timestep: Union[int, TensorType],
) -> "tf.Tensor":
per_slate_q_values = action_distribution.inputs
all_slates = action_distribution.all_slates
exploit_action = action_distribution.deterministic_sample()
batch_size, num_slates = (
tf.shape(per_slate_q_values)[0],
tf.shape(per_slate_q_values)[1],
)
action_logp = tf.zeros(batch_size, dtype=tf.float32)
# Get the current epsilon.
epsilon = self.epsilon_schedule(
timestep if timestep is not None else self.last_timestep
)
# A random action.
random_indices = tf.random.uniform(
(batch_size,),
minval=0,
maxval=num_slates,
dtype=tf.dtypes.int32,
)
random_actions = tf.gather(all_slates, random_indices)
choose_random = (
tf.random.uniform(
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
)
< epsilon
)
# Pick either random or greedy.
action = tf.cond(
pred=tf.constant(explore, dtype=tf.bool)
if isinstance(explore, bool)
else explore,
true_fn=(lambda: tf.where(choose_random, random_actions, exploit_action)),
false_fn=lambda: exploit_action,
)
if self.framework in ["tf2", "tfe"] and not self.policy_config["eager_tracing"]:
self.last_timestep = timestep
return action, action_logp
else:
assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
with tf1.control_dependencies([assign_op]):
return action, action_logp
@override(EpsilonGreedy)
def _get_torch_exploration_action(
self,
action_distribution: ActionDistribution,
explore: bool,
timestep: Union[int, TensorType],
) -> "torch.Tensor":
per_slate_q_values = action_distribution.inputs
all_slates = self.model.slates
exploit_indices = action_distribution.deterministic_sample()
exploit_action = all_slates[exploit_indices]
batch_size = per_slate_q_values.size()[0]
action_logp = torch.zeros(batch_size, dtype=torch.float)
self.last_timestep = timestep
# Explore.
if explore:
# Get the current epsilon.
epsilon = self.epsilon_schedule(self.last_timestep)
# A random action.
random_indices = torch.randint(
0, per_slate_q_values.shape[1], (per_slate_q_values.shape[0],)
)
random_actions = all_slates[random_indices]
# Pick either random or greedy.
action = torch.where(
torch.empty((batch_size,)).uniform_() < epsilon,
random_actions,
exploit_action,
)
return action, action_logp
# Return the deterministic "sample" (argmax) over the logits.
else:
return exploit_action, action_logp