diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index 5600f5ff0..864219f53 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -3,7 +3,12 @@ commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT - DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh - - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu python/ray/ml/... + - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,-needs_credentials python/ray/ml/... + # Only setup credentials in branch builds + - if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then exit 0; fi + - python ./ci/env/setup_credentials.py wandb + - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-gpu,needs_credentials python/ray/ml/... + - python ./ci/env/cleanup_test_state.py wandb - label: ":brain: RLlib: Learning discr. actions TF2-static-graph" conditions: ["RAY_CI_RLLIB_AFFECTED"] diff --git a/ci/env/cleanup_test_state.py b/ci/env/cleanup_test_state.py new file mode 100644 index 000000000..0fa651f66 --- /dev/null +++ b/ci/env/cleanup_test_state.py @@ -0,0 +1,32 @@ +import sys + + +def clear_wandb_project(): + import wandb + + # This is hardcoded in the `ray/ml/examples/upload_to_wandb.py` example + wandb_project = "ray_air_example" + + api = wandb.Api() + for run in api.runs(wandb_project): + run.delete() + + +SERVICES = {"wandb": clear_wandb_project} + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} [service2] ...") + sys.exit(0) + + services = sys.argv[1:] + + if any(service not in SERVICES for service in services): + raise RuntimeError( + f"All services must be included in {list(SERVICES.keys())}. " + f"Got: {services}" + ) + + for service in services: + SERVICES[service]() diff --git a/ci/env/setup_credentials.py b/ci/env/setup_credentials.py new file mode 100644 index 000000000..aea87df28 --- /dev/null +++ b/ci/env/setup_credentials.py @@ -0,0 +1,35 @@ +import os +import sys + +import boto3 + +AWS_WANDB_SECRET_ARN = ( + "arn:aws:secretsmanager:us-west-2:029272617770:secret:oss-ci/wandb-key-V8UeE5" +) + + +def get_and_write_wandb_api_key(client): + api_key = client.get_secret_value(SecretId=AWS_WANDB_SECRET_ARN)["SecretString"] + with open(os.path.expanduser("~/.netrc"), "w") as fp: + fp.write(f"machine api.wandb.ai\n" f" login user\n" f" password {api_key}\n") + + +SERVICES = {"wandb": get_and_write_wandb_api_key} + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} [service2] ...") + sys.exit(0) + + services = sys.argv[1:] + + if any(service not in SERVICES for service in services): + raise RuntimeError( + f"All services must be included in {list(SERVICES.keys())}. " + f"Got: {services}" + ) + + client = boto3.client("secretsmanager", region_name="us-west-2") + for service in services: + SERVICES[service](client) diff --git a/python/ray/ml/BUILD b/python/ray/ml/BUILD index ea0a3a158..df7e917d8 100644 --- a/python/ray/ml/BUILD +++ b/python/ray/ml/BUILD @@ -145,6 +145,14 @@ py_test( deps = [":ml_lib"] ) +py_test( + name = "upload_to_wandb", + size = "medium", + srcs = ["examples/upload_to_wandb.py"], + tags = ["team:ml", "exclusive", "needs_credentials"], + deps = [":ml_lib"] +) + py_test( name = "xgboost_example", size = "medium", diff --git a/python/ray/ml/examples/upload_to_wandb.py b/python/ray/ml/examples/upload_to_wandb.py new file mode 100644 index 000000000..32e306a36 --- /dev/null +++ b/python/ray/ml/examples/upload_to_wandb.py @@ -0,0 +1,49 @@ +""" +In this example, we train a simple XGBoost model and log the training +results to Weights & Biases. We also save the resulting model checkpoints +as artifacts. +""" +import ray + +from ray.ml import RunConfig +from ray.ml.result import Result +from ray.ml.train.integrations.xgboost import XGBoostTrainer +from ray.tune.integration.wandb import WandbLoggerCallback +from sklearn.datasets import load_breast_cancer + + +def get_train_dataset() -> ray.data.Dataset: + """Return the "Breast cancer" dataset as a Ray dataset.""" + data_raw = load_breast_cancer(as_frame=True) + df = data_raw["data"] + df["target"] = data_raw["target"] + return ray.data.from_pandas(df) + + +def train_model(train_dataset: ray.data.Dataset, wandb_project: str) -> Result: + """Train a simple XGBoost model and return the result.""" + trainer = XGBoostTrainer( + scaling_config={"num_workers": 2}, + params={"tree_method": "auto"}, + label_column="target", + datasets={"train": train_dataset}, + num_boost_round=10, + run_config=RunConfig( + callbacks=[ + # This is the part needed to enable logging to Weights & Biases. + # It assumes you've logged in before, e.g. with `wandb login`. + WandbLoggerCallback( + project=wandb_project, + save_checkpoints=True, + ) + ] + ), + ) + result = trainer.fit() + return result + + +wandb_project = "ray_air_example" + +train_dataset = get_train_dataset() +result = train_model(train_dataset=train_dataset, wandb_project=wandb_project) diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 772b31ecd..b03d1e4f8 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -1,3 +1,4 @@ +import enum import os import pickle from collections.abc import Sequence @@ -24,7 +25,6 @@ except ImportError: wandb = None WANDB_ENV_VAR = "WANDB_API_KEY" -_WANDB_QUEUE_END = (None,) _VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image) _VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image) @@ -181,30 +181,53 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No ) +class _QueueItem(enum.Enum): + END = enum.auto() + RESULT = enum.auto() + CHECKPOINT = enum.auto() + + class _WandbLoggingProcess(Process): """ We need a `multiprocessing.Process` to allow multiple concurrent wandb logging instances locally. + + We use a queue for the driver to communicate with the logging process. + The queue accepts the following items: + + - If it's a dict, it is assumed to be a result and will be logged using + ``wandb.log()`` + - If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``. """ def __init__( self, queue: Queue, exclude: List[str], to_config: List[str], *args, **kwargs ): super(_WandbLoggingProcess, self).__init__() + self.queue = queue self._exclude = set(exclude) self._to_config = set(to_config) self.args = args self.kwargs = kwargs + self._trial_name = self.kwargs.get("name", "unknown") + def run(self): - # Since we're running in a separate process already, use threads. + wandb.require("service") wandb.init(*self.args, **self.kwargs) + wandb.setup() while True: - result = self.queue.get() - if result == _WANDB_QUEUE_END: + item_type, item_content = self.queue.get() + if item_type == _QueueItem.END: break - log, config_update = self._handle_result(result) + + if item_type == _QueueItem.CHECKPOINT: + self._handle_checkpoint(item_content) + continue + + assert item_type == _QueueItem.RESULT + log, config_update = self._handle_result(item_content) try: wandb.config.update(config_update, allow_val_change=True) wandb.log(log) @@ -214,6 +237,11 @@ class _WandbLoggingProcess(Process): logger.warn("Failed to log result to w&b: {}".format(str(e))) wandb.finish() + def _handle_checkpoint(self, checkpoint_path: str): + artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model") + artifact.add_dir(checkpoint_path) + wandb.log_artifact(artifact) + def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]: config_update = result.get("config", {}).copy() log = {} @@ -255,6 +283,8 @@ class WandbLoggerCallback(LoggerCallback): the ``results`` dict should be logged. This makes sense if parameters will change during training, e.g. with PopulationBasedTraining. Defaults to False. + save_checkpoints: If ``True``, model checkpoints will be saved to + Wandb as artifacts. Defaults to ``False``. **kwargs: The keyword arguments will be pased to ``wandb.init()``. Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected @@ -308,7 +338,8 @@ class WandbLoggerCallback(LoggerCallback): api_key: Optional[str] = None, excludes: Optional[List[str]] = None, log_config: bool = False, - **kwargs + save_checkpoints: bool = False, + **kwargs, ): self.project = project self.group = group @@ -316,18 +347,17 @@ class WandbLoggerCallback(LoggerCallback): self.api_key = api_key self.excludes = excludes or [] self.log_config = log_config + self.save_checkpoints = save_checkpoints self.kwargs = kwargs self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {} self._trial_queues: Dict["Trial", Queue] = {} - def setup(self): + def setup(self, *args, **kwargs): self.api_key_file = ( os.path.expanduser(self.api_key_path) if self.api_key_path else None ) _set_api_key(self.api_key_file, self.api_key) - wandb.require("service") - wandb.setup() def log_trial_start(self, trial: "Trial"): config = trial.config.copy() @@ -373,7 +403,7 @@ class WandbLoggerCallback(LoggerCallback): queue=self._trial_queues[trial], exclude=exclude_results, to_config=self._config_results, - **wandb_init_kwargs + **wandb_init_kwargs, ) self._trial_processes[trial].start() @@ -382,10 +412,16 @@ class WandbLoggerCallback(LoggerCallback): self.log_trial_start(trial) result = _clean_log(result) - self._trial_queues[trial].put(result) + self._trial_queues[trial].put((_QueueItem.RESULT, result)) + + def log_trial_save(self, trial: "Trial"): + if self.save_checkpoints and trial.checkpoint: + self._trial_queues[trial].put( + (_QueueItem.CHECKPOINT, trial.checkpoint.value) + ) def log_trial_end(self, trial: "Trial", failed: bool = False): - self._trial_queues[trial].put(_WANDB_QUEUE_END) + self._trial_queues[trial].put((_QueueItem.END, None)) self._trial_processes[trial].join(timeout=10) del self._trial_queues[trial] @@ -394,7 +430,7 @@ class WandbLoggerCallback(LoggerCallback): def __del__(self): for trial in self._trial_processes: if trial in self._trial_queues: - self._trial_queues[trial].put(_WANDB_QUEUE_END) + self._trial_queues[trial].put((_QueueItem.END, None)) del self._trial_queues[trial] self._trial_processes[trial].join(timeout=2) del self._trial_processes[trial] diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/tune/tests/test_integration_wandb.py index 68f050dd4..e955282ca 100644 --- a/python/ray/tune/tests/test_integration_wandb.py +++ b/python/ray/tune/tests/test_integration_wandb.py @@ -11,11 +11,11 @@ from ray.tune.function_runner import wrap_function from ray.tune.integration.wandb import ( WandbLoggerCallback, _WandbLoggingProcess, - _WANDB_QUEUE_END, WandbLogger, WANDB_ENV_VAR, WandbTrainableMixin, wandb_mixin, + _QueueItem, ) from ray.tune.result import TRIAL_INFO from ray.tune.trial import TrialInfo @@ -52,10 +52,10 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess): def run(self): while True: - result = self.queue.get() - if result == _WANDB_QUEUE_END: + result_type, result_content = self.queue.get() + if result_type == _QueueItem.END: break - log, config_update = self._handle_result(result) + log, config_update = self._handle_result(result_content) self.config_updates.put(config_update) self.logs.put(log) diff --git a/python/requirements.txt b/python/requirements.txt index abcaec8a1..0240df495 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -83,7 +83,7 @@ pytest-lazy-fixture pytest-timeout pytest-virtualenv redis >= 3.5.0, < 4.0.0 -scikit-learn==0.22.2 +scikit-learn==0.24.2 testfixtures werkzeug xlrd