mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Add an env wrapper so RecSim works with our Bandits agent. (#22028)
This commit is contained in:
parent
87fe033f7b
commit
9c95b9a5fa
4 changed files with 127 additions and 2 deletions
|
@ -2767,6 +2767,14 @@ py_test(
|
|||
srcs = ["examples/bandit/tune_lin_ucb_train_recommendation.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/bandit/tune_lin_ucb_train_recsim_env",
|
||||
main = "examples/bandit/tune_lin_ucb_train_recsim_env.py",
|
||||
tags = ["team:ml", "examples", ],
|
||||
size = "small",
|
||||
srcs = ["examples/bandit/tune_lin_ucb_train_recsim_env.py"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# examples/documentation directory
|
||||
#
|
||||
|
|
56
rllib/env/wrappers/recsim.py
vendored
56
rllib/env/wrappers/recsim.py
vendored
|
@ -65,6 +65,49 @@ class RecSimObservationSpaceWrapper(gym.ObservationWrapper):
|
|||
return new_obs
|
||||
|
||||
|
||||
class RecSimObservationBanditWrapper(gym.ObservationWrapper):
|
||||
"""Fix RecSim environment's observation format
|
||||
|
||||
RecSim's observations are keyed by document IDs, and nested under
|
||||
"doc" key.
|
||||
Our Bandits agent expects the observations to be flat 2D array
|
||||
and under "item" key.
|
||||
|
||||
This environment wrapper converts obs into the right format.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
super().__init__(env)
|
||||
obs_space = self.env.observation_space
|
||||
|
||||
num_items = len(obs_space["doc"])
|
||||
embedding_dim = next(iter(obs_space["doc"].values())).shape[-1]
|
||||
self.observation_space = Dict(
|
||||
OrderedDict(
|
||||
[
|
||||
("user", obs_space["user"]),
|
||||
(
|
||||
"item",
|
||||
gym.spaces.Box(
|
||||
low=-np.ones((num_items, embedding_dim)),
|
||||
high=np.ones((num_items, embedding_dim)),
|
||||
),
|
||||
),
|
||||
("response", obs_space["response"]),
|
||||
]
|
||||
)
|
||||
)
|
||||
self._sampled_obs = self.observation_space.sample()
|
||||
|
||||
def observation(self, obs):
|
||||
new_obs = OrderedDict()
|
||||
new_obs["user"] = obs["user"]
|
||||
new_obs["item"] = np.vstack(list(obs["doc"].values()))
|
||||
new_obs["response"] = obs["response"]
|
||||
new_obs = convert_element_to_space_type(new_obs, self._sampled_obs)
|
||||
return new_obs
|
||||
|
||||
|
||||
class RecSimResetWrapper(gym.Wrapper):
|
||||
"""Fix RecSim environment's reset() and close() function
|
||||
|
||||
|
@ -118,7 +161,9 @@ class MultiDiscreteToDiscreteActionWrapper(gym.ActionWrapper):
|
|||
|
||||
|
||||
def recsim_gym_wrapper(
|
||||
recsim_gym_env: gym.Env, convert_to_discrete_action_space: bool = False
|
||||
recsim_gym_env: gym.Env,
|
||||
convert_to_discrete_action_space: bool = False,
|
||||
wrap_for_bandits: bool = False,
|
||||
) -> gym.Env:
|
||||
"""Makes sure a RecSim gym.Env can ba handled by RLlib.
|
||||
|
||||
|
@ -142,6 +187,8 @@ def recsim_gym_wrapper(
|
|||
such as RLlib's DQN. If None, `convert_to_discrete_action_space`
|
||||
may also be provided via the EnvContext (config) when creating an
|
||||
actual env instance.
|
||||
wrap_for_bandits: Bool indicating, whether this RecSim env should be
|
||||
wrapped for use with our Bandits agent.
|
||||
|
||||
Returns:
|
||||
An RLlib-ready gym.Env instance.
|
||||
|
@ -150,6 +197,8 @@ def recsim_gym_wrapper(
|
|||
env = RecSimObservationSpaceWrapper(env)
|
||||
if convert_to_discrete_action_space:
|
||||
env = MultiDiscreteToDiscreteActionWrapper(env)
|
||||
if wrap_for_bandits:
|
||||
env = RecSimObservationBanditWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
|
@ -186,6 +235,7 @@ def make_recsim_env(
|
|||
"resample_documents": True,
|
||||
"seed": 0,
|
||||
"convert_to_discrete_action_space": False,
|
||||
"wrap_for_bandits": False,
|
||||
}
|
||||
if env_ctx is None or isinstance(env_ctx, dict):
|
||||
env_ctx = EnvContext(env_ctx or default_config, worker_index=0)
|
||||
|
@ -210,7 +260,9 @@ def make_recsim_env(
|
|||
# Fix observation space and - if necessary - convert to discrete
|
||||
# action space (from multi-discrete).
|
||||
self.env = recsim_gym_wrapper(
|
||||
gym_env, env_ctx["convert_to_discrete_action_space"]
|
||||
gym_env,
|
||||
env_ctx["convert_to_discrete_action_space"],
|
||||
env_ctx["wrap_for_bandits"],
|
||||
)
|
||||
self.observation_space = self.env.observation_space
|
||||
self.action_space = self.env.action_space
|
||||
|
|
|
@ -28,6 +28,11 @@ class TestRecSimWrapper(unittest.TestCase):
|
|||
new_obs, _, _, _ = env.step(action)
|
||||
self.assertTrue(env.observation_space.contains(new_obs))
|
||||
|
||||
def test_bandits_observation_space_conversion(self):
|
||||
env = InterestEvolutionRecSimEnv({"wrap_for_bandits": True})
|
||||
# "item" of observation space is a Box space.
|
||||
self.assertIsInstance(env.observation_space["item"], gym.spaces.Box)
|
||||
|
||||
def test_double_action_space_conversion_raises_exception(self):
|
||||
env = InterestEvolutionRecSimEnv({"convert_to_discrete_action_space": True})
|
||||
with self.assertRaises(UnsupportedSpaceException):
|
||||
|
|
60
rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py
Normal file
60
rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
"""Example of using LinUCB on a RecSim environment. """
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
import pandas as pd
|
||||
import time
|
||||
|
||||
from ray import tune
|
||||
import ray.rllib.examples.env.recsim_recommender_system_envs # noqa
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
config = {
|
||||
# "RecSim-v1" is a pre-registered RecSim env.
|
||||
# Alternatively, you can do:
|
||||
# `from ray.rllib.examples.env.recsim_recommender_system_envs import ...`
|
||||
# - LongTermSatisfactionRecSimEnv
|
||||
# - InterestExplorationRecSimEnv
|
||||
# - InterestEvolutionRecSimEnv
|
||||
# Then: "env": [the imported RecSim class]
|
||||
"env": "RecSim-v1",
|
||||
"env_config": {
|
||||
"convert_to_discrete_action_space": True,
|
||||
"wrap_for_bandits": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Actual training_iterations will be 10 * timesteps_per_iteration
|
||||
# (100 by default) = 2,000
|
||||
training_iterations = 10
|
||||
|
||||
print("Running training for %s time steps" % training_iterations)
|
||||
|
||||
start_time = time.time()
|
||||
analysis = tune.run(
|
||||
"BanditLinUCB",
|
||||
config=config,
|
||||
stop={"training_iteration": training_iterations},
|
||||
num_samples=1,
|
||||
checkpoint_at_end=False,
|
||||
)
|
||||
|
||||
print("The trials took", time.time() - start_time, "seconds\n")
|
||||
|
||||
# Analyze cumulative regrets of the trials
|
||||
frame = pd.DataFrame()
|
||||
for key, df in analysis.trial_dataframes.items():
|
||||
frame = frame.append(df, ignore_index=True)
|
||||
x = frame.groupby("agent_timesteps_total")["episode_reward_mean"].aggregate(
|
||||
["mean", "max", "min", "std"]
|
||||
)
|
||||
|
||||
plt.plot(x["mean"])
|
||||
plt.fill_between(
|
||||
x.index, x["mean"] - x["std"], x["mean"] + x["std"], color="b", alpha=0.2
|
||||
)
|
||||
plt.title("Episode reward mean")
|
||||
plt.xlabel("Training steps")
|
||||
plt.show()
|
Loading…
Add table
Reference in a new issue