mirror of
https://github.com/vale981/ray
synced 2025-03-11 05:46:37 -04:00

update rllib example to use Tuner API. Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
104 lines
3 KiB
Python
104 lines
3 KiB
Python
""" Example of using Linear Thompson Sampling on WheelBandit environment.
|
|
For more information on WheelBandit, see https://arxiv.org/abs/1802.09127 .
|
|
"""
|
|
|
|
import argparse
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import time
|
|
|
|
import ray
|
|
from ray import air, tune
|
|
from ray.rllib.algorithms.bandit.bandit import BanditLinTS
|
|
from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
|
|
|
|
|
def plot_model_weights(means, covs, ax):
|
|
fmts = ["bo", "ro", "yx", "k+", "gx"]
|
|
labels = ["arm{}".format(i) for i in range(5)]
|
|
|
|
ax.set_title("Weights distributions of arms")
|
|
|
|
for i in range(0, 5):
|
|
x, y = np.random.multivariate_normal(means[i] / 30, covs[i], 5000).T
|
|
ax.plot(x, y, fmts[i], label=labels[i])
|
|
|
|
ax.set_aspect("equal")
|
|
ax.grid(True, which="both")
|
|
ax.axhline(y=0, color="k")
|
|
ax.axvline(x=0, color="k")
|
|
ax.legend(loc="best")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf2", "torch"],
|
|
default="torch",
|
|
help="The DL framework specifier.",
|
|
)
|
|
args = parser.parse_args()
|
|
print(f"Running with following CLI args: {args}")
|
|
|
|
ray.init(num_cpus=2)
|
|
|
|
config = {
|
|
"env": WheelBanditEnv,
|
|
"framework": args.framework,
|
|
"eager_tracing": (args.framework == "tf2"),
|
|
}
|
|
|
|
# Actual env steps per `train()` call will be
|
|
# 10 * `min_sample_timesteps_per_iteration` (100 by default) = 1,000
|
|
training_iterations = 10
|
|
|
|
print("Running training for %s time steps" % training_iterations)
|
|
|
|
start_time = time.time()
|
|
tuner = tune.Tuner(
|
|
"BanditLinTS",
|
|
param_space=config,
|
|
run_config=air.RunConfig(
|
|
stop={"training_iteration": training_iterations},
|
|
checkpoint_config=air.CheckpointConfig(
|
|
checkpoint_at_end=True,
|
|
),
|
|
),
|
|
tune_config=tune.TuneConfig(
|
|
num_samples=1,
|
|
),
|
|
)
|
|
results = tuner.fit()
|
|
|
|
print("The trials took", time.time() - start_time, "seconds\n")
|
|
|
|
# Analyze cumulative regrets of the trials.
|
|
# There is only one trial
|
|
result = results.get_best_result()
|
|
x = result.metrics_dataframe.groupby("agent_timesteps_total")[
|
|
"episode_reward_mean"
|
|
].aggregate(["mean", "max", "min", "std"])
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
|
|
|
|
ax1.plot(x["mean"])
|
|
|
|
ax1.set_title("Episode reward mean")
|
|
ax1.set_xlabel("Training steps")
|
|
|
|
# Restore trainer from checkpoint
|
|
checkpoint = results.get_best_result().checkpoint
|
|
trainer = BanditLinTS(config=config)
|
|
with checkpoint.as_directory() as checkpoint_dir:
|
|
trainer.restore(checkpoint_dir)
|
|
|
|
# Get model to plot arm weights distribution
|
|
model = trainer.get_policy().model
|
|
means = [model.arms[i].theta.numpy() for i in range(5)]
|
|
covs = [model.arms[i].covariance.numpy() for i in range(5)]
|
|
|
|
# Plot weight distributions for different arms
|
|
plot_model_weights(means, covs, ax2)
|
|
fig.tight_layout()
|
|
plt.show()
|