mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Bandits use TrainerConfig objects. (#24687)
This commit is contained in:
parent
2fd888ac9d
commit
ebe6ab0afc
3 changed files with 124 additions and 56 deletions
|
@ -1,3 +1,13 @@
|
||||||
from ray.rllib.agents.bandit.bandit import BanditLinTSTrainer, BanditLinUCBTrainer
|
from ray.rllib.agents.bandit.bandit import (
|
||||||
|
BanditLinTSTrainer,
|
||||||
|
BanditLinUCBTrainer,
|
||||||
|
BanditLinTSConfig,
|
||||||
|
BanditLinUCBConfig,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["BanditLinTSTrainer", "BanditLinUCBTrainer"]
|
__all__ = [
|
||||||
|
"BanditLinTSTrainer",
|
||||||
|
"BanditLinUCBTrainer",
|
||||||
|
"BanditLinTSConfig",
|
||||||
|
"BanditLinUCBConfig",
|
||||||
|
]
|
||||||
|
|
|
@ -1,33 +1,89 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Type
|
from typing import Type, Union
|
||||||
|
|
||||||
from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy
|
from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy
|
||||||
from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy
|
from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# __sphinx_doc_begin__
|
|
||||||
DEFAULT_CONFIG = with_common_config({
|
|
||||||
# No remote workers by default.
|
|
||||||
"num_workers": 0,
|
|
||||||
"framework": "torch",
|
|
||||||
|
|
||||||
# Do online learning one step at a time.
|
class BanditConfig(TrainerConfig):
|
||||||
"rollout_fragment_length": 1,
|
"""Defines a contextual bandit configuration class from which
|
||||||
"train_batch_size": 1,
|
a contexual bandit algorithm can be built. Note this config is shared
|
||||||
|
between BanditLinUCBTrainer and BanditLinTSTrainer. You likely
|
||||||
|
want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
|
||||||
|
instead.
|
||||||
|
"""
|
||||||
|
|
||||||
# Make sure, a `train()` call performs at least 100 env sampling timesteps, before
|
def __init__(
|
||||||
# reporting results. Not setting this (default is 0) would significantly slow down
|
self, trainer_class: Union["BanditLinTSTrainer", "BanditLinUCBTrainer"] = None
|
||||||
# the Bandit Trainer.
|
):
|
||||||
"min_sample_timesteps_per_reporting": 100,
|
super().__init__(trainer_class=trainer_class)
|
||||||
})
|
# fmt: off
|
||||||
# __sphinx_doc_end__
|
# __sphinx_doc_begin__
|
||||||
# fmt: on
|
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||||
|
self.framework_str = "torch"
|
||||||
|
self.num_workers = 0
|
||||||
|
self.rollout_fragment_length = 1
|
||||||
|
self.train_batch_size = 1
|
||||||
|
# Make sure, a `train()` call performs at least 100 env sampling
|
||||||
|
# timesteps, before reporting results. Not setting this (default is 0)
|
||||||
|
# would significantly slow down the Bandit Trainer.
|
||||||
|
self.min_sample_timesteps_per_reporting = 100
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class BanditLinTSConfig(BanditConfig):
|
||||||
|
"""Defines a configuration class from which a Thompson-sampling bandit can be built.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.bandit import BanditLinTSConfig
|
||||||
|
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
||||||
|
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
|
||||||
|
>>> print(config.to_dict())
|
||||||
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||||
|
>>> trainer = config.build(env=WheelBanditEnv)
|
||||||
|
>>> trainer.train()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(trainer_class=BanditLinTSTrainer)
|
||||||
|
# fmt: off
|
||||||
|
# __sphinx_doc_begin__
|
||||||
|
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||||
|
self.exploration_config = {"type": "ThompsonSampling"}
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class BanditLinUCBConfig(BanditConfig):
|
||||||
|
"""Defines a config class from which an upper confidence bound bandit can be built.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.bandit import BanditLinUCBConfig
|
||||||
|
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
||||||
|
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
|
||||||
|
>>> print(config.to_dict())
|
||||||
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||||
|
>>> trainer = config.build(env=WheelBanditEnv)
|
||||||
|
>>> trainer.train()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(trainer_class=BanditLinUCBTrainer)
|
||||||
|
# fmt: off
|
||||||
|
# __sphinx_doc_begin__
|
||||||
|
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||||
|
self.exploration_config = {"type": "UpperConfidenceBound"}
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class BanditLinTSTrainer(Trainer):
|
class BanditLinTSTrainer(Trainer):
|
||||||
|
@ -35,15 +91,8 @@ class BanditLinTSTrainer(Trainer):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> BanditLinTSConfig:
|
||||||
config = Trainer.merge_trainer_configs(
|
return BanditLinTSConfig().to_dict()
|
||||||
DEFAULT_CONFIG,
|
|
||||||
{
|
|
||||||
# Use ThompsonSampling exploration.
|
|
||||||
"exploration_config": {"type": "ThompsonSampling"}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return config
|
|
||||||
|
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||||
|
@ -52,20 +101,14 @@ class BanditLinTSTrainer(Trainer):
|
||||||
elif config["framework"] == "tf2":
|
elif config["framework"] == "tf2":
|
||||||
return BanditTFPolicy
|
return BanditTFPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
||||||
|
|
||||||
|
|
||||||
class BanditLinUCBTrainer(Trainer):
|
class BanditLinUCBTrainer(Trainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> BanditLinUCBConfig:
|
||||||
return Trainer.merge_trainer_configs(
|
return BanditLinUCBConfig().to_dict()
|
||||||
DEFAULT_CONFIG,
|
|
||||||
{
|
|
||||||
# Use UpperConfidenceBound exploration.
|
|
||||||
"exploration_config": {"type": "UpperConfidenceBound"}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||||
|
@ -74,4 +117,21 @@ class BanditLinUCBTrainer(Trainer):
|
||||||
elif config["framework"] == "tf2":
|
elif config["framework"] == "tf2":
|
||||||
return BanditTFPolicy
|
return BanditTFPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.agents.bandit.BanditLinUCBConfig instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(BanditLinUCBConfig().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.agents.bandit.bandit.BanditLin[UCB|TS]Config(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.bandit.bandit as bandit
|
from ray.rllib.agents.bandit import bandit
|
||||||
from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit
|
from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit
|
||||||
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
|
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
|
||||||
|
|
||||||
|
@ -17,19 +17,17 @@ class TestBandits(unittest.TestCase):
|
||||||
|
|
||||||
def test_bandit_lin_ts_compilation(self):
|
def test_bandit_lin_ts_compilation(self):
|
||||||
"""Test whether a BanditLinTSTrainer can be built on all frameworks."""
|
"""Test whether a BanditLinTSTrainer can be built on all frameworks."""
|
||||||
config = {
|
config = (
|
||||||
# Use a simple bandit-friendly env.
|
bandit.BanditLinTSConfig()
|
||||||
"env": SimpleContextualBandit,
|
.environment(env=SimpleContextualBandit)
|
||||||
"num_envs_per_worker": 2, # Test batched inference.
|
.rollouts(num_rollout_workers=2, num_envs_per_worker=2)
|
||||||
"num_workers": 2, # Test distributed bandits.
|
)
|
||||||
}
|
|
||||||
|
|
||||||
num_iterations = 5
|
num_iterations = 5
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks="torch"):
|
for _ in framework_iterator(config, frameworks="torch"):
|
||||||
for train_batch_size in [1, 10]:
|
for train_batch_size in [1, 10]:
|
||||||
config["train_batch_size"] = train_batch_size
|
config.training(train_batch_size=train_batch_size)
|
||||||
trainer = bandit.BanditLinTSTrainer(config=config)
|
trainer = config.build()
|
||||||
results = None
|
results = None
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
|
@ -41,18 +39,18 @@ class TestBandits(unittest.TestCase):
|
||||||
|
|
||||||
def test_bandit_lin_ucb_compilation(self):
|
def test_bandit_lin_ucb_compilation(self):
|
||||||
"""Test whether a BanditLinUCBTrainer can be built on all frameworks."""
|
"""Test whether a BanditLinUCBTrainer can be built on all frameworks."""
|
||||||
config = {
|
config = (
|
||||||
# Use a simple bandit-friendly env.
|
bandit.BanditLinUCBConfig()
|
||||||
"env": SimpleContextualBandit,
|
.environment(env=SimpleContextualBandit)
|
||||||
"num_envs_per_worker": 2, # Test batched inference.
|
.rollouts(num_envs_per_worker=2)
|
||||||
}
|
)
|
||||||
|
|
||||||
num_iterations = 5
|
num_iterations = 5
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks="torch"):
|
for _ in framework_iterator(config, frameworks="torch"):
|
||||||
for train_batch_size in [1, 10]:
|
for train_batch_size in [1, 10]:
|
||||||
config["train_batch_size"] = train_batch_size
|
config.training(train_batch_size=train_batch_size)
|
||||||
trainer = bandit.BanditLinUCBTrainer(config=config)
|
trainer = config.build()
|
||||||
results = None
|
results = None
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
|
|
Loading…
Add table
Reference in a new issue