[air] Example to track runs with Weights & Biases (#24459)

This PR 
- adds an example on how to run Ray Train and log results to weights & biases
- adds functionality to the W&B plugin to store checkpoints
- fixes a bug introduced in #24017
- Adds a CI utility script to setup credentials
- Adds a CI utility script to remove test state from external services cc @simon-mo
This commit is contained in:
Kai Fricke 2022-05-06 15:52:37 +01:00 committed by GitHub
parent fee35444ab
commit 5d9bf4234a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 184 additions and 19 deletions

View file

@ -3,7 +3,12 @@
commands: commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh - DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu python/ray/ml/... - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,-needs_credentials python/ray/ml/...
# Only setup credentials in branch builds
- if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then exit 0; fi
- python ./ci/env/setup_credentials.py wandb
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,needs_credentials python/ray/ml/...
- python ./ci/env/cleanup_test_state.py wandb
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph" - label: ":brain: RLlib: Learning discr. actions TF2-static-graph"
conditions: ["RAY_CI_RLLIB_AFFECTED"] conditions: ["RAY_CI_RLLIB_AFFECTED"]

32
ci/env/cleanup_test_state.py vendored Normal file
View file

@ -0,0 +1,32 @@
import sys
def clear_wandb_project():
import wandb
# This is hardcoded in the `ray/ml/examples/upload_to_wandb.py` example
wandb_project = "ray_air_example"
api = wandb.Api()
for run in api.runs(wandb_project):
run.delete()
SERVICES = {"wandb": clear_wandb_project}
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
sys.exit(0)
services = sys.argv[1:]
if any(service not in SERVICES for service in services):
raise RuntimeError(
f"All services must be included in {list(SERVICES.keys())}. "
f"Got: {services}"
)
for service in services:
SERVICES[service]()

35
ci/env/setup_credentials.py vendored Normal file
View file

@ -0,0 +1,35 @@
import os
import sys
import boto3
AWS_WANDB_SECRET_ARN = (
"arn:aws:secretsmanager:us-west-2:029272617770:secret:oss-ci/wandb-key-V8UeE5"
)
def get_and_write_wandb_api_key(client):
api_key = client.get_secret_value(SecretId=AWS_WANDB_SECRET_ARN)["SecretString"]
with open(os.path.expanduser("~/.netrc"), "w") as fp:
fp.write(f"machine api.wandb.ai\n" f" login user\n" f" password {api_key}\n")
SERVICES = {"wandb": get_and_write_wandb_api_key}
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
sys.exit(0)
services = sys.argv[1:]
if any(service not in SERVICES for service in services):
raise RuntimeError(
f"All services must be included in {list(SERVICES.keys())}. "
f"Got: {services}"
)
client = boto3.client("secretsmanager", region_name="us-west-2")
for service in services:
SERVICES[service](client)

View file

@ -145,6 +145,14 @@ py_test(
deps = [":ml_lib"] deps = [":ml_lib"]
) )
py_test(
name = "upload_to_wandb",
size = "medium",
srcs = ["examples/upload_to_wandb.py"],
tags = ["team:ml", "exclusive", "needs_credentials"],
deps = [":ml_lib"]
)
py_test( py_test(
name = "xgboost_example", name = "xgboost_example",
size = "medium", size = "medium",

View file

@ -0,0 +1,49 @@
"""
In this example, we train a simple XGBoost model and log the training
results to Weights & Biases. We also save the resulting model checkpoints
as artifacts.
"""
import ray
from ray.ml import RunConfig
from ray.ml.result import Result
from ray.ml.train.integrations.xgboost import XGBoostTrainer
from ray.tune.integration.wandb import WandbLoggerCallback
from sklearn.datasets import load_breast_cancer
def get_train_dataset() -> ray.data.Dataset:
"""Return the "Breast cancer" dataset as a Ray dataset."""
data_raw = load_breast_cancer(as_frame=True)
df = data_raw["data"]
df["target"] = data_raw["target"]
return ray.data.from_pandas(df)
def train_model(train_dataset: ray.data.Dataset, wandb_project: str) -> Result:
"""Train a simple XGBoost model and return the result."""
trainer = XGBoostTrainer(
scaling_config={"num_workers": 2},
params={"tree_method": "auto"},
label_column="target",
datasets={"train": train_dataset},
num_boost_round=10,
run_config=RunConfig(
callbacks=[
# This is the part needed to enable logging to Weights & Biases.
# It assumes you've logged in before, e.g. with `wandb login`.
WandbLoggerCallback(
project=wandb_project,
save_checkpoints=True,
)
]
),
)
result = trainer.fit()
return result
wandb_project = "ray_air_example"
train_dataset = get_train_dataset()
result = train_model(train_dataset=train_dataset, wandb_project=wandb_project)

View file

@ -1,3 +1,4 @@
import enum
import os import os
import pickle import pickle
from collections.abc import Sequence from collections.abc import Sequence
@ -24,7 +25,6 @@ except ImportError:
wandb = None wandb = None
WANDB_ENV_VAR = "WANDB_API_KEY" WANDB_ENV_VAR = "WANDB_API_KEY"
_WANDB_QUEUE_END = (None,)
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image) _VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image) _VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
@ -181,30 +181,53 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No
) )
class _QueueItem(enum.Enum):
END = enum.auto()
RESULT = enum.auto()
CHECKPOINT = enum.auto()
class _WandbLoggingProcess(Process): class _WandbLoggingProcess(Process):
""" """
We need a `multiprocessing.Process` to allow multiple concurrent We need a `multiprocessing.Process` to allow multiple concurrent
wandb logging instances locally. wandb logging instances locally.
We use a queue for the driver to communicate with the logging process.
The queue accepts the following items:
- If it's a dict, it is assumed to be a result and will be logged using
``wandb.log()``
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
""" """
def __init__( def __init__(
self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs
): ):
super(_WandbLoggingProcess, self).__init__() super(_WandbLoggingProcess, self).__init__()
self.queue = queue self.queue = queue
self._exclude = set(exclude) self._exclude = set(exclude)
self._to_config = set(to_config) self._to_config = set(to_config)
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
self._trial_name = self.kwargs.get("name", "unknown")
def run(self): def run(self):
# Since we're running in a separate process already, use threads. wandb.require("service")
wandb.init(*self.args, **self.kwargs) wandb.init(*self.args, **self.kwargs)
wandb.setup()
while True: while True:
result = self.queue.get() item_type, item_content = self.queue.get()
if result == _WANDB_QUEUE_END: if item_type == _QueueItem.END:
break break
log, config_update = self._handle_result(result)
if item_type == _QueueItem.CHECKPOINT:
self._handle_checkpoint(item_content)
continue
assert item_type == _QueueItem.RESULT
log, config_update = self._handle_result(item_content)
try: try:
wandb.config.update(config_update, allow_val_change=True) wandb.config.update(config_update, allow_val_change=True)
wandb.log(log) wandb.log(log)
@ -214,6 +237,11 @@ class _WandbLoggingProcess(Process):
logger.warn("Failed to log result to w&b: {}".format(str(e))) logger.warn("Failed to log result to w&b: {}".format(str(e)))
wandb.finish() wandb.finish()
def _handle_checkpoint(self, checkpoint_path: str):
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
artifact.add_dir(checkpoint_path)
wandb.log_artifact(artifact)
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]: def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
config_update = result.get("config", {}).copy() config_update = result.get("config", {}).copy()
log = {} log = {}
@ -255,6 +283,8 @@ class WandbLoggerCallback(LoggerCallback):
the ``results`` dict should be logged. This makes sense if the ``results`` dict should be logged. This makes sense if
parameters will change during training, e.g. with parameters will change during training, e.g. with
PopulationBasedTraining. Defaults to False. PopulationBasedTraining. Defaults to False.
save_checkpoints: If ``True``, model checkpoints will be saved to
Wandb as artifacts. Defaults to ``False``.
**kwargs: The keyword arguments will be pased to ``wandb.init()``. **kwargs: The keyword arguments will be pased to ``wandb.init()``.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
@ -308,7 +338,8 @@ class WandbLoggerCallback(LoggerCallback):
api_key: Optional[str] = None, api_key: Optional[str] = None,
excludes: Optional[List[str]] = None, excludes: Optional[List[str]] = None,
log_config: bool = False, log_config: bool = False,
**kwargs save_checkpoints: bool = False,
**kwargs,
): ):
self.project = project self.project = project
self.group = group self.group = group
@ -316,18 +347,17 @@ class WandbLoggerCallback(LoggerCallback):
self.api_key = api_key self.api_key = api_key
self.excludes = excludes or [] self.excludes = excludes or []
self.log_config = log_config self.log_config = log_config
self.save_checkpoints = save_checkpoints
self.kwargs = kwargs self.kwargs = kwargs
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {} self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
self._trial_queues: Dict["Trial", Queue] = {} self._trial_queues: Dict["Trial", Queue] = {}
def setup(self): def setup(self, *args, **kwargs):
self.api_key_file = ( self.api_key_file = (
os.path.expanduser(self.api_key_path) if self.api_key_path else None os.path.expanduser(self.api_key_path) if self.api_key_path else None
) )
_set_api_key(self.api_key_file, self.api_key) _set_api_key(self.api_key_file, self.api_key)
wandb.require("service")
wandb.setup()
def log_trial_start(self, trial: "Trial"): def log_trial_start(self, trial: "Trial"):
config = trial.config.copy() config = trial.config.copy()
@ -373,7 +403,7 @@ class WandbLoggerCallback(LoggerCallback):
queue=self._trial_queues[trial], queue=self._trial_queues[trial],
exclude=exclude_results, exclude=exclude_results,
to_config=self._config_results, to_config=self._config_results,
**wandb_init_kwargs **wandb_init_kwargs,
) )
self._trial_processes[trial].start() self._trial_processes[trial].start()
@ -382,10 +412,16 @@ class WandbLoggerCallback(LoggerCallback):
self.log_trial_start(trial) self.log_trial_start(trial)
result = _clean_log(result) result = _clean_log(result)
self._trial_queues[trial].put(result) self._trial_queues[trial].put((_QueueItem.RESULT, result))
def log_trial_save(self, trial: "Trial"):
if self.save_checkpoints and trial.checkpoint:
self._trial_queues[trial].put(
(_QueueItem.CHECKPOINT, trial.checkpoint.value)
)
def log_trial_end(self, trial: "Trial", failed: bool = False): def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_queues[trial].put(_WANDB_QUEUE_END) self._trial_queues[trial].put((_QueueItem.END, None))
self._trial_processes[trial].join(timeout=10) self._trial_processes[trial].join(timeout=10)
del self._trial_queues[trial] del self._trial_queues[trial]
@ -394,7 +430,7 @@ class WandbLoggerCallback(LoggerCallback):
def __del__(self): def __del__(self):
for trial in self._trial_processes: for trial in self._trial_processes:
if trial in self._trial_queues: if trial in self._trial_queues:
self._trial_queues[trial].put(_WANDB_QUEUE_END) self._trial_queues[trial].put((_QueueItem.END, None))
del self._trial_queues[trial] del self._trial_queues[trial]
self._trial_processes[trial].join(timeout=2) self._trial_processes[trial].join(timeout=2)
del self._trial_processes[trial] del self._trial_processes[trial]

