mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Bandits (torch) Policy sub-class. (#25254)
Co-authored-by: Steven Morad <smorad@anyscale.com>
This commit is contained in:
parent
6fe91885b0
commit
f781622f86
1 changed files with 60 additions and 61 deletions
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from gym import spaces
|
||||
from ray.rllib.algorithms.bandit.bandit_tf_policy import validate_spaces
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms.bandit.bandit_torch_model import (
|
||||
DiscreteLinearModelThompsonSampling,
|
||||
DiscreteLinearModelUCB,
|
||||
|
@ -12,18 +12,72 @@ from ray.rllib.algorithms.bandit.bandit_torch_model import (
|
|||
)
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import restore_original_dimensions
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BanditPolicyOverrides:
|
||||
@override(TorchPolicy)
|
||||
class BanditTorchPolicy(TorchPolicyV2):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.algorithms.bandit.bandit.DEFAULT_CONFIG, **config)
|
||||
|
||||
TorchPolicyV2.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
)
|
||||
self.regrets = []
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def make_model_and_action_dist(self):
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, self.config["model"], framework="torch"
|
||||
)
|
||||
model_cls = DiscreteLinearModel
|
||||
|
||||
if hasattr(self.observation_space, "original_space"):
|
||||
original_space = self.observation_space.original_space
|
||||
else:
|
||||
original_space = self.observation_space
|
||||
|
||||
exploration_config = self.config.get("exploration_config")
|
||||
# Model is dependent on exploration strategy because of its implicitness
|
||||
|
||||
# TODO: Have a separate model catalogue for bandits
|
||||
if exploration_config:
|
||||
if exploration_config["type"] == "ThompsonSampling":
|
||||
if isinstance(original_space, spaces.Dict):
|
||||
assert (
|
||||
"item" in original_space.spaces
|
||||
), "Cannot find 'item' key in observation space"
|
||||
model_cls = ParametricLinearModelThompsonSampling
|
||||
else:
|
||||
model_cls = DiscreteLinearModelThompsonSampling
|
||||
elif exploration_config["type"] == "UpperConfidenceBound":
|
||||
if isinstance(original_space, spaces.Dict):
|
||||
assert (
|
||||
"item" in original_space.spaces
|
||||
), "Cannot find 'item' key in observation space"
|
||||
model_cls = ParametricLinearModelUCB
|
||||
else:
|
||||
model_cls = DiscreteLinearModelUCB
|
||||
|
||||
model = model_cls(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
name="LinearModel",
|
||||
)
|
||||
return model, dist_class
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
unflattened_obs = restore_original_dimensions(
|
||||
|
@ -52,58 +106,3 @@ class BanditPolicyOverrides:
|
|||
)
|
||||
info["update_latency"] = time.time() - start
|
||||
return {LEARNER_STATS_KEY: info}
|
||||
|
||||
|
||||
def make_model_and_action_dist(policy, obs_space, action_space, config):
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, config["model"], framework="torch"
|
||||
)
|
||||
model_cls = DiscreteLinearModel
|
||||
|
||||
if hasattr(obs_space, "original_space"):
|
||||
original_space = obs_space.original_space
|
||||
else:
|
||||
original_space = obs_space
|
||||
|
||||
exploration_config = config.get("exploration_config")
|
||||
# Model is dependent on exploration strategy because of its implicitness
|
||||
|
||||
# TODO: Have a separate model catalogue for bandits
|
||||
if exploration_config:
|
||||
if exploration_config["type"] == "ThompsonSampling":
|
||||
if isinstance(original_space, spaces.Dict):
|
||||
assert (
|
||||
"item" in original_space.spaces
|
||||
), "Cannot find 'item' key in observation space"
|
||||
model_cls = ParametricLinearModelThompsonSampling
|
||||
else:
|
||||
model_cls = DiscreteLinearModelThompsonSampling
|
||||
elif exploration_config["type"] == "UpperConfidenceBound":
|
||||
if isinstance(original_space, spaces.Dict):
|
||||
assert (
|
||||
"item" in original_space.spaces
|
||||
), "Cannot find 'item' key in observation space"
|
||||
model_cls = ParametricLinearModelUCB
|
||||
else:
|
||||
model_cls = DiscreteLinearModelUCB
|
||||
|
||||
model = model_cls(
|
||||
obs_space, action_space, logit_dim, config["model"], name="LinearModel"
|
||||
)
|
||||
return model, dist_class
|
||||
|
||||
|
||||
def init_cum_regret(policy, *args):
|
||||
policy.regrets = []
|
||||
|
||||
|
||||
BanditTorchPolicy = build_policy_class(
|
||||
name="BanditTorchPolicy",
|
||||
framework="torch",
|
||||
validate_spaces=validate_spaces,
|
||||
loss_fn=None,
|
||||
after_init=init_cum_regret,
|
||||
make_model_and_action_dist=make_model_and_action_dist,
|
||||
optimizer_fn=lambda policy, config: None, # Pass a dummy optimizer
|
||||
mixins=[BanditPolicyOverrides],
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue