2020-03-26 13:41:16 -07:00
|
|
|
""" Example of using LinUCB on a recommendation environment with parametric
|
|
|
|
actions. """
|
|
|
|
|
2022-03-21 08:55:55 -07:00
|
|
|
import argparse
|
2020-03-26 13:41:16 -07:00
|
|
|
from matplotlib import pyplot as plt
|
2022-01-27 13:58:12 +01:00
|
|
|
import os
|
2020-03-26 13:41:16 -07:00
|
|
|
import pandas as pd
|
2022-01-27 13:58:12 +01:00
|
|
|
import time
|
2020-03-26 13:41:16 -07:00
|
|
|
|
2022-02-17 22:32:26 +01:00
|
|
|
import ray
|
2020-03-26 13:41:16 -07:00
|
|
|
from ray import tune
|
2022-02-24 13:43:41 -08:00
|
|
|
from ray.tune import register_env
|
|
|
|
from ray.rllib.env.wrappers.recsim import (
|
|
|
|
MultiDiscreteToDiscreteActionWrapper,
|
|
|
|
RecSimObservationBanditWrapper,
|
|
|
|
)
|
|
|
|
from ray.rllib.examples.env.bandit_envs_recommender_system import (
|
|
|
|
ParametricRecSys,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Because ParametricRecSys follows RecSim's API, we have to wrap it before
|
|
|
|
# it can work with our Bandits agent.
|
|
|
|
register_env(
|
|
|
|
"ParametricRecSysEnv",
|
|
|
|
lambda cfg: MultiDiscreteToDiscreteActionWrapper(
|
|
|
|
RecSimObservationBanditWrapper(ParametricRecSys(**cfg))
|
|
|
|
),
|
|
|
|
)
|
2020-03-26 13:41:16 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-03-21 08:55:55 -07:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
"--framework",
|
|
|
|
choices=["tf2", "torch"],
|
|
|
|
default="torch",
|
|
|
|
help="The DL framework specifier.",
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(f"Running with following CLI args: {args}")
|
|
|
|
|
2022-02-17 22:32:26 +01:00
|
|
|
# Temp fix to avoid OMP conflict.
|
2020-03-26 13:41:16 -07:00
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
|
|
|
|
2022-02-17 22:32:26 +01:00
|
|
|
ray.init()
|
|
|
|
|
2022-01-27 13:58:12 +01:00
|
|
|
config = {
|
2022-03-21 08:55:55 -07:00
|
|
|
"framework": args.framework,
|
|
|
|
"eager_tracing": (args.framework == "tf2"),
|
2022-02-24 13:43:41 -08:00
|
|
|
"env": "ParametricRecSysEnv",
|
|
|
|
"env_config": {
|
|
|
|
"embedding_size": 20,
|
|
|
|
"num_docs_to_select_from": 10,
|
|
|
|
"slate_size": 1,
|
|
|
|
"num_docs_in_db": 100,
|
|
|
|
"num_users_in_db": 1,
|
|
|
|
"user_time_budget": 1.0,
|
|
|
|
},
|
2022-02-17 22:32:26 +01:00
|
|
|
"num_envs_per_worker": 2, # Test with batched inference.
|
2022-02-24 13:43:41 -08:00
|
|
|
"evaluation_interval": 20,
|
|
|
|
"evaluation_duration": 100,
|
|
|
|
"evaluation_duration_unit": "episodes",
|
|
|
|
"simple_optimizer": True,
|
2022-01-27 13:58:12 +01:00
|
|
|
}
|
2020-03-26 13:41:16 -07:00
|
|
|
|
2021-06-15 13:30:31 +02:00
|
|
|
# Actual training_iterations will be 10 * timesteps_per_iteration
|
2020-03-26 13:41:16 -07:00
|
|
|
# (100 by default) = 2,000
|
2021-06-15 13:30:31 +02:00
|
|
|
training_iterations = 10
|
2020-03-26 13:41:16 -07:00
|
|
|
|
|
|
|
print("Running training for %s time steps" % training_iterations)
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
analysis = tune.run(
|
2022-01-27 13:58:12 +01:00
|
|
|
"BanditLinUCB",
|
|
|
|
config=config,
|
2020-03-26 13:41:16 -07:00
|
|
|
stop={"training_iteration": training_iterations},
|
2021-06-15 13:30:31 +02:00
|
|
|
num_samples=2,
|
2022-01-29 18:41:57 -08:00
|
|
|
checkpoint_at_end=False,
|
|
|
|
)
|
2020-03-26 13:41:16 -07:00
|
|
|
|
|
|
|
print("The trials took", time.time() - start_time, "seconds\n")
|
|
|
|
|
|
|
|
# Analyze cumulative regrets of the trials
|
|
|
|
frame = pd.DataFrame()
|
|
|
|
for key, df in analysis.trial_dataframes.items():
|
|
|
|
frame = frame.append(df, ignore_index=True)
|
2022-01-29 18:41:57 -08:00
|
|
|
x = frame.groupby("agent_timesteps_total")["episode_reward_mean"].aggregate(
|
|
|
|
["mean", "max", "min", "std"]
|
|
|
|
)
|
2020-03-26 13:41:16 -07:00
|
|
|
|
|
|
|
plt.plot(x["mean"])
|
|
|
|
plt.fill_between(
|
2022-01-29 18:41:57 -08:00
|
|
|
x.index, x["mean"] - x["std"], x["mean"] + x["std"], color="b", alpha=0.2
|
|
|
|
)
|
2021-06-15 13:30:31 +02:00
|
|
|
plt.title("Episode reward mean")
|
2020-03-26 13:41:16 -07:00
|
|
|
plt.xlabel("Training steps")
|
|
|
|
plt.show()
|