mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Rollback. * WIP. * WIP. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
46 lines
1.8 KiB
Python
46 lines
1.8 KiB
Python
from typing import Union
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.exploration.exploration import Exploration
|
|
from ray.rllib.utils.framework import TensorType
|
|
|
|
|
|
class ThompsonSampling(Exploration):
|
|
@override(Exploration)
|
|
def get_exploration_action(self,
|
|
action_distribution: ActionDistribution,
|
|
timestep: Union[int, TensorType],
|
|
explore: bool = True):
|
|
if self.framework == "torch":
|
|
return self._get_torch_exploration_action(action_distribution,
|
|
explore)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def _get_torch_exploration_action(self, action_dist, explore):
|
|
if explore:
|
|
return action_dist.inputs.argmax(dim=1), None
|
|
else:
|
|
scores = self.model.predict(self.model.current_obs())
|
|
return scores.argmax(dim=1), None
|
|
|
|
|
|
class UCB(Exploration):
|
|
@override(Exploration)
|
|
def get_exploration_action(self,
|
|
action_distribution: ActionDistribution,
|
|
timestep: Union[int, TensorType],
|
|
explore: bool = True):
|
|
if self.framework == "torch":
|
|
return self._get_torch_exploration_action(action_distribution,
|
|
explore)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def _get_torch_exploration_action(self, action_dist, explore):
|
|
if explore:
|
|
return action_dist.inputs.argmax(dim=1), None
|
|
else:
|
|
scores = self.model.value_function()
|
|
return scores.argmax(dim=1), None
|