View file

@ -11,11 +11,11 @@ from ray.tune.function_runner import wrap_function
from ray.tune.integration.wandb import ( from ray.tune.integration.wandb import (
WandbLoggerCallback, WandbLoggerCallback,
_WandbLoggingProcess, _WandbLoggingProcess,
_WANDB_QUEUE_END,
WandbLogger, WandbLogger,
WANDB_ENV_VAR, WANDB_ENV_VAR,
WandbTrainableMixin, WandbTrainableMixin,
wandb_mixin, wandb_mixin,
_QueueItem,
) )
from ray.tune.result import TRIAL_INFO from ray.tune.result import TRIAL_INFO
from ray.tune.trial import TrialInfo from ray.tune.trial import TrialInfo
@ -52,10 +52,10 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess):
def run(self): def run(self):
while True: while True:
result = self.queue.get() result_type, result_content = self.queue.get()
if result == _WANDB_QUEUE_END: if result_type == _QueueItem.END:
break break
log, config_update = self._handle_result(result) log, config_update = self._handle_result(result_content)
self.config_updates.put(config_update) self.config_updates.put(config_update)
self.logs.put(log) self.logs.put(log)

View file

@ -83,7 +83,7 @@ pytest-lazy-fixture
pytest-timeout pytest-timeout
pytest-virtualenv pytest-virtualenv
redis >= 3.5.0, < 4.0.0 redis >= 3.5.0, < 4.0.0
scikit-learn==0.22.2 scikit-learn==0.24.2
testfixtures testfixtures
werkzeug werkzeug
xlrd xlrd