mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air] Example to track runs with Weights & Biases (#24459)
This PR - adds an example on how to run Ray Train and log results to weights & biases - adds functionality to the W&B plugin to store checkpoints - fixes a bug introduced in #24017 - Adds a CI utility script to setup credentials - Adds a CI utility script to remove test state from external services cc @simon-mo
This commit is contained in:
parent
fee35444ab
commit
5d9bf4234a
8 changed files with 184 additions and 19 deletions
|
@ -3,7 +3,12 @@
|
||||||
commands:
|
commands:
|
||||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||||
- DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
|
- DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
|
||||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu python/ray/ml/...
|
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,-needs_credentials python/ray/ml/...
|
||||||
|
# Only setup credentials in branch builds
|
||||||
|
- if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then exit 0; fi
|
||||||
|
- python ./ci/env/setup_credentials.py wandb
|
||||||
|
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,needs_credentials python/ray/ml/...
|
||||||
|
- python ./ci/env/cleanup_test_state.py wandb
|
||||||
|
|
||||||
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph"
|
- label: ":brain: RLlib: Learning discr. actions TF2-static-graph"
|
||||||
conditions: ["RAY_CI_RLLIB_AFFECTED"]
|
conditions: ["RAY_CI_RLLIB_AFFECTED"]
|
||||||
|
|
32
ci/env/cleanup_test_state.py
vendored
Normal file
32
ci/env/cleanup_test_state.py
vendored
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def clear_wandb_project():
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
# This is hardcoded in the `ray/ml/examples/upload_to_wandb.py` example
|
||||||
|
wandb_project = "ray_air_example"
|
||||||
|
|
||||||
|
api = wandb.Api()
|
||||||
|
for run in api.runs(wandb_project):
|
||||||
|
run.delete()
|
||||||
|
|
||||||
|
|
||||||
|
SERVICES = {"wandb": clear_wandb_project}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
services = sys.argv[1:]
|
||||||
|
|
||||||
|
if any(service not in SERVICES for service in services):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"All services must be included in {list(SERVICES.keys())}. "
|
||||||
|
f"Got: {services}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for service in services:
|
||||||
|
SERVICES[service]()
|
35
ci/env/setup_credentials.py
vendored
Normal file
35
ci/env/setup_credentials.py
vendored
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
AWS_WANDB_SECRET_ARN = (
|
||||||
|
"arn:aws:secretsmanager:us-west-2:029272617770:secret:oss-ci/wandb-key-V8UeE5"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_write_wandb_api_key(client):
|
||||||
|
api_key = client.get_secret_value(SecretId=AWS_WANDB_SECRET_ARN)["SecretString"]
|
||||||
|
with open(os.path.expanduser("~/.netrc"), "w") as fp:
|
||||||
|
fp.write(f"machine api.wandb.ai\n" f" login user\n" f" password {api_key}\n")
|
||||||
|
|
||||||
|
|
||||||
|
SERVICES = {"wandb": get_and_write_wandb_api_key}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print(f"Usage: python {sys.argv[0]} <service1> [service2] ...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
services = sys.argv[1:]
|
||||||
|
|
||||||
|
if any(service not in SERVICES for service in services):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"All services must be included in {list(SERVICES.keys())}. "
|
||||||
|
f"Got: {services}"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client("secretsmanager", region_name="us-west-2")
|
||||||
|
for service in services:
|
||||||
|
SERVICES[service](client)
|
|
@ -145,6 +145,14 @@ py_test(
|
||||||
deps = [":ml_lib"]
|
deps = [":ml_lib"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "upload_to_wandb",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["examples/upload_to_wandb.py"],
|
||||||
|
tags = ["team:ml", "exclusive", "needs_credentials"],
|
||||||
|
deps = [":ml_lib"]
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "xgboost_example",
|
name = "xgboost_example",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
|
49
python/ray/ml/examples/upload_to_wandb.py
Normal file
49
python/ray/ml/examples/upload_to_wandb.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
"""
|
||||||
|
In this example, we train a simple XGBoost model and log the training
|
||||||
|
results to Weights & Biases. We also save the resulting model checkpoints
|
||||||
|
as artifacts.
|
||||||
|
"""
|
||||||
|
import ray
|
||||||
|
|
||||||
|
from ray.ml import RunConfig
|
||||||
|
from ray.ml.result import Result
|
||||||
|
from ray.ml.train.integrations.xgboost import XGBoostTrainer
|
||||||
|
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_dataset() -> ray.data.Dataset:
|
||||||
|
"""Return the "Breast cancer" dataset as a Ray dataset."""
|
||||||
|
data_raw = load_breast_cancer(as_frame=True)
|
||||||
|
df = data_raw["data"]
|
||||||
|
df["target"] = data_raw["target"]
|
||||||
|
return ray.data.from_pandas(df)
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(train_dataset: ray.data.Dataset, wandb_project: str) -> Result:
|
||||||
|
"""Train a simple XGBoost model and return the result."""
|
||||||
|
trainer = XGBoostTrainer(
|
||||||
|
scaling_config={"num_workers": 2},
|
||||||
|
params={"tree_method": "auto"},
|
||||||
|
label_column="target",
|
||||||
|
datasets={"train": train_dataset},
|
||||||
|
num_boost_round=10,
|
||||||
|
run_config=RunConfig(
|
||||||
|
callbacks=[
|
||||||
|
# This is the part needed to enable logging to Weights & Biases.
|
||||||
|
# It assumes you've logged in before, e.g. with `wandb login`.
|
||||||
|
WandbLoggerCallback(
|
||||||
|
project=wandb_project,
|
||||||
|
save_checkpoints=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
result = trainer.fit()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
wandb_project = "ray_air_example"
|
||||||
|
|
||||||
|
train_dataset = get_train_dataset()
|
||||||
|
result = train_model(train_dataset=train_dataset, wandb_project=wandb_project)
|
|
@ -1,3 +1,4 @@
|
||||||
|
import enum
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
@ -24,7 +25,6 @@ except ImportError:
|
||||||
wandb = None
|
wandb = None
|
||||||
|
|
||||||
WANDB_ENV_VAR = "WANDB_API_KEY"
|
WANDB_ENV_VAR = "WANDB_API_KEY"
|
||||||
_WANDB_QUEUE_END = (None,)
|
|
||||||
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
|
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
|
||||||
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
|
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
|
||||||
|
|
||||||
|
@ -181,30 +181,53 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _QueueItem(enum.Enum):
|
||||||
|
END = enum.auto()
|
||||||
|
RESULT = enum.auto()
|
||||||
|
CHECKPOINT = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class _WandbLoggingProcess(Process):
|
class _WandbLoggingProcess(Process):
|
||||||
"""
|
"""
|
||||||
We need a `multiprocessing.Process` to allow multiple concurrent
|
We need a `multiprocessing.Process` to allow multiple concurrent
|
||||||
wandb logging instances locally.
|
wandb logging instances locally.
|
||||||
|
|
||||||
|
We use a queue for the driver to communicate with the logging process.
|
||||||
|
The queue accepts the following items:
|
||||||
|
|
||||||
|
- If it's a dict, it is assumed to be a result and will be logged using
|
||||||
|
``wandb.log()``
|
||||||
|
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs
|
self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs
|
||||||
):
|
):
|
||||||
super(_WandbLoggingProcess, self).__init__()
|
super(_WandbLoggingProcess, self).__init__()
|
||||||
|
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self._exclude = set(exclude)
|
self._exclude = set(exclude)
|
||||||
self._to_config = set(to_config)
|
self._to_config = set(to_config)
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
self._trial_name = self.kwargs.get("name", "unknown")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# Since we're running in a separate process already, use threads.
|
wandb.require("service")
|
||||||
wandb.init(*self.args, **self.kwargs)
|
wandb.init(*self.args, **self.kwargs)
|
||||||
|
wandb.setup()
|
||||||
while True:
|
while True:
|
||||||
result = self.queue.get()
|
item_type, item_content = self.queue.get()
|
||||||
if result == _WANDB_QUEUE_END:
|
if item_type == _QueueItem.END:
|
||||||
break
|
break
|
||||||
log, config_update = self._handle_result(result)
|
|
||||||
|
if item_type == _QueueItem.CHECKPOINT:
|
||||||
|
self._handle_checkpoint(item_content)
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert item_type == _QueueItem.RESULT
|
||||||
|
log, config_update = self._handle_result(item_content)
|
||||||
try:
|
try:
|
||||||
wandb.config.update(config_update, allow_val_change=True)
|
wandb.config.update(config_update, allow_val_change=True)
|
||||||
wandb.log(log)
|
wandb.log(log)
|
||||||
|
@ -214,6 +237,11 @@ class _WandbLoggingProcess(Process):
|
||||||
logger.warn("Failed to log result to w&b: {}".format(str(e)))
|
logger.warn("Failed to log result to w&b: {}".format(str(e)))
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|
||||||
|
def _handle_checkpoint(self, checkpoint_path: str):
|
||||||
|
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
|
||||||
|
artifact.add_dir(checkpoint_path)
|
||||||
|
wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
|
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
|
||||||
config_update = result.get("config", {}).copy()
|
config_update = result.get("config", {}).copy()
|
||||||
log = {}
|
log = {}
|
||||||
|
@ -255,6 +283,8 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
the ``results`` dict should be logged. This makes sense if
|
the ``results`` dict should be logged. This makes sense if
|
||||||
parameters will change during training, e.g. with
|
parameters will change during training, e.g. with
|
||||||
PopulationBasedTraining. Defaults to False.
|
PopulationBasedTraining. Defaults to False.
|
||||||
|
save_checkpoints: If ``True``, model checkpoints will be saved to
|
||||||
|
Wandb as artifacts. Defaults to ``False``.
|
||||||
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
|
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
|
||||||
|
|
||||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||||
|
@ -308,7 +338,8 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
excludes: Optional[List[str]] = None,
|
excludes: Optional[List[str]] = None,
|
||||||
log_config: bool = False,
|
log_config: bool = False,
|
||||||
**kwargs
|
save_checkpoints: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.project = project
|
self.project = project
|
||||||
self.group = group
|
self.group = group
|
||||||
|
@ -316,18 +347,17 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.excludes = excludes or []
|
self.excludes = excludes or []
|
||||||
self.log_config = log_config
|
self.log_config = log_config
|
||||||
|
self.save_checkpoints = save_checkpoints
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
|
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
|
||||||
self._trial_queues: Dict["Trial", Queue] = {}
|
self._trial_queues: Dict["Trial", Queue] = {}
|
||||||
|
|
||||||
def setup(self):
|
def setup(self, *args, **kwargs):
|
||||||
self.api_key_file = (
|
self.api_key_file = (
|
||||||
os.path.expanduser(self.api_key_path) if self.api_key_path else None
|
os.path.expanduser(self.api_key_path) if self.api_key_path else None
|
||||||
)
|
)
|
||||||
_set_api_key(self.api_key_file, self.api_key)
|
_set_api_key(self.api_key_file, self.api_key)
|
||||||
wandb.require("service")
|
|
||||||
wandb.setup()
|
|
||||||
|
|
||||||
def log_trial_start(self, trial: "Trial"):
|
def log_trial_start(self, trial: "Trial"):
|
||||||
config = trial.config.copy()
|
config = trial.config.copy()
|
||||||
|
@ -373,7 +403,7 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
queue=self._trial_queues[trial],
|
queue=self._trial_queues[trial],
|
||||||
exclude=exclude_results,
|
exclude=exclude_results,
|
||||||
to_config=self._config_results,
|
to_config=self._config_results,
|
||||||
**wandb_init_kwargs
|
**wandb_init_kwargs,
|
||||||
)
|
)
|
||||||
self._trial_processes[trial].start()
|
self._trial_processes[trial].start()
|
||||||
|
|
||||||
|
@ -382,10 +412,16 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
self.log_trial_start(trial)
|
self.log_trial_start(trial)
|
||||||
|
|
||||||
result = _clean_log(result)
|
result = _clean_log(result)
|
||||||
self._trial_queues[trial].put(result)
|
self._trial_queues[trial].put((_QueueItem.RESULT, result))
|
||||||
|
|
||||||
|
def log_trial_save(self, trial: "Trial"):
|
||||||
|
if self.save_checkpoints and trial.checkpoint:
|
||||||
|
self._trial_queues[trial].put(
|
||||||
|
(_QueueItem.CHECKPOINT, trial.checkpoint.value)
|
||||||
|
)
|
||||||
|
|
||||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||||
self._trial_queues[trial].put(_WANDB_QUEUE_END)
|
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||||
self._trial_processes[trial].join(timeout=10)
|
self._trial_processes[trial].join(timeout=10)
|
||||||
|
|
||||||
del self._trial_queues[trial]
|
del self._trial_queues[trial]
|
||||||
|
@ -394,7 +430,7 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
for trial in self._trial_processes:
|
for trial in self._trial_processes:
|
||||||
if trial in self._trial_queues:
|
if trial in self._trial_queues:
|
||||||
self._trial_queues[trial].put(_WANDB_QUEUE_END)
|
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||||
del self._trial_queues[trial]
|
del self._trial_queues[trial]
|
||||||
self._trial_processes[trial].join(timeout=2)
|
self._trial_processes[trial].join(timeout=2)
|
||||||
del self._trial_processes[trial]
|
del self._trial_processes[trial]
|
||||||
|
|
|
@ -11,11 +11,11 @@ from ray.tune.function_runner import wrap_function
|
||||||
from ray.tune.integration.wandb import (
|
from ray.tune.integration.wandb import (
|
||||||
WandbLoggerCallback,
|
WandbLoggerCallback,
|
||||||
_WandbLoggingProcess,
|
_WandbLoggingProcess,
|
||||||
_WANDB_QUEUE_END,
|
|
||||||
WandbLogger,
|
WandbLogger,
|
||||||
WANDB_ENV_VAR,
|
WANDB_ENV_VAR,
|
||||||
WandbTrainableMixin,
|
WandbTrainableMixin,
|
||||||
wandb_mixin,
|
wandb_mixin,
|
||||||
|
_QueueItem,
|
||||||
)
|
)
|
||||||
from ray.tune.result import TRIAL_INFO
|
from ray.tune.result import TRIAL_INFO
|
||||||
from ray.tune.trial import TrialInfo
|
from ray.tune.trial import TrialInfo
|
||||||
|
@ -52,10 +52,10 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
result = self.queue.get()
|
result_type, result_content = self.queue.get()
|
||||||
if result == _WANDB_QUEUE_END:
|
if result_type == _QueueItem.END:
|
||||||
break
|
break
|
||||||
log, config_update = self._handle_result(result)
|
log, config_update = self._handle_result(result_content)
|
||||||
self.config_updates.put(config_update)
|
self.config_updates.put(config_update)
|
||||||
self.logs.put(log)
|
self.logs.put(log)
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,7 @@ pytest-lazy-fixture
|
||||||
pytest-timeout
|
pytest-timeout
|
||||||
pytest-virtualenv
|
pytest-virtualenv
|
||||||
redis >= 3.5.0, < 4.0.0
|
redis >= 3.5.0, < 4.0.0
|
||||||
scikit-learn==0.22.2
|
scikit-learn==0.24.2
|
||||||
testfixtures
|
testfixtures
|
||||||
werkzeug
|
werkzeug
|
||||||
xlrd
|
xlrd
|
||||||
|
|
Loading…
Add table
Reference in a new issue