mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Bandits use TrainerConfig objects. (#24687)
This commit is contained in:
parent
2fd888ac9d
commit
ebe6ab0afc
3 changed files with 124 additions and 56 deletions
|
@ -1,3 +1,13 @@
|
|||
from ray.rllib.agents.bandit.bandit import BanditLinTSTrainer, BanditLinUCBTrainer
|
||||
from ray.rllib.agents.bandit.bandit import (
|
||||
BanditLinTSTrainer,
|
||||
BanditLinUCBTrainer,
|
||||
BanditLinTSConfig,
|
||||
BanditLinUCBConfig,
|
||||
)
|
||||
|
||||
__all__ = ["BanditLinTSTrainer", "BanditLinUCBTrainer"]
|
||||
__all__ = [
|
||||
"BanditLinTSTrainer",
|
||||
"BanditLinUCBTrainer",
|
||||
"BanditLinTSConfig",
|
||||
"BanditLinUCBConfig",
|
||||
]
|
||||
|
|
|
@ -1,33 +1,89 @@
|
|||
import logging
|
||||
from typing import Type
|
||||
from typing import Type, Union
|
||||
|
||||
from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy
|
||||
from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# No remote workers by default.
|
||||
"num_workers": 0,
|
||||
"framework": "torch",
|
||||
|
||||
# Do online learning one step at a time.
|
||||
"rollout_fragment_length": 1,
|
||||
"train_batch_size": 1,
|
||||
class BanditConfig(TrainerConfig):
|
||||
"""Defines a contextual bandit configuration class from which
|
||||
a contexual bandit algorithm can be built. Note this config is shared
|
||||
between BanditLinUCBTrainer and BanditLinTSTrainer. You likely
|
||||
want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
|
||||
instead.
|
||||
"""
|
||||
|
||||
# Make sure, a `train()` call performs at least 100 env sampling timesteps, before
|
||||
# reporting results. Not setting this (default is 0) would significantly slow down
|
||||
# the Bandit Trainer.
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
def __init__(
|
||||
self, trainer_class: Union["BanditLinTSTrainer", "BanditLinUCBTrainer"] = None
|
||||
):
|
||||
super().__init__(trainer_class=trainer_class)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||
self.framework_str = "torch"
|
||||
self.num_workers = 0
|
||||
self.rollout_fragment_length = 1
|
||||
self.train_batch_size = 1
|
||||
# Make sure, a `train()` call performs at least 100 env sampling
|
||||
# timesteps, before reporting results. Not setting this (default is 0)
|
||||
# would significantly slow down the Bandit Trainer.
|
||||
self.min_sample_timesteps_per_reporting = 100
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BanditLinTSConfig(BanditConfig):
|
||||
"""Defines a configuration class from which a Thompson-sampling bandit can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.bandit import BanditLinTSConfig
|
||||
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
||||
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env=WheelBanditEnv)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(trainer_class=BanditLinTSTrainer)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||
self.exploration_config = {"type": "ThompsonSampling"}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BanditLinUCBConfig(BanditConfig):
|
||||
"""Defines a config class from which an upper confidence bound bandit can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.bandit import BanditLinUCBConfig
|
||||
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
||||
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env=WheelBanditEnv)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(trainer_class=BanditLinUCBTrainer)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||
self.exploration_config = {"type": "UpperConfidenceBound"}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BanditLinTSTrainer(Trainer):
|
||||
|
@ -35,15 +91,8 @@ class BanditLinTSTrainer(Trainer):
|
|||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
config = Trainer.merge_trainer_configs(
|
||||
DEFAULT_CONFIG,
|
||||
{
|
||||
# Use ThompsonSampling exploration.
|
||||
"exploration_config": {"type": "ThompsonSampling"}
|
||||
},
|
||||
)
|
||||
return config
|
||||
def get_default_config(cls) -> BanditLinTSConfig:
|
||||
return BanditLinTSConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
|
@ -52,20 +101,14 @@ class BanditLinTSTrainer(Trainer):
|
|||
elif config["framework"] == "tf2":
|
||||
return BanditTFPolicy
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
||||
|
||||
|
||||
class BanditLinUCBTrainer(Trainer):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return Trainer.merge_trainer_configs(
|
||||
DEFAULT_CONFIG,
|
||||
{
|
||||
# Use UpperConfidenceBound exploration.
|
||||
"exploration_config": {"type": "UpperConfidenceBound"}
|
||||
},
|
||||
)
|
||||
def get_default_config(cls) -> BanditLinUCBConfig:
|
||||
return BanditLinUCBConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
|
@ -74,4 +117,21 @@ class BanditLinUCBTrainer(Trainer):
|
|||
elif config["framework"] == "tf2":
|
||||
return BanditTFPolicy
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.bandit.BanditLinUCBConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(BanditLinUCBConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.bandit.bandit.BanditLin[UCB|TS]Config(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.bandit.bandit as bandit
|
||||
from ray.rllib.agents.bandit import bandit
|
||||
from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit
|
||||
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
|
||||
|
||||
|
@ -17,19 +17,17 @@ class TestBandits(unittest.TestCase):
|
|||
|
||||
def test_bandit_lin_ts_compilation(self):
|
||||
"""Test whether a BanditLinTSTrainer can be built on all frameworks."""
|
||||
config = {
|
||||
# Use a simple bandit-friendly env.
|
||||
"env": SimpleContextualBandit,
|
||||
"num_envs_per_worker": 2, # Test batched inference.
|
||||
"num_workers": 2, # Test distributed bandits.
|
||||
}
|
||||
|
||||
config = (
|
||||
bandit.BanditLinTSConfig()
|
||||
.environment(env=SimpleContextualBandit)
|
||||
.rollouts(num_rollout_workers=2, num_envs_per_worker=2)
|
||||
)
|
||||
num_iterations = 5
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
for train_batch_size in [1, 10]:
|
||||
config["train_batch_size"] = train_batch_size
|
||||
trainer = bandit.BanditLinTSTrainer(config=config)
|
||||
config.training(train_batch_size=train_batch_size)
|
||||
trainer = config.build()
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
@ -41,18 +39,18 @@ class TestBandits(unittest.TestCase):
|
|||
|
||||
def test_bandit_lin_ucb_compilation(self):
|
||||
"""Test whether a BanditLinUCBTrainer can be built on all frameworks."""
|
||||
config = {
|
||||
# Use a simple bandit-friendly env.
|
||||
"env": SimpleContextualBandit,
|
||||
"num_envs_per_worker": 2, # Test batched inference.
|
||||
}
|
||||
config = (
|
||||
bandit.BanditLinUCBConfig()
|
||||
.environment(env=SimpleContextualBandit)
|
||||
.rollouts(num_envs_per_worker=2)
|
||||
)
|
||||
|
||||
num_iterations = 5
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
for train_batch_size in [1, 10]:
|
||||
config["train_batch_size"] = train_batch_size
|
||||
trainer = bandit.BanditLinUCBTrainer(config=config)
|
||||
config.training(train_batch_size=train_batch_size)
|
||||
trainer = config.build()
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
|
Loading…
Add table
Reference in a new issue