mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Issue 24075: Better error message for Bandit MultiDiscrete (suggest using our wrapper). (#24385)
This commit is contained in:
parent
fbbc9c33d6
commit
0c5ac3b9e8
2 changed files with 43 additions and 2 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import gym
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
@ -13,11 +14,13 @@ from ray.rllib.agents.bandit.bandit_tf_model import (
|
||||||
)
|
)
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
from ray.rllib.models.modelv2 import restore_original_dimensions
|
from ray.rllib.models.modelv2 import restore_original_dimensions
|
||||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||||
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||||
from ray.rllib.utils.tf_utils import make_tf_callable
|
from ray.rllib.utils.tf_utils import make_tf_callable
|
||||||
from ray.rllib.utils.typing import TensorType
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -64,6 +67,41 @@ class BanditPolicyOverrides:
|
||||||
self.learn_on_batch = learn_on_batch
|
self.learn_on_batch = learn_on_batch
|
||||||
|
|
||||||
|
|
||||||
|
def validate_spaces(
|
||||||
|
policy: Policy,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
config: TrainerConfigDict,
|
||||||
|
) -> None:
|
||||||
|
"""Validates the observation- and action spaces used for the Policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: The policy, whose spaces are being validated.
|
||||||
|
observation_space: The observation space to validate.
|
||||||
|
action_space: The action space to validate.
|
||||||
|
config: The Policy's config dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsupportedSpaceException: If one of the spaces is not supported.
|
||||||
|
"""
|
||||||
|
# Only support single Box or single Discrete spaces.
|
||||||
|
if not isinstance(action_space, gym.spaces.Discrete):
|
||||||
|
msg = (
|
||||||
|
f"Action space ({action_space}) of {policy} is not supported for "
|
||||||
|
f"Bandit algorithms. Must be `Discrete`."
|
||||||
|
)
|
||||||
|
# Hint at using the MultiDiscrete to Discrete wrapper for Bandits.
|
||||||
|
if isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||||
|
msg += (
|
||||||
|
" Try to wrap your environment with the "
|
||||||
|
"`ray.rllib.env.wrappers.recsim::"
|
||||||
|
"MultiDiscreteToDiscreteActionWrapper` class: `tune.register_env("
|
||||||
|
"[some str], lambda ctx: MultiDiscreteToDiscreteActionWrapper("
|
||||||
|
"[your gym env])); config = {'env': [some str]}`"
|
||||||
|
)
|
||||||
|
raise UnsupportedSpaceException(msg)
|
||||||
|
|
||||||
|
|
||||||
def make_model(policy, obs_space, action_space, config):
|
def make_model(policy, obs_space, action_space, config):
|
||||||
_, logit_dim = ModelCatalog.get_action_dist(
|
_, logit_dim = ModelCatalog.get_action_dist(
|
||||||
action_space, config["model"], framework="tf"
|
action_space, config["model"], framework="tf"
|
||||||
|
@ -112,6 +150,7 @@ def after_init(policy, *args):
|
||||||
BanditTFPolicy = build_tf_policy(
|
BanditTFPolicy = build_tf_policy(
|
||||||
name="BanditTFPolicy",
|
name="BanditTFPolicy",
|
||||||
get_default_config=lambda: ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG,
|
get_default_config=lambda: ray.rllib.agents.bandit.bandit.DEFAULT_CONFIG,
|
||||||
|
validate_spaces=validate_spaces,
|
||||||
make_model=make_model,
|
make_model=make_model,
|
||||||
loss_fn=None,
|
loss_fn=None,
|
||||||
mixins=[BanditPolicyOverrides],
|
mixins=[BanditPolicyOverrides],
|
||||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
from ray.rllib.agents.bandit.bandit_tf_policy import validate_spaces
|
||||||
from ray.rllib.agents.bandit.bandit_torch_model import (
|
from ray.rllib.agents.bandit.bandit_torch_model import (
|
||||||
DiscreteLinearModelThompsonSampling,
|
DiscreteLinearModelThompsonSampling,
|
||||||
DiscreteLinearModelUCB,
|
DiscreteLinearModelUCB,
|
||||||
|
@ -99,6 +100,7 @@ def init_cum_regret(policy, *args):
|
||||||
BanditTorchPolicy = build_policy_class(
|
BanditTorchPolicy = build_policy_class(
|
||||||
name="BanditTorchPolicy",
|
name="BanditTorchPolicy",
|
||||||
framework="torch",
|
framework="torch",
|
||||||
|
validate_spaces=validate_spaces,
|
||||||
loss_fn=None,
|
loss_fn=None,
|
||||||
after_init=init_cum_regret,
|
after_init=init_cum_regret,
|
||||||
make_model_and_action_dist=make_model_and_action_dist,
|
make_model_and_action_dist=make_model_and_action_dist,
|
||||||
|
|
Loading…
Add table
Reference in a new issue