mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air] Example to track runs with Weights & Biases (#24459)
This PR - adds an example on how to run Ray Train and log results to weights & biases - adds functionality to the W&B plugin to store checkpoints - fixes a bug introduced in #24017 - Adds a CI utility script to setup credentials - Adds a CI utility script to remove test state from external services cc @simon-mo
This commit is contained in:
parent
fee35444ab
commit
5d9bf4234a
8 changed files with 184 additions and 19 deletions
|
@ -3,7 +3,12 @@
|
|||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
- DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu python/ray/ml/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,-needs_credentials python/ray/ml/...
|
||||
# Only setup credentials in branch builds
|
||||
- if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then exit 0; fi
|
||||
- python ./ci/env/setup_credentials.py wandb
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,needs_credentials python/ray/ml/...
|
||||
- python ./ci/env/cleanup_test_state.py wandb
|
||||
|
||||
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph"
|
||||
conditions: ["RAY_CI_RLLIB_AFFECTED"]
|
||||
|
|
32
ci/env/cleanup_test_state.py
vendored
Normal file
32
ci/env/cleanup_test_state.py
vendored
Normal file
|
@ -0,0 +1,32 @@
|
|||
import sys
|
||||
|
||||
|
||||
def clear_wandb_project():
|
||||
import wandb
|
||||
|
||||
# This is hardcoded in the `ray/ml/examples/upload_to_wandb.py` example
|
||||
wandb_project = "ray_air_example"
|
||||
|
||||
api = wandb.Api()
|
||||
for run in api.runs(wandb_project):
|
||||
run.delete()
|
||||
|
||||
|
||||
SERVICES = {"wandb": clear_wandb_project}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
|
||||
sys.exit(0)
|
||||
|
||||
services = sys.argv[1:]
|
||||
|
||||
if any(service not in SERVICES for service in services):
|
||||
raise RuntimeError(
|
||||
f"All services must be included in {list(SERVICES.keys())}. "
|
||||
f"Got: {services}"
|
||||
)
|
||||
|
||||
for service in services:
|
||||
SERVICES[service]()
|
35
ci/env/setup_credentials.py
vendored
Normal file
35
ci/env/setup_credentials.py
vendored
Normal file
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import boto3
|
||||
|
||||
AWS_WANDB_SECRET_ARN = (
|
||||
"arn:aws:secretsmanager:us-west-2:029272617770:secret:oss-ci/wandb-key-V8UeE5"
|
||||
)
|
||||
|
||||
|
||||
def get_and_write_wandb_api_key(client):
|
||||
api_key = client.get_secret_value(SecretId=AWS_WANDB_SECRET_ARN)["SecretString"]
|
||||
with open(os.path.expanduser("~/.netrc"), "w") as fp:
|
||||
fp.write(f"machine api.wandb.ai\n" f" login user\n" f" password {api_key}\n")
|
||||
|
||||
|
||||
SERVICES = {"wandb": get_and_write_wandb_api_key}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
|
||||
sys.exit(0)
|
||||
|
||||
services = sys.argv[1:]
|
||||
|
||||
if any(service not in SERVICES for service in services):
|
||||
raise RuntimeError(
|
||||
f"All services must be included in {list(SERVICES.keys())}. "
|
||||
f"Got: {services}"
|
||||
)
|
||||
|
||||
client = boto3.client("secretsmanager", region_name="us-west-2")
|
||||
for service in services:
|
||||
SERVICES[service](client)
|
|
@ -145,6 +145,14 @@ py_test(
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "upload_to_wandb",
|
||||
size = "medium",
|
||||
srcs = ["examples/upload_to_wandb.py"],
|
||||
tags = ["team:ml", "exclusive", "needs_credentials"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "xgboost_example",
|
||||
size = "medium",
|
||||
|
|
49
python/ray/ml/examples/upload_to_wandb.py
Normal file
49
python/ray/ml/examples/upload_to_wandb.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
"""
|
||||
In this example, we train a simple XGBoost model and log the training
|
||||
results to Weights & Biases. We also save the resulting model checkpoints
|
||||
as artifacts.
|
||||
"""
|
||||
import ray
|
||||
|
||||
from ray.ml import RunConfig
|
||||
from ray.ml.result import Result
|
||||
from ray.ml.train.integrations.xgboost import XGBoostTrainer
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
|
||||
def get_train_dataset() -> ray.data.Dataset:
|
||||
"""Return the "Breast cancer" dataset as a Ray dataset."""
|
||||
data_raw = load_breast_cancer(as_frame=True)
|
||||
df = data_raw["data"]
|
||||
df["target"] = data_raw["target"]
|
||||
return ray.data.from_pandas(df)
|
||||
|
||||
|
||||
def train_model(train_dataset: ray.data.Dataset, wandb_project: str) -> Result:
|
||||
"""Train a simple XGBoost model and return the result."""
|
||||
trainer = XGBoostTrainer(
|
||||
scaling_config={"num_workers": 2},
|
||||
params={"tree_method": "auto"},
|
||||
label_column="target",
|
||||
datasets={"train": train_dataset},
|
||||
num_boost_round=10,
|
||||
run_config=RunConfig(
|
||||
callbacks=[
|
||||
# This is the part needed to enable logging to Weights & Biases.
|
||||
# It assumes you've logged in before, e.g. with `wandb login`.
|
||||
WandbLoggerCallback(
|
||||
project=wandb_project,
|
||||
save_checkpoints=True,
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
result = trainer.fit()
|
||||
return result
|
||||
|
||||
|
||||
wandb_project = "ray_air_example"
|
||||
|
||||
train_dataset = get_train_dataset()
|
||||
result = train_model(train_dataset=train_dataset, wandb_project=wandb_project)
|
|
@ -1,3 +1,4 @@
|
|||
import enum
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
|
@ -24,7 +25,6 @@ except ImportError:
|
|||
wandb = None
|
||||
|
||||
WANDB_ENV_VAR = "WANDB_API_KEY"
|
||||
_WANDB_QUEUE_END = (None,)
|
||||
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
|
||||
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
|
||||
|
||||
|
@ -181,30 +181,53 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No
|
|||
)
|
||||
|
||||
|
||||
class _QueueItem(enum.Enum):
|
||||
END = enum.auto()
|
||||
RESULT = enum.auto()
|
||||
CHECKPOINT = enum.auto()
|
||||
|
||||
|
||||
class _WandbLoggingProcess(Process):
|
||||
"""
|
||||
We need a `multiprocessing.Process` to allow multiple concurrent
|
||||
wandb logging instances locally.
|
||||
|
||||
We use a queue for the driver to communicate with the logging process.
|
||||
The queue accepts the following items:
|
||||
|
||||
- If it's a dict, it is assumed to be a result and will be logged using
|
||||
``wandb.log()``
|
||||
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs
|
||||
):
|
||||
super(_WandbLoggingProcess, self).__init__()
|
||||
|
||||
self.queue = queue
|
||||
self._exclude = set(exclude)
|
||||
self._to_config = set(to_config)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_name = self.kwargs.get("name", "unknown")
|
||||
|
||||
def run(self):
|
||||
# Since we're running in a separate process already, use threads.
|
||||
wandb.require("service")
|
||||
wandb.init(*self.args, **self.kwargs)
|
||||
wandb.setup()
|
||||
while True:
|
||||
result = self.queue.get()
|
||||
if result == _WANDB_QUEUE_END:
|
||||
item_type, item_content = self.queue.get()
|
||||
if item_type == _QueueItem.END:
|
||||
break
|
||||
log, config_update = self._handle_result(result)
|
||||
|
||||
if item_type == _QueueItem.CHECKPOINT:
|
||||
self._handle_checkpoint(item_content)
|
||||
continue
|
||||
|
||||
assert item_type == _QueueItem.RESULT
|
||||
log, config_update = self._handle_result(item_content)
|
||||
try:
|
||||
wandb.config.update(config_update, allow_val_change=True)
|
||||
wandb.log(log)
|
||||
|
@ -214,6 +237,11 @@ class _WandbLoggingProcess(Process):
|
|||
logger.warn("Failed to log result to w&b: {}".format(str(e)))
|
||||
wandb.finish()
|
||||
|
||||
def _handle_checkpoint(self, checkpoint_path: str):
|
||||
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
|
||||
artifact.add_dir(checkpoint_path)
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
|
||||
config_update = result.get("config", {}).copy()
|
||||
log = {}
|
||||
|
@ -255,6 +283,8 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
the ``results`` dict should be logged. This makes sense if
|
||||
parameters will change during training, e.g. with
|
||||
PopulationBasedTraining. Defaults to False.
|
||||
save_checkpoints: If ``True``, model checkpoints will be saved to
|
||||
Wandb as artifacts. Defaults to ``False``.
|
||||
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
|
||||
|
||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||
|
@ -308,7 +338,8 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
api_key: Optional[str] = None,
|
||||
excludes: Optional[List[str]] = None,
|
||||
log_config: bool = False,
|
||||
**kwargs
|
||||
save_checkpoints: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.project = project
|
||||
self.group = group
|
||||
|
@ -316,18 +347,17 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
self.api_key = api_key
|
||||
self.excludes = excludes or []
|
||||
self.log_config = log_config
|
||||
self.save_checkpoints = save_checkpoints
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
|
||||
self._trial_queues: Dict["Trial", Queue] = {}
|
||||
|
||||
def setup(self):
|
||||
def setup(self, *args, **kwargs):
|
||||
self.api_key_file = (
|
||||
os.path.expanduser(self.api_key_path) if self.api_key_path else None
|
||||
)
|
||||
_set_api_key(self.api_key_file, self.api_key)
|
||||
wandb.require("service")
|
||||
wandb.setup()
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
config = trial.config.copy()
|
||||
|
@ -373,7 +403,7 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
queue=self._trial_queues[trial],
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs
|
||||
**wandb_init_kwargs,
|
||||
)
|
||||
self._trial_processes[trial].start()
|
||||
|
||||
|
@ -382,10 +412,16 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
self.log_trial_start(trial)
|
||||
|
||||
result = _clean_log(result)
|
||||
self._trial_queues[trial].put(result)
|
||||
self._trial_queues[trial].put((_QueueItem.RESULT, result))
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
if self.save_checkpoints and trial.checkpoint:
|
||||
self._trial_queues[trial].put(
|
||||
(_QueueItem.CHECKPOINT, trial.checkpoint.value)
|
||||
)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
self._trial_queues[trial].put(_WANDB_QUEUE_END)
|
||||
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||
self._trial_processes[trial].join(timeout=10)
|
||||
|
||||
del self._trial_queues[trial]
|
||||
|
@ -394,7 +430,7 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
def __del__(self):
|
||||
for trial in self._trial_processes:
|
||||
if trial in self._trial_queues:
|
||||
self._trial_queues[trial].put(_WANDB_QUEUE_END)
|
||||
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||
del self._trial_queues[trial]
|
||||
self._trial_processes[trial].join(timeout=2)
|
||||
del self._trial_processes[trial]
|
||||
|
|
|
@ -11,11 +11,11 @@ from ray.tune.function_runner import wrap_function
|
|||
from ray.tune.integration.wandb import (
|
||||
WandbLoggerCallback,
|
||||
_WandbLoggingProcess,
|
||||
_WANDB_QUEUE_END,
|
||||
WandbLogger,
|
||||
WANDB_ENV_VAR,
|
||||
WandbTrainableMixin,
|
||||
wandb_mixin,
|
||||
_QueueItem,
|
||||
)
|
||||
from ray.tune.result import TRIAL_INFO
|
||||
from ray.tune.trial import TrialInfo
|
||||
|
@ -52,10 +52,10 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess):
|
|||
|
||||
def run(self):
|
||||
while True:
|
||||
result = self.queue.get()
|
||||
if result == _WANDB_QUEUE_END:
|
||||
result_type, result_content = self.queue.get()
|
||||
if result_type == _QueueItem.END:
|
||||
break
|
||||
log, config_update = self._handle_result(result)
|
||||
log, config_update = self._handle_result(result_content)
|
||||
self.config_updates.put(config_update)
|
||||
self.logs.put(log)
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ pytest-lazy-fixture
|
|||
pytest-timeout
|
||||
pytest-virtualenv
|
||||
redis >= 3.5.0, < 4.0.0
|
||||
scikit-learn==0.22.2
|
||||
scikit-learn==0.24.2
|
||||
testfixtures
|
||||
werkzeug
|
||||
xlrd
|
||||
|
|
Loading…
Add table
Reference in a new issue