[RLlib] Issue 24075: Better error message for Bandit MultiDiscrete (suggest using our wrapper). (#24385)

This commit is contained in:
Sven Mika 2022-05-02 21:14:08 +02:00 committed by GitHub
parent fbbc9c33d6
commit 0c5ac3b9e8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 2 deletions

View file

@ -1,3 +1,4 @@
import gym
import logging import logging
import time import time
from typing import Dict from typing import Dict
@ -13,11 +14,13 @@ from ray.rllib.agents.bandit.bandit_tf_model import (
) )
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import restore_original_dimensions from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.tf_utils import make_tf_callable from ray.rllib.utils.tf_utils import make_tf_callable
from ray.rllib.utils.typing import TensorType from ray.rllib.utils.typing import TensorType, TrainerConfigDict
from ray.util.debug import log_once from ray.util.debug import log_once
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +67,41 @@ class BanditPolicyOverrides:
self.learn_on_batch = learn_on_batch self.learn_on_batch = learn_on_batch
def validate_spaces(
policy: Policy,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
"""Validates the observation- and action spaces used for the Policy.
Args:
policy: The policy, whose spaces are being validated.
observation_space: The observation space to validate.
action_space: The action space to validate.
config: The Policy's config dict.
Raises:
UnsupportedSpaceException: If one of the spaces is not supported.
"""
# Only support single Box or single Discrete spaces.
if not isinstance(action_space, gym.spaces.Discrete):
msg = (
f"Action space ({action_space}) of {policy} is not supported for "
f"Bandit algorithms. Must be `Discrete`."
)
# Hint at using the MultiDiscrete to Discrete wrapper for Bandits.
if isinstance(action_space, gym.spaces.MultiDiscrete):
msg += (
" Try to wrap your environment with the "
"`ray.rllib.env.wrappers.recsim::"
"MultiDiscreteToDiscreteActionWrapper` class: `tune.register_env("
"[some str], lambda ctx: MultiDiscreteToDiscreteActionWrapper("
"[your gym env])); config = {'env': [some str]}`"
)
raise UnsupportedSpaceException(msg)
def make_model(policy, obs_space, action_space, config): def make_model(policy, obs_space, action_space, config):
_, logit_dim = ModelCatalog.get_action_dist( _, logit_dim = ModelCatalog.get_action_dist(
action_space, config["model"], framework="tf" action_space, config["model"], framework="tf"
@ -112,6 +150,7 @@ def after_init(policy, *args):
BanditTFPolicy = build_tf_policy( BanditTFPolicy = build_tf_policy(
name="BanditTFPolicy", name="BanditTFPolicy",
get_default_config=lambda: ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG,
validate_spaces=validate_spaces,
make_model=make_model, make_model=make_model,
loss_fn=None, loss_fn=None,
mixins=[BanditPolicyOverrides], mixins=[BanditPolicyOverrides],

View file

@ -2,6 +2,7 @@ import logging
import time import time
from gym import spaces from gym import spaces
from ray.rllib.agents.bandit.bandit_tf_policy import validate_spaces
from ray.rllib.agents.bandit.bandit_torch_model import ( from ray.rllib.agents.bandit.bandit_torch_model import (
DiscreteLinearModelThompsonSampling, DiscreteLinearModelThompsonSampling,
DiscreteLinearModelUCB, DiscreteLinearModelUCB,
@ -99,6 +100,7 @@ def init_cum_regret(policy, *args):
BanditTorchPolicy = build_policy_class( BanditTorchPolicy = build_policy_class(
name="BanditTorchPolicy", name="BanditTorchPolicy",
framework="torch", framework="torch",
validate_spaces=validate_spaces,
loss_fn=None, loss_fn=None,
after_init=init_cum_regret, after_init=init_cum_regret,
make_model_and_action_dist=make_model_and_action_dist, make_model_and_action_dist=make_model_and_action_dist,