[RLlib] Bandits use TrainerConfig objects. (#24687)

This commit is contained in:
Steven Morad 2022-05-12 21:02:15 +01:00 committed by GitHub
parent 2fd888ac9d
commit ebe6ab0afc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 56 deletions

View file

@ -1,3 +1,13 @@
from ray.rllib.agents.bandit.bandit import BanditLinTSTrainer, BanditLinUCBTrainer
from ray.rllib.agents.bandit.bandit import (
BanditLinTSTrainer,
BanditLinUCBTrainer,
BanditLinTSConfig,
BanditLinUCBConfig,
)
__all__ = ["BanditLinTSTrainer", "BanditLinUCBTrainer"]
__all__ = [
"BanditLinTSTrainer",
"BanditLinUCBTrainer",
"BanditLinTSConfig",
"BanditLinUCBConfig",
]

View file

@ -1,33 +1,89 @@
import logging
from typing import Type
from typing import Type, Union
from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy
from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
"framework": "torch",
# Do online learning one step at a time.
"rollout_fragment_length": 1,
"train_batch_size": 1,
class BanditConfig(TrainerConfig):
"""Defines a contextual bandit configuration class from which
a contexual bandit algorithm can be built. Note this config is shared
between BanditLinUCBTrainer and BanditLinTSTrainer. You likely
want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
instead.
"""
# Make sure, a `train()` call performs at least 100 env sampling timesteps, before
# reporting results. Not setting this (default is 0) would significantly slow down
# the Bandit Trainer.
"min_sample_timesteps_per_reporting": 100,
})
# __sphinx_doc_end__
# fmt: on
def __init__(
self, trainer_class: Union["BanditLinTSTrainer", "BanditLinUCBTrainer"] = None
):
super().__init__(trainer_class=trainer_class)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
self.framework_str = "torch"
self.num_workers = 0
self.rollout_fragment_length = 1
self.train_batch_size = 1
# Make sure, a `train()` call performs at least 100 env sampling
# timesteps, before reporting results. Not setting this (default is 0)
# would significantly slow down the Bandit Trainer.
self.min_sample_timesteps_per_reporting = 100
# __sphinx_doc_end__
# fmt: on
class BanditLinTSConfig(BanditConfig):
"""Defines a configuration class from which a Thompson-sampling bandit can be built.
Example:
>>> from ray.rllib.agents.bandit import BanditLinTSConfig
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=WheelBanditEnv)
>>> trainer.train()
"""
def __init__(self):
super().__init__(trainer_class=BanditLinTSTrainer)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
self.exploration_config = {"type": "ThompsonSampling"}
# __sphinx_doc_end__
# fmt: on
class BanditLinUCBConfig(BanditConfig):
"""Defines a config class from which an upper confidence bound bandit can be built.
Example:
>>> from ray.rllib.agents.bandit import BanditLinUCBConfig
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=WheelBanditEnv)
>>> trainer.train()
"""
def __init__(self):
super().__init__(trainer_class=BanditLinUCBTrainer)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
self.exploration_config = {"type": "UpperConfidenceBound"}
# __sphinx_doc_end__
# fmt: on
class BanditLinTSTrainer(Trainer):
@ -35,15 +91,8 @@ class BanditLinTSTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
config = Trainer.merge_trainer_configs(
DEFAULT_CONFIG,
{
# Use ThompsonSampling exploration.
"exploration_config": {"type": "ThompsonSampling"}
},
)
return config
def get_default_config(cls) -> BanditLinTSConfig:
return BanditLinTSConfig().to_dict()
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@ -52,20 +101,14 @@ class BanditLinTSTrainer(Trainer):
elif config["framework"] == "tf2":
return BanditTFPolicy
else:
raise NotImplementedError()
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
class BanditLinUCBTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return Trainer.merge_trainer_configs(
DEFAULT_CONFIG,
{
# Use UpperConfidenceBound exploration.
"exploration_config": {"type": "UpperConfidenceBound"}
},
)
def get_default_config(cls) -> BanditLinUCBConfig:
return BanditLinUCBConfig().to_dict()
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@ -74,4 +117,21 @@ class BanditLinUCBTrainer(Trainer):
elif config["framework"] == "tf2":
return BanditTFPolicy
else:
raise NotImplementedError()
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
# Deprecated: Use ray.rllib.agents.bandit.BanditLinUCBConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(BanditLinUCBConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG",
new="ray.rllib.agents.bandit.bandit.BanditLin[UCB|TS]Config(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -1,7 +1,7 @@
import unittest
import ray
import ray.rllib.agents.bandit.bandit as bandit
from ray.rllib.agents.bandit import bandit
from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
@ -17,19 +17,17 @@ class TestBandits(unittest.TestCase):
def test_bandit_lin_ts_compilation(self):
"""Test whether a BanditLinTSTrainer can be built on all frameworks."""
config = {
# Use a simple bandit-friendly env.
"env": SimpleContextualBandit,
"num_envs_per_worker": 2, # Test batched inference.
"num_workers": 2, # Test distributed bandits.
}
config = (
bandit.BanditLinTSConfig()
.environment(env=SimpleContextualBandit)
.rollouts(num_rollout_workers=2, num_envs_per_worker=2)
)
num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"):
for train_batch_size in [1, 10]:
config["train_batch_size"] = train_batch_size
trainer = bandit.BanditLinTSTrainer(config=config)
config.training(train_batch_size=train_batch_size)
trainer = config.build()
results = None
for i in range(num_iterations):
results = trainer.train()
@ -41,18 +39,18 @@ class TestBandits(unittest.TestCase):
def test_bandit_lin_ucb_compilation(self):
"""Test whether a BanditLinUCBTrainer can be built on all frameworks."""
config = {
# Use a simple bandit-friendly env.
"env": SimpleContextualBandit,
"num_envs_per_worker": 2, # Test batched inference.
}
config = (
bandit.BanditLinUCBConfig()
.environment(env=SimpleContextualBandit)
.rollouts(num_envs_per_worker=2)
)
num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"):
for train_batch_size in [1, 10]:
config["train_batch_size"] = train_batch_size
trainer = bandit.BanditLinUCBTrainer(config=config)
config.training(train_batch_size=train_batch_size)
trainer = config.build()
results = None
for i in range(num_iterations):
results = trainer.train()