[RLlib] Add an env wrapper so RecSim works with our Bandits agent. (#22028)

This commit is contained in:
Jun Gong 2022-02-02 03:15:38 -08:00 committed by GitHub
parent 87fe033f7b
commit 9c95b9a5fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 127 additions and 2 deletions

View file

@ -2767,6 +2767,14 @@ py_test(
srcs = ["examples/bandit/tune_lin_ucb_train_recommendation.py"],
)
py_test(
name = "examples/bandit/tune_lin_ucb_train_recsim_env",
main = "examples/bandit/tune_lin_ucb_train_recsim_env.py",
tags = ["team:ml", "examples", ],
size = "small",
srcs = ["examples/bandit/tune_lin_ucb_train_recsim_env.py"],
)
# --------------------------------------------------------------------
# examples/documentation directory
#

View file

@ -65,6 +65,49 @@ class RecSimObservationSpaceWrapper(gym.ObservationWrapper):
return new_obs
class RecSimObservationBanditWrapper(gym.ObservationWrapper):
"""Fix RecSim environment's observation format
RecSim's observations are keyed by document IDs, and nested under
"doc" key.
Our Bandits agent expects the observations to be flat 2D array
and under "item" key.
This environment wrapper converts obs into the right format.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
obs_space = self.env.observation_space
num_items = len(obs_space["doc"])
embedding_dim = next(iter(obs_space["doc"].values())).shape[-1]
self.observation_space = Dict(
OrderedDict(
[
("user", obs_space["user"]),
(
"item",
gym.spaces.Box(
low=-np.ones((num_items, embedding_dim)),
high=np.ones((num_items, embedding_dim)),
),
),
("response", obs_space["response"]),
]
)
)
self._sampled_obs = self.observation_space.sample()
def observation(self, obs):
new_obs = OrderedDict()
new_obs["user"] = obs["user"]
new_obs["item"] = np.vstack(list(obs["doc"].values()))
new_obs["response"] = obs["response"]
new_obs = convert_element_to_space_type(new_obs, self._sampled_obs)
return new_obs
class RecSimResetWrapper(gym.Wrapper):
"""Fix RecSim environment's reset() and close() function
@ -118,7 +161,9 @@ class MultiDiscreteToDiscreteActionWrapper(gym.ActionWrapper):
def recsim_gym_wrapper(
recsim_gym_env: gym.Env, convert_to_discrete_action_space: bool = False
recsim_gym_env: gym.Env,
convert_to_discrete_action_space: bool = False,
wrap_for_bandits: bool = False,
) -> gym.Env:
"""Makes sure a RecSim gym.Env can ba handled by RLlib.
@ -142,6 +187,8 @@ def recsim_gym_wrapper(
such as RLlib's DQN. If None, `convert_to_discrete_action_space`
may also be provided via the EnvContext (config) when creating an
actual env instance.
wrap_for_bandits: Bool indicating, whether this RecSim env should be
wrapped for use with our Bandits agent.
Returns:
An RLlib-ready gym.Env instance.
@ -150,6 +197,8 @@ def recsim_gym_wrapper(
env = RecSimObservationSpaceWrapper(env)
if convert_to_discrete_action_space:
env = MultiDiscreteToDiscreteActionWrapper(env)
if wrap_for_bandits:
env = RecSimObservationBanditWrapper(env)
return env
@ -186,6 +235,7 @@ def make_recsim_env(
"resample_documents": True,
"seed": 0,
"convert_to_discrete_action_space": False,
"wrap_for_bandits": False,
}
if env_ctx is None or isinstance(env_ctx, dict):
env_ctx = EnvContext(env_ctx or default_config, worker_index=0)
@ -210,7 +260,9 @@ def make_recsim_env(
# Fix observation space and - if necessary - convert to discrete
# action space (from multi-discrete).
self.env = recsim_gym_wrapper(
gym_env, env_ctx["convert_to_discrete_action_space"]
gym_env,
env_ctx["convert_to_discrete_action_space"],
env_ctx["wrap_for_bandits"],
)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space

View file

@ -28,6 +28,11 @@ class TestRecSimWrapper(unittest.TestCase):
new_obs, _, _, _ = env.step(action)
self.assertTrue(env.observation_space.contains(new_obs))
def test_bandits_observation_space_conversion(self):
env = InterestEvolutionRecSimEnv({"wrap_for_bandits": True})
# "item" of observation space is a Box space.
self.assertIsInstance(env.observation_space["item"], gym.spaces.Box)
def test_double_action_space_conversion_raises_exception(self):
env = InterestEvolutionRecSimEnv({"convert_to_discrete_action_space": True})
with self.assertRaises(UnsupportedSpaceException):

View file

@ -0,0 +1,60 @@
"""Example of using LinUCB on a RecSim environment. """
from matplotlib import pyplot as plt
import pandas as pd
import time
from ray import tune
import ray.rllib.examples.env.recsim_recommender_system_envs # noqa
if __name__ == "__main__":
ray.init()
config = {
# "RecSim-v1" is a pre-registered RecSim env.
# Alternatively, you can do:
# `from ray.rllib.examples.env.recsim_recommender_system_envs import ...`
# - LongTermSatisfactionRecSimEnv
# - InterestExplorationRecSimEnv
# - InterestEvolutionRecSimEnv
# Then: "env": [the imported RecSim class]
"env": "RecSim-v1",
"env_config": {
"convert_to_discrete_action_space": True,
"wrap_for_bandits": True,
},
}
# Actual training_iterations will be 10 * timesteps_per_iteration
# (100 by default) = 2,000
training_iterations = 10
print("Running training for %s time steps" % training_iterations)
start_time = time.time()
analysis = tune.run(
"BanditLinUCB",
config=config,
stop={"training_iteration": training_iterations},
num_samples=1,
checkpoint_at_end=False,
)
print("The trials took", time.time() - start_time, "seconds\n")
# Analyze cumulative regrets of the trials
frame = pd.DataFrame()
for key, df in analysis.trial_dataframes.items():
frame = frame.append(df, ignore_index=True)
x = frame.groupby("agent_timesteps_total")["episode_reward_mean"].aggregate(
["mean", "max", "min", "std"]
)
plt.plot(x["mean"])
plt.fill_between(
x.index, x["mean"] - x["std"], x["mean"] + x["std"], color="b", alpha=0.2
)
plt.title("Episode reward mean")
plt.xlabel("Training steps")
plt.show()