mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import logging
|
|
from typing import Type
|
|
|
|
from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy
|
|
from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy
|
|
from ray.rllib.agents.trainer import Trainer, with_common_config
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# No remote workers by default.
|
|
"num_workers": 0,
|
|
"framework": "torch",
|
|
|
|
# Do online learning one step at a time.
|
|
"rollout_fragment_length": 1,
|
|
"train_batch_size": 1,
|
|
|
|
# Make sure, a `train()` call performs at least 100 env sampling timesteps, before
|
|
# reporting results. Not setting this (default is 0) would significantly slow down
|
|
# the Bandit Trainer.
|
|
"min_sample_timesteps_per_reporting": 100,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class BanditLinTSTrainer(Trainer):
|
|
"""Bandit Trainer using ThompsonSampling exploration."""
|
|
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
config = Trainer.merge_trainer_configs(
|
|
DEFAULT_CONFIG,
|
|
{
|
|
# Use ThompsonSampling exploration.
|
|
"exploration_config": {"type": "ThompsonSampling"}
|
|
},
|
|
)
|
|
return config
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
return BanditTorchPolicy
|
|
elif config["framework"] == "tf2":
|
|
return BanditTFPolicy
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class BanditLinUCBTrainer(Trainer):
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return Trainer.merge_trainer_configs(
|
|
DEFAULT_CONFIG,
|
|
{
|
|
# Use UpperConfidenceBound exploration.
|
|
"exploration_config": {"type": "UpperConfidenceBound"}
|
|
},
|
|
)
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
return BanditTorchPolicy
|
|
elif config["framework"] == "tf2":
|
|
return BanditTFPolicy
|
|
else:
|
|
raise NotImplementedError()
|