mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
135 lines
4.9 KiB
Python
135 lines
4.9 KiB
Python
import logging
|
|
from typing import Type, Union
|
|
|
|
from ray.rllib.algorithms.bandit.bandit_tf_policy import BanditTFPolicy
|
|
from ray.rllib.algorithms.bandit.bandit_torch_policy import BanditTorchPolicy
|
|
from ray.rllib.agents.trainer import Trainer
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BanditConfig(TrainerConfig):
|
|
"""Defines a contextual bandit configuration class from which
|
|
a contexual bandit algorithm can be built. Note this config is shared
|
|
between BanditLinUCB and BanditLinTS. You likely
|
|
want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
|
|
instead.
|
|
"""
|
|
|
|
def __init__(self, trainer_class: Union["BanditLinTS", "BanditLinUCB"] = None):
|
|
super().__init__(trainer_class=trainer_class)
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
# Override some of TrainerConfig's default values with bandit-specific values.
|
|
self.framework_str = "torch"
|
|
self.num_workers = 0
|
|
self.rollout_fragment_length = 1
|
|
self.train_batch_size = 1
|
|
# Make sure, a `train()` call performs at least 100 env sampling
|
|
# timesteps, before reporting results. Not setting this (default is 0)
|
|
# would significantly slow down the Bandit Trainer.
|
|
self.min_sample_timesteps_per_reporting = 100
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class BanditLinTSConfig(BanditConfig):
|
|
"""Defines a configuration class from which a Thompson-sampling bandit can be built.
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.bandit import BanditLinTSConfig
|
|
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
|
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
|
>>> trainer = config.build(env=WheelBanditEnv)
|
|
>>> trainer.train()
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__(trainer_class=BanditLinTS)
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
# Override some of TrainerConfig's default values with bandit-specific values.
|
|
self.exploration_config = {"type": "ThompsonSampling"}
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class BanditLinUCBConfig(BanditConfig):
|
|
"""Defines a config class from which an upper confidence bound bandit can be built.
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.bandit import BanditLinUCBConfig
|
|
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
|
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
|
>>> trainer = config.build(env=WheelBanditEnv)
|
|
>>> trainer.train()
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__(trainer_class=BanditLinUCB)
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
# Override some of TrainerConfig's default values with bandit-specific values.
|
|
self.exploration_config = {"type": "UpperConfidenceBound"}
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class BanditLinTS(Trainer):
|
|
"""Bandit Trainer using ThompsonSampling exploration."""
|
|
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> BanditLinTSConfig:
|
|
return BanditLinTSConfig().to_dict()
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
return BanditTorchPolicy
|
|
elif config["framework"] == "tf2":
|
|
return BanditTFPolicy
|
|
else:
|
|
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
|
|
|
|
|
class BanditLinUCB(Trainer):
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> BanditLinUCBConfig:
|
|
return BanditLinUCBConfig().to_dict()
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
return BanditTorchPolicy
|
|
elif config["framework"] == "tf2":
|
|
return BanditTFPolicy
|
|
else:
|
|
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.bandit.BanditLinUCBConfig instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(BanditLinUCBConfig().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.algorithms.bandit.bandit.DEFAULT_CONFIG",
|
|
new="ray.rllib.algorithms.bandit.bandit.BanditLin[UCB|TS]Config(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|