[air] Example to track runs with Weights & Biases (#24459)

This PR 
- adds an example on how to run Ray Train and log results to weights & biases
- adds functionality to the W&B plugin to store checkpoints
- fixes a bug introduced in #24017
- Adds a CI utility script to setup credentials
- Adds a CI utility script to remove test state from external services cc @simon-mo
This commit is contained in:
Kai Fricke 2022-05-06 15:52:37 +01:00 committed by GitHub
parent fee35444ab
commit 5d9bf4234a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 184 additions and 19 deletions

View file

@ -3,7 +3,12 @@
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu python/ray/ml/...
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,-needs_credentials python/ray/ml/...
# Only setup credentials in branch builds
- if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then exit 0; fi
- python ./ci/env/setup_credentials.py wandb
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,needs_credentials python/ray/ml/...
- python ./ci/env/cleanup_test_state.py wandb
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph"
conditions: ["RAY_CI_RLLIB_AFFECTED"]

32
ci/env/cleanup_test_state.py vendored Normal file
View file

@ -0,0 +1,32 @@
import sys
def clear_wandb_project():
import wandb
# This is hardcoded in the `ray/ml/examples/upload_to_wandb.py` example
wandb_project = "ray_air_example"
api = wandb.Api()
for run in api.runs(wandb_project):
run.delete()
SERVICES = {"wandb": clear_wandb_project}
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
sys.exit(0)
services = sys.argv[1:]
if any(service not in SERVICES for service in services):
raise RuntimeError(
f"All services must be included in {list(SERVICES.keys())}. "
f"Got: {services}"
)
for service in services:
SERVICES[service]()

35
ci/env/setup_credentials.py vendored Normal file
View file

@ -0,0 +1,35 @@
import os
import sys
import boto3
AWS_WANDB_SECRET_ARN = (
"arn:aws:secretsmanager:us-west-2:029272617770:secret:oss-ci/wandb-key-V8UeE5"
)
def get_and_write_wandb_api_key(client):
api_key = client.get_secret_value(SecretId=AWS_WANDB_SECRET_ARN)["SecretString"]
with open(os.path.expanduser("~/.netrc"), "w") as fp:
fp.write(f"machine api.wandb.ai\n" f" login user\n" f" password {api_key}\n")
SERVICES = {"wandb": get_and_write_wandb_api_key}
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
sys.exit(0)
services = sys.argv[1:]
if any(service not in SERVICES for service in services):
raise RuntimeError(
f"All services must be included in {list(SERVICES.keys())}. "
f"Got: {services}"
)
client = boto3.client("secretsmanager", region_name="us-west-2")
for service in services:
SERVICES[service](client)

View file

@ -145,6 +145,14 @@ py_test(
deps = [":ml_lib"]
)
py_test(
name = "upload_to_wandb",
size = "medium",
srcs = ["examples/upload_to_wandb.py"],
tags = ["team:ml", "exclusive", "needs_credentials"],
deps = [":ml_lib"]
)
py_test(
name = "xgboost_example",
size = "medium",

View file

@ -0,0 +1,49 @@
"""
In this example, we train a simple XGBoost model and log the training
results to Weights & Biases. We also save the resulting model checkpoints
as artifacts.
"""
import ray
from ray.ml import RunConfig
from ray.ml.result import Result
from ray.ml.train.integrations.xgboost import XGBoostTrainer
from ray.tune.integration.wandb import WandbLoggerCallback
from sklearn.datasets import load_breast_cancer
def get_train_dataset() -> ray.data.Dataset:
"""Return the "Breast cancer" dataset as a Ray dataset."""
data_raw = load_breast_cancer(as_frame=True)
df = data_raw["data"]
df["target"] = data_raw["target"]
return ray.data.from_pandas(df)
def train_model(train_dataset: ray.data.Dataset, wandb_project: str) -> Result:
"""Train a simple XGBoost model and return the result."""
trainer = XGBoostTrainer(
scaling_config={"num_workers": 2},
params={"tree_method": "auto"},
label_column="target",
datasets={"train": train_dataset},
num_boost_round=10,
run_config=RunConfig(
callbacks=[
# This is the part needed to enable logging to Weights & Biases.
# It assumes you've logged in before, e.g. with `wandb login`.
WandbLoggerCallback(
project=wandb_project,
save_checkpoints=True,
)
]
),
)
result = trainer.fit()
return result
wandb_project = "ray_air_example"
train_dataset = get_train_dataset()
result = train_model(train_dataset=train_dataset, wandb_project=wandb_project)

View file

@ -1,3 +1,4 @@
import enum
import os
import pickle
from collections.abc import Sequence
@ -24,7 +25,6 @@ except ImportError:
wandb = None
WANDB_ENV_VAR = "WANDB_API_KEY"
_WANDB_QUEUE_END = (None,)
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
@ -181,30 +181,53 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No
)
class _QueueItem(enum.Enum):
END = enum.auto()
RESULT = enum.auto()
CHECKPOINT = enum.auto()
class _WandbLoggingProcess(Process):
"""
We need a `multiprocessing.Process` to allow multiple concurrent
wandb logging instances locally.
We use a queue for the driver to communicate with the logging process.
The queue accepts the following items:
- If it's a dict, it is assumed to be a result and will be logged using
``wandb.log()``
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
"""
def __init__(
self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs
):
super(_WandbLoggingProcess, self).__init__()
self.queue = queue
self._exclude = set(exclude)
self._to_config = set(to_config)
self.args = args
self.kwargs = kwargs
self._trial_name = self.kwargs.get("name", "unknown")
def run(self):
# Since we're running in a separate process already, use threads.
wandb.require("service")
wandb.init(*self.args, **self.kwargs)
wandb.setup()
while True:
result = self.queue.get()
if result == _WANDB_QUEUE_END:
item_type, item_content = self.queue.get()
if item_type == _QueueItem.END:
break
log, config_update = self._handle_result(result)
if item_type == _QueueItem.CHECKPOINT:
self._handle_checkpoint(item_content)
continue
assert item_type == _QueueItem.RESULT
log, config_update = self._handle_result(item_content)
try:
wandb.config.update(config_update, allow_val_change=True)
wandb.log(log)
@ -214,6 +237,11 @@ class _WandbLoggingProcess(Process):
logger.warn("Failed to log result to w&b: {}".format(str(e)))
wandb.finish()
def _handle_checkpoint(self, checkpoint_path: str):
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
artifact.add_dir(checkpoint_path)
wandb.log_artifact(artifact)
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
config_update = result.get("config", {}).copy()
log = {}
@ -255,6 +283,8 @@ class WandbLoggerCallback(LoggerCallback):
the ``results`` dict should be logged. This makes sense if
parameters will change during training, e.g. with
PopulationBasedTraining. Defaults to False.
save_checkpoints: If ``True``, model checkpoints will be saved to
Wandb as artifacts. Defaults to ``False``.
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
@ -308,7 +338,8 @@ class WandbLoggerCallback(LoggerCallback):
api_key: Optional[str] = None,
excludes: Optional[List[str]] = None,
log_config: bool = False,
**kwargs
save_checkpoints: bool = False,
**kwargs,
):
self.project = project
self.group = group
@ -316,18 +347,17 @@ class WandbLoggerCallback(LoggerCallback):
self.api_key = api_key
self.excludes = excludes or []
self.log_config = log_config
self.save_checkpoints = save_checkpoints
self.kwargs = kwargs
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
self._trial_queues: Dict["Trial", Queue] = {}
def setup(self):
def setup(self, *args, **kwargs):
self.api_key_file = (
os.path.expanduser(self.api_key_path) if self.api_key_path else None
)
_set_api_key(self.api_key_file, self.api_key)
wandb.require("service")
wandb.setup()
def log_trial_start(self, trial: "Trial"):
config = trial.config.copy()
@ -373,7 +403,7 @@ class WandbLoggerCallback(LoggerCallback):
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs
**wandb_init_kwargs,
)
self._trial_processes[trial].start()
@ -382,10 +412,16 @@ class WandbLoggerCallback(LoggerCallback):
self.log_trial_start(trial)
result = _clean_log(result)
self._trial_queues[trial].put(result)
self._trial_queues[trial].put((_QueueItem.RESULT, result))
def log_trial_save(self, trial: "Trial"):
if self.save_checkpoints and trial.checkpoint:
self._trial_queues[trial].put(
(_QueueItem.CHECKPOINT, trial.checkpoint.value)
)
def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_queues[trial].put(_WANDB_QUEUE_END)
self._trial_queues[trial].put((_QueueItem.END, None))
self._trial_processes[trial].join(timeout=10)
del self._trial_queues[trial]
@ -394,7 +430,7 @@ class WandbLoggerCallback(LoggerCallback):
def __del__(self):
for trial in self._trial_processes:
if trial in self._trial_queues:
self._trial_queues[trial].put(_WANDB_QUEUE_END)
self._trial_queues[trial].put((_QueueItem.END, None))
del self._trial_queues[trial]
self._trial_processes[trial].join(timeout=2)
del self._trial_processes[trial]

View file

@ -11,11 +11,11 @@ from ray.tune.function_runner import wrap_function
from ray.tune.integration.wandb import (
WandbLoggerCallback,
_WandbLoggingProcess,
_WANDB_QUEUE_END,
WandbLogger,
WANDB_ENV_VAR,
WandbTrainableMixin,
wandb_mixin,
_QueueItem,
)
from ray.tune.result import TRIAL_INFO
from ray.tune.trial import TrialInfo
@ -52,10 +52,10 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess):
def run(self):
while True:
result = self.queue.get()
if result == _WANDB_QUEUE_END:
result_type, result_content = self.queue.get()
if result_type == _QueueItem.END:
break
log, config_update = self._handle_result(result)
log, config_update = self._handle_result(result_content)
self.config_updates.put(config_update)
self.logs.put(log)

View file

@ -83,7 +83,7 @@ pytest-lazy-fixture
pytest-timeout
pytest-virtualenv
redis >= 3.5.0, < 4.0.0
scikit-learn==0.22.2
scikit-learn==0.24.2
testfixtures
werkzeug
xlrd