2020-03-26 13:41:16 -07:00
|
|
|
from typing import Union
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
2020-04-01 09:43:21 +02:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-03-26 13:41:16 -07:00
|
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
from ray.rllib.utils.exploration.exploration import Exploration
|
2022-03-21 08:55:55 -07:00
|
|
|
from ray.rllib.utils.framework import (
|
|
|
|
TensorType,
|
|
|
|
try_import_tf,
|
|
|
|
)
|
|
|
|
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-03-26 13:41:16 -07:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@PublicAPI
|
2020-03-26 13:41:16 -07:00
|
|
|
class ThompsonSampling(Exploration):
|
|
|
|
@override(Exploration)
|
|
|
|
def get_exploration_action(
|
|
|
|
self,
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution: ActionDistribution,
|
2020-03-26 13:41:16 -07:00
|
|
|
timestep: Union[int, TensorType],
|
|
|
|
explore: bool = True,
|
|
|
|
):
|
|
|
|
if self.framework == "torch":
|
2020-04-01 09:43:21 +02:00
|
|
|
return self._get_torch_exploration_action(action_distribution, explore)
|
2022-03-21 08:55:55 -07:00
|
|
|
elif self.framework == "tf2":
|
|
|
|
return self._get_tf_exploration_action(action_distribution, explore)
|
2020-03-26 13:41:16 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
def _get_torch_exploration_action(self, action_dist, explore):
|
2020-03-26 13:41:16 -07:00
|
|
|
if explore:
|
2021-06-15 13:30:31 +02:00
|
|
|
return action_dist.inputs.argmax(dim=-1), None
|
2020-03-26 13:41:16 -07:00
|
|
|
else:
|
2020-04-01 09:43:21 +02:00
|
|
|
scores = self.model.predict(self.model.current_obs())
|
2021-06-15 13:30:31 +02:00
|
|
|
return scores.argmax(dim=-1), None
|
2022-03-21 08:55:55 -07:00
|
|
|
|
|
|
|
def _get_tf_exploration_action(self, action_dist, explore):
|
|
|
|
action = tf.argmax(
|
|
|
|
tf.cond(
|
|
|
|
pred=explore,
|
|
|
|
true_fn=lambda: action_dist.inputs,
|
|
|
|
false_fn=lambda: self.model.predict(self.model.current_obs()),
|
|
|
|
),
|
|
|
|
axis=-1,
|
|
|
|
)
|
|
|
|
return action, None
|