mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00

**Update**: This PR is now part 3 of a three PR group to consolidate the checkpoints. 1. Part 1 adds the common checkpoint management class #24771 2. Part 2 adds the integration for Ray Train #24772 3. This PR builds on #24772 and includes all changes. It moves the Ray Tune integration to use the new common checkpoint manager class. Old PR description: This PR consolidates the Ray Train and Tune checkpoint managers. These concepts previously did something very similar but in different modules. To simplify maintenance in the future, we've consolidated the common core. - This PR keeps full compatibility with the previous interfaces and implementations. This means that for now, Train and Tune will have separate CheckpointManagers that both extend the common core - This PR prepares Tune to move to a CheckpointStrategy object - In follow-up PRs, we can further unify interfacing with the common core, possibly removing any train- or tune-specific adjustments (e.g. moving to setup on init rather on runtime for Ray Train) Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
99 lines
2.8 KiB
Python
99 lines
2.8 KiB
Python
""" Example of using Linear Thompson Sampling on WheelBandit environment.
|
|
For more information on WheelBandit, see https://arxiv.org/abs/1802.09127 .
|
|
"""
|
|
|
|
import argparse
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import time
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.rllib.algorithms.bandit.bandit import BanditLinTS
|
|
from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
|
|
|
|
|
def plot_model_weights(means, covs, ax):
|
|
fmts = ["bo", "ro", "yx", "k+", "gx"]
|
|
labels = ["arm{}".format(i) for i in range(5)]
|
|
|
|
ax.set_title("Weights distributions of arms")
|
|
|
|
for i in range(0, 5):
|
|
x, y = np.random.multivariate_normal(means[i] / 30, covs[i], 5000).T
|
|
ax.plot(x, y, fmts[i], label=labels[i])
|
|
|
|
ax.set_aspect("equal")
|
|
ax.grid(True, which="both")
|
|
ax.axhline(y=0, color="k")
|
|
ax.axvline(x=0, color="k")
|
|
ax.legend(loc="best")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf2", "torch"],
|
|
default="torch",
|
|
help="The DL framework specifier.",
|
|
)
|
|
args = parser.parse_args()
|
|
print(f"Running with following CLI args: {args}")
|
|
|
|
ray.init(num_cpus=2)
|
|
|
|
config = {
|
|
"env": WheelBanditEnv,
|
|
"framework": args.framework,
|
|
"eager_tracing": (args.framework == "tf2"),
|
|
}
|
|
|
|
# Actual env steps per `train()` call will be
|
|
# 10 * `min_sample_timesteps_per_reporting` (100 by default) = 1,000
|
|
training_iterations = 10
|
|
|
|
print("Running training for %s time steps" % training_iterations)
|
|
|
|
start_time = time.time()
|
|
analysis = tune.run(
|
|
"BanditLinTS",
|
|
config=config,
|
|
stop={"training_iteration": training_iterations},
|
|
num_samples=1,
|
|
checkpoint_at_end=True,
|
|
)
|
|
|
|
print("The trials took", time.time() - start_time, "seconds\n")
|
|
|
|
# Analyze cumulative regrets of the trials.
|
|
frame = pd.DataFrame()
|
|
for key, df in analysis.trial_dataframes.items():
|
|
frame = frame.append(df, ignore_index=True)
|
|
|
|
x = frame.groupby("agent_timesteps_total")["episode_reward_mean"].aggregate(
|
|
["mean", "max", "min", "std"]
|
|
)
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
|
|
|
|
ax1.plot(x["mean"])
|
|
|
|
ax1.set_title("Episode reward mean")
|
|
ax1.set_xlabel("Training steps")
|
|
|
|
# Restore trainer from checkpoint
|
|
trial = analysis.trials[0]
|
|
trainer = BanditLinTS(config=config)
|
|
trainer.restore(trial.checkpoint.dir_or_data)
|
|
|
|
# Get model to plot arm weights distribution
|
|
model = trainer.get_policy().model
|
|
means = [model.arms[i].theta.numpy() for i in range(5)]
|
|
covs = [model.arms[i].covariance.numpy() for i in range(5)]
|
|
|
|
# Plot weight distributions for different arms
|
|
plot_model_weights(means, covs, ax2)
|
|
fig.tight_layout()
|
|
plt.show()
|