ray/rllib/examples/bandit/tune_lin_ts_train_wheel_env.py
Kai Fricke 8affbc7be6
[tune/train] Consolidate checkpoint manager 3: Ray Tune (#24430)
**Update**: This PR is now part 3 of a three PR group to consolidate the checkpoints.

1. Part 1 adds the common checkpoint management class #24771 
2. Part 2 adds the integration for Ray Train #24772
3. This PR builds on #24772 and includes all changes. It moves the Ray Tune integration to use the new common checkpoint manager class.

Old PR description:

This PR consolidates the Ray Train and Tune checkpoint managers. These concepts previously did something very similar but in different modules. To simplify maintenance in the future, we've consolidated the common core.

- This PR keeps full compatibility with the previous interfaces and implementations. This means that for now, Train and Tune will have separate CheckpointManagers that both extend the common core
- This PR prepares Tune to move to a CheckpointStrategy object
- In follow-up PRs, we can further unify interfacing with the common core, possibly removing any train- or tune-specific adjustments (e.g. moving to setup on init rather on runtime for Ray Train)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2022-06-08 12:05:34 +01:00

99 lines
2.8 KiB
Python

""" Example of using Linear Thompson Sampling on WheelBandit environment.
For more information on WheelBandit, see https://arxiv.org/abs/1802.09127 .
"""
import argparse
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import time
import ray
from ray import tune
from ray.rllib.algorithms.bandit.bandit import BanditLinTS
from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
def plot_model_weights(means, covs, ax):
fmts = ["bo", "ro", "yx", "k+", "gx"]
labels = ["arm{}".format(i) for i in range(5)]
ax.set_title("Weights distributions of arms")
for i in range(0, 5):
x, y = np.random.multivariate_normal(means[i] / 30, covs[i], 5000).T
ax.plot(x, y, fmts[i], label=labels[i])
ax.set_aspect("equal")
ax.grid(True, which="both")
ax.axhline(y=0, color="k")
ax.axvline(x=0, color="k")
ax.legend(loc="best")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
ray.init(num_cpus=2)
config = {
"env": WheelBanditEnv,
"framework": args.framework,
"eager_tracing": (args.framework == "tf2"),
}
# Actual env steps per `train()` call will be
# 10 * `min_sample_timesteps_per_reporting` (100 by default) = 1,000
training_iterations = 10
print("Running training for %s time steps" % training_iterations)
start_time = time.time()
analysis = tune.run(
"BanditLinTS",
config=config,
stop={"training_iteration": training_iterations},
num_samples=1,
checkpoint_at_end=True,
)
print("The trials took", time.time() - start_time, "seconds\n")
# Analyze cumulative regrets of the trials.
frame = pd.DataFrame()
for key, df in analysis.trial_dataframes.items():
frame = frame.append(df, ignore_index=True)
x = frame.groupby("agent_timesteps_total")["episode_reward_mean"].aggregate(
["mean", "max", "min", "std"]
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.plot(x["mean"])
ax1.set_title("Episode reward mean")
ax1.set_xlabel("Training steps")
# Restore trainer from checkpoint
trial = analysis.trials[0]
trainer = BanditLinTS(config=config)
trainer.restore(trial.checkpoint.dir_or_data)
# Get model to plot arm weights distribution
model = trainer.get_policy().model
means = [model.arms[i].theta.numpy() for i in range(5)]
covs = [model.arms[i].covariance.numpy() for i in range(5)]
# Plot weight distributions for different arms
plot_model_weights(means, covs, ax2)
fig.tight_layout()
plt.show()