[RLlib] Bandits use TrainerConfig objects. (#24687)

This commit is contained in:
Steven Morad 2022-05-12 21:02:15 +01:00 committed by GitHub
parent 2fd888ac9d
commit ebe6ab0afc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 56 deletions

View file

@ -1,3 +1,13 @@
from ray.rllib.agents.bandit.bandit import BanditLinTSTrainer, BanditLinUCBTrainer from ray.rllib.agents.bandit.bandit import (
BanditLinTSTrainer,
BanditLinUCBTrainer,
BanditLinTSConfig,
BanditLinUCBConfig,
)
__all__ = ["BanditLinTSTrainer", "BanditLinUCBTrainer"] __all__ = [
"BanditLinTSTrainer",
"BanditLinUCBTrainer",
"BanditLinTSConfig",
"BanditLinUCBConfig",
]

View file

@ -1,33 +1,89 @@
import logging import logging
from typing import Type from typing import Type, Union
from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy from ray.rllib.agents.bandit.bandit_tf_policy import BanditTFPolicy
from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy from ray.rllib.agents.bandit.bandit_torch_policy import BanditTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
"framework": "torch",
# Do online learning one step at a time. class BanditConfig(TrainerConfig):
"rollout_fragment_length": 1, """Defines a contextual bandit configuration class from which
"train_batch_size": 1, a contexual bandit algorithm can be built. Note this config is shared
between BanditLinUCBTrainer and BanditLinTSTrainer. You likely
want to use the child classes BanditLinTSConfig or BanditLinUCBConfig
instead.
"""
# Make sure, a `train()` call performs at least 100 env sampling timesteps, before def __init__(
# reporting results. Not setting this (default is 0) would significantly slow down self, trainer_class: Union["BanditLinTSTrainer", "BanditLinUCBTrainer"] = None
# the Bandit Trainer. ):
"min_sample_timesteps_per_reporting": 100, super().__init__(trainer_class=trainer_class)
}) # fmt: off
# __sphinx_doc_end__ # __sphinx_doc_begin__
# fmt: on # Override some of TrainerConfig's default values with bandit-specific values.
self.framework_str = "torch"
self.num_workers = 0
self.rollout_fragment_length = 1
self.train_batch_size = 1
# Make sure, a `train()` call performs at least 100 env sampling
# timesteps, before reporting results. Not setting this (default is 0)
# would significantly slow down the Bandit Trainer.
self.min_sample_timesteps_per_reporting = 100
# __sphinx_doc_end__
# fmt: on
class BanditLinTSConfig(BanditConfig):
"""Defines a configuration class from which a Thompson-sampling bandit can be built.
Example:
>>> from ray.rllib.agents.bandit import BanditLinTSConfig
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=WheelBanditEnv)
>>> trainer.train()
"""
def __init__(self):
super().__init__(trainer_class=BanditLinTSTrainer)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
self.exploration_config = {"type": "ThompsonSampling"}
# __sphinx_doc_end__
# fmt: on
class BanditLinUCBConfig(BanditConfig):
"""Defines a config class from which an upper confidence bound bandit can be built.
Example:
>>> from ray.rllib.agents.bandit import BanditLinUCBConfig
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env=WheelBanditEnv)
>>> trainer.train()
"""
def __init__(self):
super().__init__(trainer_class=BanditLinUCBTrainer)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
self.exploration_config = {"type": "UpperConfidenceBound"}
# __sphinx_doc_end__
# fmt: on
class BanditLinTSTrainer(Trainer): class BanditLinTSTrainer(Trainer):
@ -35,15 +91,8 @@ class BanditLinTSTrainer(Trainer):
@classmethod @classmethod
@override(Trainer) @override(Trainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> BanditLinTSConfig:
config = Trainer.merge_trainer_configs( return BanditLinTSConfig().to_dict()
DEFAULT_CONFIG,
{
# Use ThompsonSampling exploration.
"exploration_config": {"type": "ThompsonSampling"}
},
)
return config
@override(Trainer) @override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@ -52,20 +101,14 @@ class BanditLinTSTrainer(Trainer):
elif config["framework"] == "tf2": elif config["framework"] == "tf2":
return BanditTFPolicy return BanditTFPolicy
else: else:
raise NotImplementedError() raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
class BanditLinUCBTrainer(Trainer): class BanditLinUCBTrainer(Trainer):
@classmethod @classmethod
@override(Trainer) @override(Trainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> BanditLinUCBConfig:
return Trainer.merge_trainer_configs( return BanditLinUCBConfig().to_dict()
DEFAULT_CONFIG,
{
# Use UpperConfidenceBound exploration.
"exploration_config": {"type": "UpperConfidenceBound"}
},
)
@override(Trainer) @override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@ -74,4 +117,21 @@ class BanditLinUCBTrainer(Trainer):
elif config["framework"] == "tf2": elif config["framework"] == "tf2":
return BanditTFPolicy return BanditTFPolicy
else: else:
raise NotImplementedError() raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
# Deprecated: Use ray.rllib.agents.bandit.BanditLinUCBConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(BanditLinUCBConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG",
new="ray.rllib.agents.bandit.bandit.BanditLin[UCB|TS]Config(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -1,7 +1,7 @@
import unittest import unittest
import ray import ray
import ray.rllib.agents.bandit.bandit as bandit from ray.rllib.agents.bandit import bandit
from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit
from ray.rllib.utils.test_utils import check_train_results, framework_iterator from ray.rllib.utils.test_utils import check_train_results, framework_iterator
@ -17,19 +17,17 @@ class TestBandits(unittest.TestCase):
def test_bandit_lin_ts_compilation(self): def test_bandit_lin_ts_compilation(self):
"""Test whether a BanditLinTSTrainer can be built on all frameworks.""" """Test whether a BanditLinTSTrainer can be built on all frameworks."""
config = { config = (
# Use a simple bandit-friendly env. bandit.BanditLinTSConfig()
"env": SimpleContextualBandit, .environment(env=SimpleContextualBandit)
"num_envs_per_worker": 2, # Test batched inference. .rollouts(num_rollout_workers=2, num_envs_per_worker=2)
"num_workers": 2, # Test distributed bandits. )
}
num_iterations = 5 num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"): for _ in framework_iterator(config, frameworks="torch"):
for train_batch_size in [1, 10]: for train_batch_size in [1, 10]:
config["train_batch_size"] = train_batch_size config.training(train_batch_size=train_batch_size)
trainer = bandit.BanditLinTSTrainer(config=config) trainer = config.build()
results = None results = None
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = trainer.train()
@ -41,18 +39,18 @@ class TestBandits(unittest.TestCase):
def test_bandit_lin_ucb_compilation(self): def test_bandit_lin_ucb_compilation(self):
"""Test whether a BanditLinUCBTrainer can be built on all frameworks.""" """Test whether a BanditLinUCBTrainer can be built on all frameworks."""
config = { config = (
# Use a simple bandit-friendly env. bandit.BanditLinUCBConfig()
"env": SimpleContextualBandit, .environment(env=SimpleContextualBandit)
"num_envs_per_worker": 2, # Test batched inference. .rollouts(num_envs_per_worker=2)
} )
num_iterations = 5 num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"): for _ in framework_iterator(config, frameworks="torch"):
for train_batch_size in [1, 10]: for train_batch_size in [1, 10]:
config["train_batch_size"] = train_batch_size config.training(train_batch_size=train_batch_size)
trainer = bandit.BanditLinUCBTrainer(config=config) trainer = config.build()
results = None results = None
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = trainer.train()