ray/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

100 lines
2.8 KiB
Python
Raw Normal View History

""" Example of using Linear Thompson Sampling on WheelBandit environment.
For more information on WheelBandit, see https://arxiv.org/abs/1802.09127 .
"""
2022-03-21 08:55:55 -07:00
import argparse
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import time
import ray
from ray import tune
from ray.rllib.algorithms.bandit.bandit import BanditLinTSTrainer
from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
def plot_model_weights(means, covs, ax):
fmts = ["bo", "ro", "yx", "k+", "gx"]
labels = ["arm{}".format(i) for i in range(5)]
ax.set_title("Weights distributions of arms")
for i in range(0, 5):
x, y = np.random.multivariate_normal(means[i] / 30, covs[i], 5000).T
ax.plot(x, y, fmts[i], label=labels[i])
ax.set_aspect("equal")
ax.grid(True, which="both")
ax.axhline(y=0, color="k")
ax.axvline(x=0, color="k")
ax.legend(loc="best")
if __name__ == "__main__":
2022-03-21 08:55:55 -07:00
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
ray.init(num_cpus=2)
config = {
"env": WheelBanditEnv,
2022-03-21 08:55:55 -07:00
"framework": args.framework,
"eager_tracing": (args.framework == "tf2"),
}
# Actual env steps per `train()` call will be
# 10 * `min_sample_timesteps_per_reporting` (100 by default) = 1,000
training_iterations = 10
print("Running training for %s time steps" % training_iterations)
start_time = time.time()
analysis = tune.run(
"BanditLinTS",
config=config,
stop={"training_iteration": training_iterations},
2022-03-21 08:55:55 -07:00
num_samples=1,
checkpoint_at_end=True,
)
print("The trials took", time.time() - start_time, "seconds\n")
# Analyze cumulative regrets of the trials.
frame = pd.DataFrame()
for key, df in analysis.trial_dataframes.items():
frame = frame.append(df, ignore_index=True)
x = frame.groupby("agent_timesteps_total")["episode_reward_mean"].aggregate(
["mean", "max", "min", "std"]
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.plot(x["mean"])
ax1.set_title("Episode reward mean")
ax1.set_xlabel("Training steps")
# Restore trainer from checkpoint
trial = analysis.trials[0]
trainer = BanditLinTSTrainer(config=config)
trainer.restore(trial.checkpoint.value)
# Get model to plot arm weights distribution
model = trainer.get_policy().model
means = [model.arms[i].theta.numpy() for i in range(5)]
covs = [model.arms[i].covariance.numpy() for i in range(5)]
# Plot weight distributions for different arms
plot_model_weights(means, covs, ax2)
fig.tight_layout()
plt.show()