mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
265 lines
10 KiB
Python
265 lines
10 KiB
Python
"""Tools and utils to create RLlib-ready recommender system envs using RecSim.
|
|
|
|
For examples on how to generate a RecSim env class (usable in RLlib):
|
|
See ray.rllib.examples.env.recommender_system_envs_with_recsim.py
|
|
|
|
For more information on google's RecSim itself:
|
|
https://github.com/google-research/recsim
|
|
"""
|
|
|
|
from collections import OrderedDict
|
|
import gym
|
|
from gym.spaces import Dict, Discrete, MultiDiscrete
|
|
import numpy as np
|
|
from recsim.document import AbstractDocumentSampler
|
|
from recsim.simulator import environment, recsim_gym
|
|
from recsim.user import AbstractUserModel, AbstractResponse
|
|
from typing import Callable, List, Optional, Type
|
|
|
|
from ray.rllib.env.env_context import EnvContext
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
|
from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
|
|
|
|
|
|
class RecSimObservationSpaceWrapper(gym.ObservationWrapper):
|
|
"""Fix RecSim environment's observation space
|
|
|
|
In RecSim's observation spaces, the "doc" field is a dictionary keyed by
|
|
document IDs. Those IDs are changing every step, thus generating a
|
|
different observation space in each time. This causes issues for RLlib
|
|
because it expects the observation space to remain the same across steps.
|
|
|
|
This environment wrapper fixes that by reindexing the documents by their
|
|
positions in the list.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env):
|
|
super().__init__(env)
|
|
obs_space = self.env.observation_space
|
|
doc_space = Dict(
|
|
OrderedDict(
|
|
[
|
|
(str(k), doc)
|
|
for k, (_, doc) in enumerate(obs_space["doc"].spaces.items())
|
|
]
|
|
)
|
|
)
|
|
self.observation_space = Dict(
|
|
OrderedDict(
|
|
[
|
|
("user", obs_space["user"]),
|
|
("doc", doc_space),
|
|
("response", obs_space["response"]),
|
|
]
|
|
)
|
|
)
|
|
self._sampled_obs = self.observation_space.sample()
|
|
|
|
def observation(self, obs):
|
|
new_obs = OrderedDict()
|
|
new_obs["user"] = obs["user"]
|
|
new_obs["doc"] = {str(k): v for k, (_, v) in enumerate(obs["doc"].items())}
|
|
new_obs["response"] = obs["response"]
|
|
new_obs = convert_element_to_space_type(new_obs, self._sampled_obs)
|
|
return new_obs
|
|
|
|
|
|
class RecSimObservationBanditWrapper(gym.ObservationWrapper):
|
|
"""Fix RecSim environment's observation format
|
|
|
|
RecSim's observations are keyed by document IDs, and nested under
|
|
"doc" key.
|
|
Our Bandits agent expects the observations to be flat 2D array
|
|
and under "item" key.
|
|
|
|
This environment wrapper converts obs into the right format.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env):
|
|
super().__init__(env)
|
|
obs_space = self.env.observation_space
|
|
|
|
num_items = len(obs_space["doc"])
|
|
embedding_dim = next(iter(obs_space["doc"].values())).shape[-1]
|
|
self.observation_space = Dict(
|
|
OrderedDict(
|
|
[
|
|
(
|
|
"item",
|
|
gym.spaces.Box(
|
|
low=-1.0, high=1.0, shape=(num_items, embedding_dim)
|
|
),
|
|
),
|
|
]
|
|
)
|
|
)
|
|
self._sampled_obs = self.observation_space.sample()
|
|
|
|
def observation(self, obs):
|
|
new_obs = OrderedDict()
|
|
new_obs["item"] = np.vstack(list(obs["doc"].values()))
|
|
new_obs = convert_element_to_space_type(new_obs, self._sampled_obs)
|
|
return new_obs
|
|
|
|
|
|
class RecSimResetWrapper(gym.Wrapper):
|
|
"""Fix RecSim environment's reset() and close() function
|
|
|
|
RecSim's reset() function returns an observation without the "response"
|
|
field, breaking RLlib's check. This wrapper fixes that by assigning a
|
|
random "response".
|
|
|
|
RecSim's close() function raises NotImplementedError. We change the
|
|
behavior to doing nothing.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env):
|
|
super().__init__(env)
|
|
self._sampled_obs = self.env.observation_space.sample()
|
|
|
|
def reset(self):
|
|
obs = super().reset()
|
|
obs["response"] = self.env.observation_space["response"].sample()
|
|
obs = convert_element_to_space_type(obs, self._sampled_obs)
|
|
return obs
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
class MultiDiscreteToDiscreteActionWrapper(gym.ActionWrapper):
|
|
"""Convert the action space from MultiDiscrete to Discrete
|
|
|
|
At this moment, RLlib's DQN algorithms only work on Discrete action space.
|
|
This wrapper allows us to apply DQN algorithms to the RecSim environment.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env):
|
|
super().__init__(env)
|
|
|
|
if not isinstance(env.action_space, MultiDiscrete):
|
|
raise UnsupportedSpaceException(
|
|
f"Action space {env.action_space} "
|
|
f"is not supported by {self.__class__.__name__}"
|
|
)
|
|
self.action_space_dimensions = env.action_space.nvec
|
|
self.action_space = Discrete(np.prod(self.action_space_dimensions))
|
|
|
|
def action(self, action: int) -> List[int]:
|
|
"""Convert a Discrete action to a MultiDiscrete action"""
|
|
multi_action = [None] * len(self.action_space_dimensions)
|
|
for idx, n in enumerate(self.action_space_dimensions):
|
|
action, dim_action = divmod(action, n)
|
|
multi_action[idx] = dim_action
|
|
return multi_action
|
|
|
|
|
|
def recsim_gym_wrapper(
|
|
recsim_gym_env: gym.Env,
|
|
convert_to_discrete_action_space: bool = False,
|
|
wrap_for_bandits: bool = False,
|
|
) -> gym.Env:
|
|
"""Makes sure a RecSim gym.Env can ba handled by RLlib.
|
|
|
|
In RecSim's observation spaces, the "doc" field is a dictionary keyed by
|
|
document IDs. Those IDs are changing every step, thus generating a
|
|
different observation space in each time. This causes issues for RLlib
|
|
because it expects the observation space to remain the same across steps.
|
|
|
|
Also, RecSim's reset() function returns an observation without the
|
|
"response" field, breaking RLlib's check. This wrapper fixes that by
|
|
assigning a random "response".
|
|
|
|
Args:
|
|
recsim_gym_env: The RecSim gym.Env instance. Usually resulting from a
|
|
raw RecSim env having been passed through RecSim's utility function:
|
|
`recsim.simulator.recsim_gym.RecSimGymEnv()`.
|
|
convert_to_discrete_action_space: Optional bool indicating, whether
|
|
the action space of the created env class should be Discrete
|
|
(rather than MultiDiscrete, even if slate size > 1). This is useful
|
|
for algorithms that don't support MultiDiscrete action spaces,
|
|
such as RLlib's DQN. If None, `convert_to_discrete_action_space`
|
|
may also be provided via the EnvContext (config) when creating an
|
|
actual env instance.
|
|
wrap_for_bandits: Bool indicating, whether this RecSim env should be
|
|
wrapped for use with our Bandits agent.
|
|
|
|
Returns:
|
|
An RLlib-ready gym.Env instance.
|
|
"""
|
|
env = RecSimResetWrapper(recsim_gym_env)
|
|
env = RecSimObservationSpaceWrapper(env)
|
|
if convert_to_discrete_action_space:
|
|
env = MultiDiscreteToDiscreteActionWrapper(env)
|
|
if wrap_for_bandits:
|
|
env = RecSimObservationBanditWrapper(env)
|
|
return env
|
|
|
|
|
|
def make_recsim_env(
|
|
recsim_user_model_creator: Callable[[EnvContext], AbstractUserModel],
|
|
recsim_document_sampler_creator: Callable[[EnvContext], AbstractDocumentSampler],
|
|
reward_aggregator: Callable[[List[AbstractResponse]], float],
|
|
) -> Type[gym.Env]:
|
|
"""Creates a RLlib-ready gym.Env class given RecSim user and doc models.
|
|
|
|
See https://github.com/google-research/recsim for more information on how to
|
|
build the required components from scratch in python using RecSim.
|
|
|
|
Args:
|
|
recsim_user_model_creator: A callable taking an EnvContext and returning
|
|
a RecSim AbstractUserModel instance to use.
|
|
recsim_document_sampler_creator: A callable taking an EnvContext and
|
|
returning a RecSim AbstractDocumentSampler
|
|
to use. This will include a AbstractDocument as well.
|
|
reward_aggregator: Callable taking a list of RecSim
|
|
AbstractResponse instances and returning a float (aggregated
|
|
reward).
|
|
|
|
Returns:
|
|
An RLlib-ready gym.Env class to use inside a Trainer.
|
|
"""
|
|
|
|
class _RecSimEnv(gym.Wrapper):
|
|
def __init__(self, config: Optional[EnvContext] = None):
|
|
|
|
# Override with default values, in case they are not set by the user.
|
|
default_config = {
|
|
"num_candidates": 10,
|
|
"slate_size": 2,
|
|
"resample_documents": True,
|
|
"seed": 0,
|
|
"convert_to_discrete_action_space": False,
|
|
"wrap_for_bandits": False,
|
|
}
|
|
if config is None or isinstance(config, dict):
|
|
config = EnvContext(config or default_config, worker_index=0)
|
|
config.set_defaults(default_config)
|
|
|
|
# Create the RecSim user model instance.
|
|
recsim_user_model = recsim_user_model_creator(config)
|
|
# Create the RecSim document sampler instance.
|
|
recsim_document_sampler = recsim_document_sampler_creator(config)
|
|
|
|
# Create a raw RecSim environment (not yet a gym.Env!).
|
|
raw_recsim_env = environment.SingleUserEnvironment(
|
|
recsim_user_model,
|
|
recsim_document_sampler,
|
|
config["num_candidates"],
|
|
config["slate_size"],
|
|
resample_documents=config["resample_documents"],
|
|
)
|
|
# Convert raw RecSim env to a gym.Env.
|
|
gym_env = recsim_gym.RecSimGymEnv(raw_recsim_env, reward_aggregator)
|
|
|
|
# Fix observation space and - if necessary - convert to discrete
|
|
# action space (from multi-discrete).
|
|
env = recsim_gym_wrapper(
|
|
gym_env,
|
|
config["convert_to_discrete_action_space"],
|
|
config["wrap_for_bandits"],
|
|
)
|
|
# Call the super (Wrapper constructor) passing it the created env.
|
|
super().__init__(env=env)
|
|
|
|
return _RecSimEnv
|