diff --git a/rllib/agents/bandit/__init__.py b/rllib/agents/bandit/__init__.py index 4bb5de213..a910ffa1c 100644 --- a/rllib/agents/bandit/__init__.py +++ b/rllib/agents/bandit/__init__.py @@ -1,3 +1,13 @@ -from ray.rllib.agents.bandit.bandit import BanditLinTSTrainer, BanditLinUCBTrainer +from ray.rllib.agents.bandit.bandit import ( + BanditLinTSTrainer, + BanditLinUCBTrainer, + BanditLinTSConfig, + BanditLinUCBConfig, +) -__all__ = ["BanditLinTSTrainer", "BanditLinUCBTrainer"] +__all__ = [ + "BanditLinTSTrainer", + "BanditLinUCBTrainer", + "BanditLinTSConfig", + "BanditLinUCBConfig", +] diff --git a/rllib/agents/bandit/bandit.py b/rllib/agents/bandit/bandit.py index 78f63f58a..19ba77d46 100644 --- a/rllib/agents/bandit/bandit.py +++ b/rllib/agents/bandit/bandit.py @@ -1,33 +1,89 @@ import logging -from typing import Type +from typing import Type, Union from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import Trainer from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import TrainerConfigDict +from ray.rllib.agents.trainer_config import TrainerConfig +from ray.rllib.utils.deprecation import Deprecated logger = logging.getLogger(__name__) -# fmt: off -# __sphinx_doc_begin__ -DEFAULT_CONFIG = with_common_config({ - # No remote workers by default. - "num_workers": 0, - "framework": "torch", - # Do online learning one step at a time. - "rollout_fragment_length": 1, - "train_batch_size": 1, +class BanditConfig(TrainerConfig): + """Defines a contextual bandit configuration class from which + a contexual bandit algorithm can be built. Note this config is shared + between BanditLinUCBTrainer and BanditLinTSTrainer. You likely + want to use the child classes BanditLinTSConfig or BanditLinUCBConfig + instead. + """ - # Make sure, a `train()` call performs at least 100 env sampling timesteps, before - # reporting results. Not setting this (default is 0) would significantly slow down - # the Bandit Trainer. - "min_sample_timesteps_per_reporting": 100, -}) -# __sphinx_doc_end__ -# fmt: on + def __init__( + self, trainer_class: Union["BanditLinTSTrainer", "BanditLinUCBTrainer"] = None + ): + super().__init__(trainer_class=trainer_class) + # fmt: off + # __sphinx_doc_begin__ + # Override some of TrainerConfig's default values with bandit-specific values. + self.framework_str = "torch" + self.num_workers = 0 + self.rollout_fragment_length = 1 + self.train_batch_size = 1 + # Make sure, a `train()` call performs at least 100 env sampling + # timesteps, before reporting results. Not setting this (default is 0) + # would significantly slow down the Bandit Trainer. + self.min_sample_timesteps_per_reporting = 100 + # __sphinx_doc_end__ + # fmt: on + + +class BanditLinTSConfig(BanditConfig): + """Defines a configuration class from which a Thompson-sampling bandit can be built. + + Example: + >>> from ray.rllib.agents.bandit import BanditLinTSConfig + >>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv + >>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run 1 training iteration. + >>> trainer = config.build(env=WheelBanditEnv) + >>> trainer.train() + """ + + def __init__(self): + super().__init__(trainer_class=BanditLinTSTrainer) + # fmt: off + # __sphinx_doc_begin__ + # Override some of TrainerConfig's default values with bandit-specific values. + self.exploration_config = {"type": "ThompsonSampling"} + # __sphinx_doc_end__ + # fmt: on + + +class BanditLinUCBConfig(BanditConfig): + """Defines a config class from which an upper confidence bound bandit can be built. + + Example: + >>> from ray.rllib.agents.bandit import BanditLinUCBConfig + >>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv + >>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run 1 training iteration. + >>> trainer = config.build(env=WheelBanditEnv) + >>> trainer.train() + """ + + def __init__(self): + super().__init__(trainer_class=BanditLinUCBTrainer) + # fmt: off + # __sphinx_doc_begin__ + # Override some of TrainerConfig's default values with bandit-specific values. + self.exploration_config = {"type": "UpperConfidenceBound"} + # __sphinx_doc_end__ + # fmt: on class BanditLinTSTrainer(Trainer): @@ -35,15 +91,8 @@ class BanditLinTSTrainer(Trainer): @classmethod @override(Trainer) - def get_default_config(cls) -> TrainerConfigDict: - config = Trainer.merge_trainer_configs( - DEFAULT_CONFIG, - { - # Use ThompsonSampling exploration. - "exploration_config": {"type": "ThompsonSampling"} - }, - ) - return config + def get_default_config(cls) -> BanditLinTSConfig: + return BanditLinTSConfig().to_dict() @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: @@ -52,20 +101,14 @@ class BanditLinTSTrainer(Trainer): elif config["framework"] == "tf2": return BanditTFPolicy else: - raise NotImplementedError() + raise NotImplementedError("Only `framework=[torch|tf2]` supported!") class BanditLinUCBTrainer(Trainer): @classmethod @override(Trainer) - def get_default_config(cls) -> TrainerConfigDict: - return Trainer.merge_trainer_configs( - DEFAULT_CONFIG, - { - # Use UpperConfidenceBound exploration. - "exploration_config": {"type": "UpperConfidenceBound"} - }, - ) + def get_default_config(cls) -> BanditLinUCBConfig: + return BanditLinUCBConfig().to_dict() @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: @@ -74,4 +117,21 @@ class BanditLinUCBTrainer(Trainer): elif config["framework"] == "tf2": return BanditTFPolicy else: - raise NotImplementedError() + raise NotImplementedError("Only `framework=[torch|tf2]` supported!") + + +# Deprecated: Use ray.rllib.agents.bandit.BanditLinUCBConfig instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(BanditLinUCBConfig().to_dict()) + + @Deprecated( + old="ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG", + new="ray.rllib.agents.bandit.bandit.BanditLin[UCB|TS]Config(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +DEFAULT_CONFIG = _deprecated_default_config() diff --git a/rllib/agents/bandit/tests/test_bandits.py b/rllib/agents/bandit/tests/test_bandits.py index 84f8625ae..63bf61aea 100644 --- a/rllib/agents/bandit/tests/test_bandits.py +++ b/rllib/agents/bandit/tests/test_bandits.py @@ -1,7 +1,7 @@ import unittest import ray -import ray.rllib.agents.bandit.bandit as bandit +from ray.rllib.agents.bandit import bandit from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit from ray.rllib.utils.test_utils import check_train_results, framework_iterator @@ -17,19 +17,17 @@ class TestBandits(unittest.TestCase): def test_bandit_lin_ts_compilation(self): """Test whether a BanditLinTSTrainer can be built on all frameworks.""" - config = { - # Use a simple bandit-friendly env. - "env": SimpleContextualBandit, - "num_envs_per_worker": 2, # Test batched inference. - "num_workers": 2, # Test distributed bandits. - } - + config = ( + bandit.BanditLinTSConfig() + .environment(env=SimpleContextualBandit) + .rollouts(num_rollout_workers=2, num_envs_per_worker=2) + ) num_iterations = 5 for _ in framework_iterator(config, frameworks="torch"): for train_batch_size in [1, 10]: - config["train_batch_size"] = train_batch_size - trainer = bandit.BanditLinTSTrainer(config=config) + config.training(train_batch_size=train_batch_size) + trainer = config.build() results = None for i in range(num_iterations): results = trainer.train() @@ -41,18 +39,18 @@ class TestBandits(unittest.TestCase): def test_bandit_lin_ucb_compilation(self): """Test whether a BanditLinUCBTrainer can be built on all frameworks.""" - config = { - # Use a simple bandit-friendly env. - "env": SimpleContextualBandit, - "num_envs_per_worker": 2, # Test batched inference. - } + config = ( + bandit.BanditLinUCBConfig() + .environment(env=SimpleContextualBandit) + .rollouts(num_envs_per_worker=2) + ) num_iterations = 5 for _ in framework_iterator(config, frameworks="torch"): for train_batch_size in [1, 10]: - config["train_batch_size"] = train_batch_size - trainer = bandit.BanditLinUCBTrainer(config=config) + config.training(train_batch_size=train_batch_size) + trainer = config.build() results = None for i in range(num_iterations): results = trainer.train()