2020-03-26 13:41:16 -07:00
|
|
|
import logging
|
|
|
|
import time
|
|
|
|
from gym import spaces
|
2022-06-02 06:16:51 -07:00
|
|
|
|
|
|
|
import ray
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.bandit.bandit_torch_model import (
|
2020-03-26 13:41:16 -07:00
|
|
|
DiscreteLinearModelThompsonSampling,
|
|
|
|
DiscreteLinearModelUCB,
|
|
|
|
DiscreteLinearModel,
|
|
|
|
ParametricLinearModelThompsonSampling,
|
|
|
|
ParametricLinearModelUCB,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-26 13:41:16 -07:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
2020-09-06 10:58:00 +02:00
|
|
|
from ray.rllib.models.modelv2 import restore_original_dimensions
|
2020-03-26 13:41:16 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.utils.annotations import override
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
2020-03-26 13:41:16 -07:00
|
|
|
from ray.util.debug import log_once
|
2022-06-02 06:16:51 -07:00
|
|
|
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
2020-03-26 13:41:16 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2022-06-02 06:16:51 -07:00
|
|
|
class BanditTorchPolicy(TorchPolicyV2):
|
|
|
|
def __init__(self, observation_space, action_space, config):
|
|
|
|
config = dict(ray.rllib.algorithms.bandit.bandit.DEFAULT_CONFIG, **config)
|
|
|
|
|
|
|
|
TorchPolicyV2.__init__(
|
|
|
|
self,
|
|
|
|
observation_space,
|
|
|
|
action_space,
|
|
|
|
config,
|
|
|
|
max_seq_len=config["model"]["max_seq_len"],
|
|
|
|
)
|
|
|
|
self.regrets = []
|
|
|
|
|
|
|
|
@override(TorchPolicyV2)
|
|
|
|
def make_model_and_action_dist(self):
|
|
|
|
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
|
|
|
self.action_space, self.config["model"], framework="torch"
|
|
|
|
)
|
|
|
|
model_cls = DiscreteLinearModel
|
|
|
|
|
|
|
|
if hasattr(self.observation_space, "original_space"):
|
|
|
|
original_space = self.observation_space.original_space
|
|
|
|
else:
|
|
|
|
original_space = self.observation_space
|
|
|
|
|
|
|
|
exploration_config = self.config.get("exploration_config")
|
|
|
|
# Model is dependent on exploration strategy because of its implicitness
|
|
|
|
|
|
|
|
# TODO: Have a separate model catalogue for bandits
|
|
|
|
if exploration_config:
|
|
|
|
if exploration_config["type"] == "ThompsonSampling":
|
|
|
|
if isinstance(original_space, spaces.Dict):
|
|
|
|
assert (
|
|
|
|
"item" in original_space.spaces
|
|
|
|
), "Cannot find 'item' key in observation space"
|
|
|
|
model_cls = ParametricLinearModelThompsonSampling
|
|
|
|
else:
|
|
|
|
model_cls = DiscreteLinearModelThompsonSampling
|
|
|
|
elif exploration_config["type"] == "UpperConfidenceBound":
|
|
|
|
if isinstance(original_space, spaces.Dict):
|
|
|
|
assert (
|
|
|
|
"item" in original_space.spaces
|
|
|
|
), "Cannot find 'item' key in observation space"
|
|
|
|
model_cls = ParametricLinearModelUCB
|
|
|
|
else:
|
|
|
|
model_cls = DiscreteLinearModelUCB
|
|
|
|
|
|
|
|
model = model_cls(
|
|
|
|
self.observation_space,
|
|
|
|
self.action_space,
|
|
|
|
logit_dim,
|
|
|
|
self.config["model"],
|
|
|
|
name="LinearModel",
|
|
|
|
)
|
|
|
|
return model, dist_class
|
|
|
|
|
|
|
|
@override(TorchPolicyV2)
|
2020-03-26 13:41:16 -07:00
|
|
|
def learn_on_batch(self, postprocessed_batch):
|
|
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
|
|
|
unflattened_obs = restore_original_dimensions(
|
|
|
|
train_batch[SampleBatch.CUR_OBS], self.observation_space, self.framework
|
|
|
|
)
|
|
|
|
|
|
|
|
info = {}
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
self.model.partial_fit(
|
|
|
|
unflattened_obs,
|
|
|
|
train_batch[SampleBatch.REWARDS],
|
|
|
|
train_batch[SampleBatch.ACTIONS],
|
|
|
|
)
|
|
|
|
|
|
|
|
infos = postprocessed_batch["infos"]
|
|
|
|
if "regret" in infos[0]:
|
|
|
|
regret = sum(row["infos"]["regret"] for row in postprocessed_batch.rows())
|
|
|
|
self.regrets.append(regret)
|
|
|
|
info["cumulative_regret"] = sum(self.regrets)
|
|
|
|
else:
|
|
|
|
if log_once("no_regrets"):
|
|
|
|
logger.warning(
|
|
|
|
"The env did not report `regret` values in "
|
|
|
|
"its `info` return, ignoring."
|
|
|
|
)
|
|
|
|
info["update_latency"] = time.time() - start
|
|
|
|
return {LEARNER_STATS_KEY: info}
|