[AIR] Move integration logging callbacks to AIR (#26126)

As the integration logging callbacks are commonly used with AIR Trainers, they should be moved from the tune package to the air package. The old imports will still work, but raise a deprecation warning.
This commit is contained in:
Antoni Baum 2022-06-28 17:25:19 -07:00 committed by GitHub
parent c9be251b7a
commit 128f9e5664
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 1034 additions and 874 deletions

View file

@ -799,7 +799,7 @@
"source": [
"from ray.train.huggingface import HuggingFaceTrainer\n",
"from ray.air import RunConfig\n",
"from ray.tune.integration.mlflow import MLflowLoggerCallback\n",
"from ray.air.callbacks.mlflow import MLflowLoggerCallback\n",
"\n",
"trainer = HuggingFaceTrainer(\n",
" trainer_init_per_worker=trainer_init_per_worker,\n",

View file

@ -49,7 +49,7 @@
"from ray.air import RunConfig\n",
"from ray.air.result import Result\n",
"from ray.train.xgboost import XGBoostTrainer\n",
"from ray.tune.integration.comet import CometLoggerCallback"
"from ray.air.callbacks.comet import CometLoggerCallback"
]
},
{

View file

@ -49,7 +49,7 @@
"from ray.air import RunConfig\n",
"from ray.air.result import Result\n",
"from ray.train.xgboost import XGBoostTrainer\n",
"from ray.tune.integration.wandb import WandbLoggerCallback"
"from ray.air.callbacks.wandb import WandbLoggerCallback"
]
},
{

View file

@ -25,7 +25,7 @@ MLflow (tune.integration.mlflow)
:ref:`See also here <tune-mlflow-ref>`.
.. autoclass:: ray.tune.integration.mlflow.MLflowLoggerCallback
.. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback
.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin
@ -56,7 +56,7 @@ Weights and Biases (tune.integration.wandb)
:ref:`See also here <tune-wandb-ref>`.
.. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback
.. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback
.. autofunction:: ray.tune.integration.wandb.wandb_mixin

View file

@ -80,7 +80,7 @@
"source": [
"# This cell is hidden from the rendered notebook. It makes the \n",
"from unittest.mock import MagicMock\n",
"from ray.tune.integration.comet import CometLoggerCallback\n",
"from ray.air.callbacks.comet import CometLoggerCallback\n",
"\n",
"CometLoggerCallback._logger_process_cls = MagicMock\n",
"api_key = \"abc\"\n",
@ -103,7 +103,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ray.tune.integration.comet import CometLoggerCallback\n",
"from ray.air.callbacks.comet import CometLoggerCallback\n",
"\n",
"analysis = tune.run(\n",
" train_function,\n",
@ -133,7 +133,7 @@
"Click on the following dropdown to see this callback API in detail:\n",
"\n",
"```{eval-rst}\n",
".. autoclass:: ray.tune.integration.comet.CometLoggerCallback\n",
".. autoclass:: ray.air.callbacks.comet.CometLoggerCallback\n",
" :noindex:\n",
"```"
]

View file

@ -52,7 +52,11 @@
"cell_type": "code",
"execution_count": null,
"id": "b0e47339",
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import os\n",
@ -62,25 +66,35 @@
"import mlflow\n",
"\n",
"from ray import tune\n",
"from ray.tune.integration.mlflow import MLflowLoggerCallback, mlflow_mixin"
"from ray.air.callbacks.mlflow import MLflowLoggerCallback\n",
"from ray.tune.integration.mlflow import mlflow_mixin"
]
},
{
"cell_type": "markdown",
"source": [
"Next, let's define an easy objective function (a Tune `Trainable`) that iteratively computes steps and evaluates\n",
"intermediate scores that we report to Tune."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
"source": [
"Next, let's define an easy objective function (a Tune `Trainable`) that iteratively computes steps and evaluates\n",
"intermediate scores that we report to Tune."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def evaluation_fn(step, width, height):\n",
@ -96,30 +110,33 @@
" # Feed the score back to Tune.\n",
" tune.report(iterations=step, mean_loss=intermediate_score)\n",
" time.sleep(0.1)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"Given an MLFlow tracking URI, you can now simply use the `MLflowLoggerCallback` as a `callback` argument to\n",
"your `tune.run()` call:"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
"source": [
"Given an MLFlow tracking URI, you can now simply use the `MLflowLoggerCallback` as a `callback` argument to\n",
"your `tune.run()` call:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def tune_function(mlflow_tracking_uri, finish_fast=False):\n",
@ -140,28 +157,31 @@
" \"steps\": 5 if finish_fast else 100,\n",
" },\n",
" )"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"To use the `mlflow_mixin` decorator, you can simply decorate the objective function from earlier.\n",
"Note that we also use `mlflow.log_metrics(...)` to log metrics to MLflow.\n",
"Otherwise, the decorated version of our objective is identical to its original."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"@mlflow_mixin\n",
@ -177,26 +197,29 @@
" # Feed the score back to Tune.\n",
" tune.report(iterations=step, mean_loss=intermediate_score)\n",
" time.sleep(0.1)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"With this new objective function ready, you can now create a Tune run with it as follows:"
],
"metadata": {
"collapsed": false
}
},
"source": [
"With this new objective function ready, you can now create a Tune run with it as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def tune_decorated(mlflow_tracking_uri, finish_fast=False):\n",
@ -217,28 +240,31 @@
" },\n",
" },\n",
" )"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"If you hapen to have an MLFlow tracking URI, you can set it below in the `mlflow_tracking_uri` variable and set\n",
"`smoke_test=False`.\n",
"Otherwise, you can just run a quick test of the `tune_function` and `tune_decorated` functions without using MLflow."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"smoke_test = True\n",
@ -261,13 +287,7 @@
" [mlflow.get_experiment_by_name(\"mixin_example\").experiment_id]\n",
" )\n",
" print(df)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
@ -287,7 +307,7 @@
"(tune-mlflow-logger)=\n",
"\n",
"```{eval-rst}\n",
".. autoclass:: ray.tune.integration.mlflow.MLflowLoggerCallback\n",
".. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback\n",
" :noindex:\n",
"```\n",
"\n",
@ -317,4 +337,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -47,7 +47,11 @@
"cell_type": "code",
"execution_count": null,
"id": "100bcf8a",
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
@ -55,8 +59,8 @@
"\n",
"from ray import tune\n",
"from ray.tune import Trainable\n",
"from ray.air.callbacks.wandb import WandbLoggerCallback\n",
"from ray.tune.integration.wandb import (\n",
" WandbLoggerCallback,\n",
" WandbTrainableMixin,\n",
" wandb_mixin,\n",
")"
@ -64,45 +68,57 @@
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Next, let's define an easy `objective` function (a Tune `Trainable`) that reports a random loss to Tune.\n",
"The objective function itself is not important for this example, since we want to focus on the Weights & Biases\n",
"integration primarily."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def objective(config, checkpoint_dir=None):\n",
" for i in range(30):\n",
" loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
" tune.report(loss=loss)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Given that you provide an `api_key_file` pointing to your Weights & Biases API key, you cna define a\n",
"simple grid-search Tune run using the `WandbLoggerCallback` as follows:"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def tune_function(api_key_file):\n",
@ -120,28 +136,31 @@
" ],\n",
" )\n",
" return analysis.best_config"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"To use the `wandb_mixin` decorator, you can simply decorate the objective function from earlier.\n",
"Note that we also use `wandb.log(...)` to log the `loss` to Weights & Biases as a dictionary.\n",
"Otherwise, the decorated version of our objective is identical to its original."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"@wandb_mixin\n",
@ -150,27 +169,30 @@
" loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
" tune.report(loss=loss)\n",
" wandb.log(dict(loss=loss))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"With the `decorated_objective` defined, running a Tune experiment is as simple as providing this objective and\n",
"passing the `api_key_file` to the `wandb` key of your Tune `config`:"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def tune_decorated(api_key_file):\n",
@ -186,26 +208,29 @@
" },\n",
" )\n",
" return analysis.best_config"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"Finally, you can also define a class-based Tune `Trainable` by using the `WandbTrainableMixin` to define your objective:"
],
"metadata": {
"collapsed": false
}
},
"source": [
"Finally, you can also define a class-based Tune `Trainable` by using the `WandbTrainableMixin` to define your objective:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class WandbTrainable(WandbTrainableMixin, Trainable):\n",
@ -214,28 +239,31 @@
" loss = self.config[\"mean\"] + self.config[\"sd\"] * np.random.randn()\n",
" wandb.log({\"loss\": loss})\n",
" return {\"loss\": loss, \"done\": True}"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Running Tune with this `WandbTrainable` works exactly the same as with the function API.\n",
"The below `tune_trainable` function differs from `tune_decorated` above only in the first argument we pass to\n",
"`tune.run()`:"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def tune_trainable(api_key_file):\n",
@ -251,28 +279,31 @@
" },\n",
" )\n",
" return analysis.best_config"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Since you may not have an API key for Wandb, we can _mock_ the Wandb logger and test all three of our training\n",
"functions as follows.\n",
"If you do have an API key file, make sure to set `mock_api` to `False` and pass in the right `api_key_file` below."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import tempfile\n",
@ -298,13 +329,7 @@
"\n",
"if mock_api:\n",
" temp_file.close()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
},
{
"cell_type": "markdown",
@ -321,7 +346,7 @@
"(tune-wandb-logger)=\n",
"\n",
"```{eval-rst}\n",
".. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback\n",
".. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback\n",
" :noindex:\n",
"```\n",
"\n",
@ -346,4 +371,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -0,0 +1,249 @@
import os
from typing import Dict, List
from ray.tune.logger import LoggerCallback
from ray.tune.experiment import Trial
from ray.tune.utils import flatten_dict
def _import_comet():
"""Try importing comet_ml.
Used to check if comet_ml is installed and, otherwise, pass an informative
error message.
"""
if "COMET_DISABLE_AUTO_LOGGING" not in os.environ:
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
try:
import comet_ml # noqa: F401
except ImportError:
raise RuntimeError("pip install 'comet-ml' to use CometLoggerCallback")
return comet_ml
class CometLoggerCallback(LoggerCallback):
"""CometLoggerCallback for logging Tune results to Comet.
Comet (https://comet.ml/site/) is a tool to manage and optimize the
entire ML lifecycle, from experiment tracking, model optimization
and dataset versioning to model production monitoring.
This Ray Tune ``LoggerCallback`` sends metrics and parameters to
Comet for tracking.
In order to use the CometLoggerCallback you must first install Comet
via ``pip install comet_ml``
Then set the following environment variables
``export COMET_API_KEY=<Your API Key>``
Alternatively, you can also pass in your API Key as an argument to the
CometLoggerCallback constructor.
``CometLoggerCallback(api_key=<Your API Key>)``
Args:
online: Whether to make use of an Online or
Offline Experiment. Defaults to True.
tags: Tags to add to the logged Experiment.
Defaults to None.
save_checkpoints: If ``True``, model checkpoints will be saved to
Comet ML as artifacts. Defaults to ``False``.
**experiment_kwargs: Other keyword arguments will be passed to the
constructor for comet_ml.Experiment (or OfflineExperiment if
online=False).
Please consult the Comet ML documentation for more information on the
Experiment and OfflineExperiment classes: https://comet.ml/site/
Example:
.. code-block:: python
from ray.air.callbacks.comet import CometLoggerCallback
tune.run(
train,
config=config
callbacks=[CometLoggerCallback(
True,
['tag1', 'tag2'],
workspace='my_workspace',
project_name='my_project_name'
)]
)
"""
# Do not enable these auto log options unless overridden
_exclude_autolog = [
"auto_output_logging",
"log_git_metadata",
"log_git_patch",
"log_env_cpu",
"log_env_gpu",
]
# Do not log these metrics.
_exclude_results = ["done", "should_checkpoint"]
# These values should be logged as system info instead of metrics.
_system_results = ["node_ip", "hostname", "pid", "date"]
# These values should be logged as "Other" instead of as metrics.
_other_results = ["trial_id", "experiment_id", "experiment_tag"]
_episode_results = ["hist_stats/episode_reward", "hist_stats/episode_lengths"]
def __init__(
self,
online: bool = True,
tags: List[str] = None,
save_checkpoints: bool = False,
**experiment_kwargs,
):
_import_comet()
self.online = online
self.tags = tags
self.save_checkpoints = save_checkpoints
self.experiment_kwargs = experiment_kwargs
# Disable the specific autologging features that cause throttling.
self._configure_experiment_defaults()
# Mapping from trial to experiment object.
self._trial_experiments = {}
self._to_exclude = self._exclude_results.copy()
self._to_system = self._system_results.copy()
self._to_other = self._other_results.copy()
self._to_episodes = self._episode_results.copy()
def _configure_experiment_defaults(self):
"""Disable the specific autologging features that cause throttling."""
for option in self._exclude_autolog:
if not self.experiment_kwargs.get(option):
self.experiment_kwargs[option] = False
def _check_key_name(self, key: str, item: str) -> bool:
"""
Check if key argument is equal to item argument or starts with item and
a forward slash. Used for parsing trial result dictionary into ignored
keys, system metrics, episode logs, etc.
"""
return key.startswith(item + "/") or key == item
def log_trial_start(self, trial: "Trial"):
"""
Initialize an Experiment (or OfflineExperiment if self.online=False)
and start logging to Comet.
Args:
trial: Trial object.
"""
_import_comet() # is this necessary?
from comet_ml import Experiment, OfflineExperiment
from comet_ml.config import set_global_experiment
if trial not in self._trial_experiments:
experiment_cls = Experiment if self.online else OfflineExperiment
experiment = experiment_cls(**self.experiment_kwargs)
self._trial_experiments[trial] = experiment
# Set global experiment to None to allow for multiple experiments.
set_global_experiment(None)
else:
experiment = self._trial_experiments[trial]
experiment.set_name(str(trial))
experiment.add_tags(self.tags)
experiment.log_other("Created from", "Ray")
config = trial.config.copy()
config.pop("callbacks", None)
experiment.log_parameters(config)
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
"""
Log the current result of a Trial upon each iteration.
"""
if trial not in self._trial_experiments:
self.log_trial_start(trial)
experiment = self._trial_experiments[trial]
step = result["training_iteration"]
config_update = result.pop("config", {}).copy()
config_update.pop("callbacks", None) # Remove callbacks
for k, v in config_update.items():
if isinstance(v, dict):
experiment.log_parameters(flatten_dict({k: v}, "/"), step=step)
else:
experiment.log_parameter(k, v, step=step)
other_logs = {}
metric_logs = {}
system_logs = {}
episode_logs = {}
flat_result = flatten_dict(result, delimiter="/")
for k, v in flat_result.items():
if any(self._check_key_name(k, item) for item in self._to_exclude):
continue
if any(self._check_key_name(k, item) for item in self._to_other):
other_logs[k] = v
elif any(self._check_key_name(k, item) for item in self._to_system):
system_logs[k] = v
elif any(self._check_key_name(k, item) for item in self._to_episodes):
episode_logs[k] = v
else:
metric_logs[k] = v
experiment.log_others(other_logs)
experiment.log_metrics(metric_logs, step=step)
for k, v in system_logs.items():
experiment.log_system_info(k, v)
for k, v in episode_logs.items():
experiment.log_curve(k, x=range(len(v)), y=v, step=step)
def log_trial_save(self, trial: "Trial"):
comet_ml = _import_comet()
if self.save_checkpoints and trial.checkpoint:
experiment = self._trial_experiments[trial]
artifact = comet_ml.Artifact(
name=f"checkpoint_{(str(trial))}", artifact_type="model"
)
# Walk through checkpoint directory and add all files to artifact
checkpoint_root = trial.checkpoint.dir_or_data
for root, dirs, files in os.walk(checkpoint_root):
rel_root = os.path.relpath(root, checkpoint_root)
for file in files:
local_file = os.path.join(checkpoint_root, rel_root, file)
logical_path = os.path.join(rel_root, file)
# Strip leading `./`
if logical_path.startswith("./"):
logical_path = logical_path[2:]
artifact.add(local_file, logical_path=logical_path)
experiment.log_artifact(artifact)
def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_experiments[trial].end()
del self._trial_experiments[trial]
def __del__(self):
for trial, experiment in self._trial_experiments.items():
experiment.end()
self._trial_experiments = {}

View file

@ -0,0 +1,134 @@
import logging
from typing import Dict, Optional
import ray
from ray.tune.logger import LoggerCallback
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
from ray.tune.experiment import Trial
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
logger = logging.getLogger(__name__)
class MLflowLoggerCallback(LoggerCallback):
"""MLflow Logger to automatically log Tune results and config to MLflow.
MLflow (https://mlflow.org) Tracking is an open source library for
recording and querying experiments. This Ray Tune ``LoggerCallback``
sends information (config parameters, training results & metrics,
and artifacts) to MLflow for automatic experiment tracking.
Args:
tracking_uri: The tracking URI for where to manage experiments
and runs. This can either be a local file path or a remote server.
This arg gets passed directly to mlflow
initialization. When using Tune in a multi-node setting, make sure
to set this to a remote server and not a local file path.
registry_uri: The registry URI that gets passed directly to
mlflow initialization.
experiment_name: The experiment name to use for this Tune run.
If the experiment with the name already exists with MLflow,
it will be reused. If not, a new experiment will be created with
that name.
tags: An optional dictionary of string keys and values to set
as tags on the run
save_artifact: If set to True, automatically save the entire
contents of the Tune local_dir as an artifact to the
corresponding run in MlFlow.
Example:
.. code-block:: python
from ray.air.callbacks.mlflow import MLflowLoggerCallback
tags = { "user_name" : "John",
"git_commit_hash" : "abc123"}
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
},
callbacks=[MLflowLoggerCallback(
experiment_name="experiment1",
tags=tags,
save_artifact=True)])
"""
def __init__(
self,
tracking_uri: Optional[str] = None,
registry_uri: Optional[str] = None,
experiment_name: Optional[str] = None,
tags: Optional[Dict] = None,
save_artifact: bool = False,
):
self.tracking_uri = tracking_uri
self.registry_uri = registry_uri
self.experiment_name = experiment_name
self.tags = tags
self.should_save_artifact = save_artifact
self.mlflow_util = _MLflowLoggerUtil()
if ray.util.client.ray.is_connected():
logger.warning(
"When using MLflowLoggerCallback with Ray Client, "
"it is recommended to use a remote tracking "
"server. If you are using a MLflow tracking server "
"backed by the local filesystem, then it must be "
"setup on the server side and not on the client "
"side."
)
def setup(self, *args, **kwargs):
# Setup the mlflow logging util.
self.mlflow_util.setup_mlflow(
tracking_uri=self.tracking_uri,
registry_uri=self.registry_uri,
experiment_name=self.experiment_name,
)
if self.tags is None:
# Create empty dictionary for tags if not given explicitly
self.tags = {}
self._trial_runs = {}
def log_trial_start(self, trial: "Trial"):
# Create run if not already exists.
if trial not in self._trial_runs:
# Set trial name in tags
tags = self.tags.copy()
tags["trial_name"] = str(trial)
run = self.mlflow_util.start_run(tags=tags, run_name=str(trial))
self._trial_runs[trial] = run.info.run_id
run_id = self._trial_runs[trial]
# Log the config parameters.
config = trial.config
self.mlflow_util.log_params(run_id=run_id, params_to_log=config)
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
run_id = self._trial_runs[trial]
self.mlflow_util.log_metrics(run_id=run_id, metrics_to_log=result, step=step)
def log_trial_end(self, trial: "Trial", failed: bool = False):
run_id = self._trial_runs[trial]
# Log the artifact if set_artifact is set to True.
if self.should_save_artifact:
self.mlflow_util.save_artifacts(run_id=run_id, dir=trial.logdir)
# Stop the run once trial finishes.
status = "FINISHED" if not failed else "FAILED"
self.mlflow_util.end_run(run_id=run_id, status=status)

View file

@ -0,0 +1,370 @@
import enum
import os
import pickle
from collections.abc import Sequence
from multiprocessing import Process, Queue
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import urllib
from ray import logger
from ray.tune.logger import LoggerCallback
from ray.tune.utils import flatten_dict
from ray.tune.experiment import Trial
import yaml
try:
import wandb
except ImportError:
logger.error("pip install 'wandb' to use WandbLoggerCallback/WandbTrainableMixin.")
wandb = None
WANDB_ENV_VAR = "WANDB_API_KEY"
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
def _is_allowed_type(obj):
"""Return True if type is allowed for logging to wandb"""
if isinstance(obj, np.ndarray) and obj.size == 1:
return isinstance(obj.item(), Number)
if isinstance(obj, Sequence) and len(obj) > 0:
return isinstance(obj[0], _VALID_ITERABLE_TYPES)
return isinstance(obj, _VALID_TYPES)
def _clean_log(obj: Any):
# Fixes https://github.com/ray-project/ray/issues/10631
if isinstance(obj, dict):
return {k: _clean_log(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_clean_log(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(_clean_log(v) for v in obj)
elif _is_allowed_type(obj):
return obj
# Else
try:
pickle.dumps(obj)
yaml.dump(
obj,
Dumper=yaml.SafeDumper,
default_flow_style=False,
allow_unicode=True,
encoding="utf-8",
)
return obj
except Exception:
# give up, similar to _SafeFallBackEncoder
fallback = str(obj)
# Try to convert to int
try:
fallback = int(fallback)
return fallback
except ValueError:
pass
# Try to convert to float
try:
fallback = float(fallback)
return fallback
except ValueError:
pass
# Else, return string
return fallback
def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
"""Set WandB API key from `wandb_config`. Will pop the
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
if api_key_file:
if api_key:
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
with open(api_key_file, "rt") as fp:
api_key = fp.readline().strip()
if api_key:
os.environ[WANDB_ENV_VAR] = api_key
elif not os.environ.get(WANDB_ENV_VAR):
try:
# Check if user is already logged into wandb.
wandb.ensure_configured()
if wandb.api.api_key:
logger.info("Already logged into W&B.")
return
except AttributeError:
pass
raise ValueError(
"No WandB API key found. Either set the {} environment "
"variable, pass `api_key` or `api_key_file` to the"
"`WandbLoggerCallback` class as arguments, "
"or run `wandb login` from the command line".format(WANDB_ENV_VAR)
)
class _QueueItem(enum.Enum):
END = enum.auto()
RESULT = enum.auto()
CHECKPOINT = enum.auto()
class _WandbLoggingProcess(Process):
"""
We need a `multiprocessing.Process` to allow multiple concurrent
wandb logging instances locally.
We use a queue for the driver to communicate with the logging process.
The queue accepts the following items:
- If it's a dict, it is assumed to be a result and will be logged using
``wandb.log()``
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
"""
def __init__(
self,
logdir: str,
queue: Queue,
exclude: List[str],
to_config: List[str],
*args,
**kwargs,
):
super(_WandbLoggingProcess, self).__init__()
os.chdir(logdir)
self.queue = queue
self._exclude = set(exclude)
self._to_config = set(to_config)
self.args = args
self.kwargs = kwargs
self._trial_name = self.kwargs.get("name", "unknown")
def run(self):
# Since we're running in a separate process already, use threads.
os.environ["WANDB_START_METHOD"] = "thread"
wandb.init(*self.args, **self.kwargs)
while True:
item_type, item_content = self.queue.get()
if item_type == _QueueItem.END:
break
if item_type == _QueueItem.CHECKPOINT:
self._handle_checkpoint(item_content)
continue
assert item_type == _QueueItem.RESULT
log, config_update = self._handle_result(item_content)
try:
wandb.config.update(config_update, allow_val_change=True)
wandb.log(log)
except urllib.error.HTTPError as e:
# Ignore HTTPError. Missing a few data points is not a
# big issue, as long as things eventually recover.
logger.warn("Failed to log result to w&b: {}".format(str(e)))
wandb.finish()
def _handle_checkpoint(self, checkpoint_path: str):
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
artifact.add_dir(checkpoint_path)
wandb.log_artifact(artifact)
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
config_update = result.get("config", {}).copy()
log = {}
flat_result = flatten_dict(result, delimiter="/")
for k, v in flat_result.items():
if any(k.startswith(item + "/") or k == item for item in self._to_config):
config_update[k] = v
elif any(k.startswith(item + "/") or k == item for item in self._exclude):
continue
elif not _is_allowed_type(v):
continue
else:
log[k] = v
config_update.pop("callbacks", None) # Remove callbacks
return log, config_update
class WandbLoggerCallback(LoggerCallback):
"""WandbLoggerCallback
Weights and biases (https://www.wandb.ai/) is a tool for experiment
tracking, model optimization, and dataset versioning. This Ray Tune
``LoggerCallback`` sends metrics to Wandb for automatic tracking and
visualization.
Args:
project: Name of the Wandb project. Mandatory.
group: Name of the Wandb group. Defaults to the trainable
name.
api_key_file: Path to file containing the Wandb API KEY. This
file only needs to be present on the node running the Tune script
if using the WandbLogger.
api_key: Wandb API Key. Alternative to setting ``api_key_file``.
excludes: List of metrics that should be excluded from
the log.
log_config: Boolean indicating if the ``config`` parameter of
the ``results`` dict should be logged. This makes sense if
parameters will change during training, e.g. with
PopulationBasedTraining. Defaults to False.
save_checkpoints: If ``True``, model checkpoints will be saved to
Wandb as artifacts. Defaults to ``False``.
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
by Tune, but can be overwritten by filling out the respective configuration
values.
Please see here for all other valid configuration settings:
https://docs.wandb.ai/library/init
Example:
.. code-block:: python
from ray.tune.logger import DEFAULT_LOGGERS
from ray.air.callbacks.wandb import WandbLoggerCallback
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
},
callbacks=[WandbLoggerCallback(
project="Optimization_Project",
api_key_file="/path/to/file",
log_config=True)])
"""
# Do not log these result keys
_exclude_results = ["done", "should_checkpoint"]
# Use these result keys to update `wandb.config`
_config_results = [
"trial_id",
"experiment_tag",
"node_ip",
"experiment_id",
"hostname",
"pid",
"date",
]
_logger_process_cls = _WandbLoggingProcess
def __init__(
self,
project: str,
group: Optional[str] = None,
api_key_file: Optional[str] = None,
api_key: Optional[str] = None,
excludes: Optional[List[str]] = None,
log_config: bool = False,
save_checkpoints: bool = False,
**kwargs,
):
self.project = project
self.group = group
self.api_key_path = api_key_file
self.api_key = api_key
self.excludes = excludes or []
self.log_config = log_config
self.save_checkpoints = save_checkpoints
self.kwargs = kwargs
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
self._trial_queues: Dict["Trial", Queue] = {}
def setup(self, *args, **kwargs):
self.api_key_file = (
os.path.expanduser(self.api_key_path) if self.api_key_path else None
)
_set_api_key(self.api_key_file, self.api_key)
def log_trial_start(self, trial: "Trial"):
config = trial.config.copy()
config.pop("callbacks", None) # Remove callbacks
exclude_results = self._exclude_results.copy()
# Additional excludes
exclude_results += self.excludes
# Log config keys on each result?
if not self.log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = trial.trial_id if trial else None
trial_name = str(trial) if trial else None
# Project name for Wandb
wandb_project = self.project
# Grouping
wandb_group = self.group or trial.trainable_name if trial else None
# remove unpickleable items!
config = _clean_log(config)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
resume=False,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config,
)
wandb_init_kwargs.update(self.kwargs)
self._trial_queues[trial] = Queue()
self._trial_processes[trial] = self._logger_process_cls(
logdir=trial.logdir,
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs,
)
self._trial_processes[trial].start()
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_processes:
self.log_trial_start(trial)
result = _clean_log(result)
self._trial_queues[trial].put((_QueueItem.RESULT, result))
def log_trial_save(self, trial: "Trial"):
if self.save_checkpoints and trial.checkpoint:
self._trial_queues[trial].put(
(_QueueItem.CHECKPOINT, trial.checkpoint.dir_or_data)
)
def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_queues[trial].put((_QueueItem.END, None))
self._trial_processes[trial].join(timeout=10)
del self._trial_queues[trial]
del self._trial_processes[trial]
def __del__(self):
for trial in self._trial_processes:
if trial in self._trial_queues:
self._trial_queues[trial].put((_QueueItem.END, None))
del self._trial_queues[trial]
self._trial_processes[trial].join(timeout=2)
del self._trial_processes[trial]

View file

@ -8,7 +8,8 @@ import time
import mlflow
from ray import tune
from ray.tune.integration.mlflow import MLflowLoggerCallback, mlflow_mixin
from ray.air.callbacks.mlflow import MLflowLoggerCallback
from ray.tune.integration.mlflow import mlflow_mixin
def evaluation_fn(step, width, height):

View file

@ -7,8 +7,8 @@ import wandb
from ray import tune
from ray.tune import Trainable
from ray.air.callbacks.wandb import WandbLoggerCallback
from ray.tune.integration.wandb import (
WandbLoggerCallback,
WandbTrainableMixin,
wandb_mixin,
)

View file

@ -1,249 +1,28 @@
import os
from typing import Dict, List
from ray.air.callbacks.comet import CometLoggerCallback as _CometLoggerCallback
from typing import List
from ray.tune.logger import LoggerCallback
from ray.tune.experiment import Trial
from ray.tune.utils import flatten_dict
import logging
from ray.util.annotations import Deprecated
logger = logging.getLogger(__name__)
callback_deprecation_message = (
"`ray.tune.integration.comet.CometLoggerCallback` "
"is deprecated and will be removed in "
"the future. Please use `ray.air.callbacks.comet.CometLoggerCallback` "
"instead."
)
def _import_comet():
"""Try importing comet_ml.
Used to check if comet_ml is installed and, otherwise, pass an informative
error message.
"""
if "COMET_DISABLE_AUTO_LOGGING" not in os.environ:
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
try:
import comet_ml # noqa: F401
except ImportError:
raise RuntimeError("pip install 'comet-ml' to use CometLoggerCallback")
return comet_ml
class CometLoggerCallback(LoggerCallback):
"""CometLoggerCallback for logging Tune results to Comet.
Comet (https://comet.ml/site/) is a tool to manage and optimize the
entire ML lifecycle, from experiment tracking, model optimization
and dataset versioning to model production monitoring.
This Ray Tune ``LoggerCallback`` sends metrics and parameters to
Comet for tracking.
In order to use the CometLoggerCallback you must first install Comet
via ``pip install comet_ml``
Then set the following environment variables
``export COMET_API_KEY=<Your API Key>``
Alternatively, you can also pass in your API Key as an argument to the
CometLoggerCallback constructor.
``CometLoggerCallback(api_key=<Your API Key>)``
Args:
online: Whether to make use of an Online or
Offline Experiment. Defaults to True.
tags: Tags to add to the logged Experiment.
Defaults to None.
save_checkpoints: If ``True``, model checkpoints will be saved to
Comet ML as artifacts. Defaults to ``False``.
**experiment_kwargs: Other keyword arguments will be passed to the
constructor for comet_ml.Experiment (or OfflineExperiment if
online=False).
Please consult the Comet ML documentation for more information on the
Experiment and OfflineExperiment classes: https://comet.ml/site/
Example:
.. code-block:: python
from ray.tune.integration.comet import CometLoggerCallback
tune.run(
train,
config=config
callbacks=[CometLoggerCallback(
True,
['tag1', 'tag2'],
workspace='my_workspace',
project_name='my_project_name'
)]
)
"""
# Do not enable these auto log options unless overridden
_exclude_autolog = [
"auto_output_logging",
"log_git_metadata",
"log_git_patch",
"log_env_cpu",
"log_env_gpu",
]
# Do not log these metrics.
_exclude_results = ["done", "should_checkpoint"]
# These values should be logged as system info instead of metrics.
_system_results = ["node_ip", "hostname", "pid", "date"]
# These values should be logged as "Other" instead of as metrics.
_other_results = ["trial_id", "experiment_id", "experiment_tag"]
_episode_results = ["hist_stats/episode_reward", "hist_stats/episode_lengths"]
@Deprecated(message=callback_deprecation_message)
class CometLoggerCallback(_CometLoggerCallback):
def __init__(
self,
online: bool = True,
tags: List[str] = None,
save_checkpoints: bool = False,
**experiment_kwargs,
**experiment_kwargs
):
_import_comet()
self.online = online
self.tags = tags
self.save_checkpoints = save_checkpoints
self.experiment_kwargs = experiment_kwargs
# Disable the specific autologging features that cause throttling.
self._configure_experiment_defaults()
# Mapping from trial to experiment object.
self._trial_experiments = {}
self._to_exclude = self._exclude_results.copy()
self._to_system = self._system_results.copy()
self._to_other = self._other_results.copy()
self._to_episodes = self._episode_results.copy()
def _configure_experiment_defaults(self):
"""Disable the specific autologging features that cause throttling."""
for option in self._exclude_autolog:
if not self.experiment_kwargs.get(option):
self.experiment_kwargs[option] = False
def _check_key_name(self, key: str, item: str) -> bool:
"""
Check if key argument is equal to item argument or starts with item and
a forward slash. Used for parsing trial result dictionary into ignored
keys, system metrics, episode logs, etc.
"""
return key.startswith(item + "/") or key == item
def log_trial_start(self, trial: "Trial"):
"""
Initialize an Experiment (or OfflineExperiment if self.online=False)
and start logging to Comet.
Args:
trial: Trial object.
"""
_import_comet() # is this necessary?
from comet_ml import Experiment, OfflineExperiment
from comet_ml.config import set_global_experiment
if trial not in self._trial_experiments:
experiment_cls = Experiment if self.online else OfflineExperiment
experiment = experiment_cls(**self.experiment_kwargs)
self._trial_experiments[trial] = experiment
# Set global experiment to None to allow for multiple experiments.
set_global_experiment(None)
else:
experiment = self._trial_experiments[trial]
experiment.set_name(str(trial))
experiment.add_tags(self.tags)
experiment.log_other("Created from", "Ray")
config = trial.config.copy()
config.pop("callbacks", None)
experiment.log_parameters(config)
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
"""
Log the current result of a Trial upon each iteration.
"""
if trial not in self._trial_experiments:
self.log_trial_start(trial)
experiment = self._trial_experiments[trial]
step = result["training_iteration"]
config_update = result.pop("config", {}).copy()
config_update.pop("callbacks", None) # Remove callbacks
for k, v in config_update.items():
if isinstance(v, dict):
experiment.log_parameters(flatten_dict({k: v}, "/"), step=step)
else:
experiment.log_parameter(k, v, step=step)
other_logs = {}
metric_logs = {}
system_logs = {}
episode_logs = {}
flat_result = flatten_dict(result, delimiter="/")
for k, v in flat_result.items():
if any(self._check_key_name(k, item) for item in self._to_exclude):
continue
if any(self._check_key_name(k, item) for item in self._to_other):
other_logs[k] = v
elif any(self._check_key_name(k, item) for item in self._to_system):
system_logs[k] = v
elif any(self._check_key_name(k, item) for item in self._to_episodes):
episode_logs[k] = v
else:
metric_logs[k] = v
experiment.log_others(other_logs)
experiment.log_metrics(metric_logs, step=step)
for k, v in system_logs.items():
experiment.log_system_info(k, v)
for k, v in episode_logs.items():
experiment.log_curve(k, x=range(len(v)), y=v, step=step)
def log_trial_save(self, trial: "Trial"):
comet_ml = _import_comet()
if self.save_checkpoints and trial.checkpoint:
experiment = self._trial_experiments[trial]
artifact = comet_ml.Artifact(
name=f"checkpoint_{(str(trial))}", artifact_type="model"
)
# Walk through checkpoint directory and add all files to artifact
checkpoint_root = trial.checkpoint.dir_or_data
for root, dirs, files in os.walk(checkpoint_root):
rel_root = os.path.relpath(root, checkpoint_root)
for file in files:
local_file = os.path.join(checkpoint_root, rel_root, file)
logical_path = os.path.join(rel_root, file)
# Strip leading `./`
if logical_path.startswith("./"):
logical_path = logical_path[2:]
artifact.add(local_file, logical_path=logical_path)
experiment.log_artifact(artifact)
def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_experiments[trial].end()
del self._trial_experiments[trial]
def __del__(self):
for trial, experiment in self._trial_experiments.items():
experiment.end()
self._trial_experiments = {}
logging.warning(callback_deprecation_message)
super().__init__(online, tags, save_checkpoints, **experiment_kwargs)

View file

@ -1,65 +1,25 @@
from ray.air.callbacks.mlflow import MLflowLoggerCallback as _MLflowLoggerCallback
import logging
from typing import Callable, Dict, Optional
import ray
from ray.tune.logger import LoggerCallback
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
from ray.tune.trainable import Trainable
from ray.tune.experiment import Trial
from ray.util.annotations import Deprecated
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
logger = logging.getLogger(__name__)
callback_deprecation_message = (
"`ray.tune.integration.mlflow.MLflowLoggerCallback` "
"is deprecated and will be removed in "
"the future. Please use `ray.air.callbacks.mlflow.MLflowLoggerCallback` "
"instead."
)
class MLflowLoggerCallback(LoggerCallback):
"""MLflow Logger to automatically log Tune results and config to MLflow.
MLflow (https://mlflow.org) Tracking is an open source library for
recording and querying experiments. This Ray Tune ``LoggerCallback``
sends information (config parameters, training results & metrics,
and artifacts) to MLflow for automatic experiment tracking.
Args:
tracking_uri: The tracking URI for where to manage experiments
and runs. This can either be a local file path or a remote server.
This arg gets passed directly to mlflow
initialization. When using Tune in a multi-node setting, make sure
to set this to a remote server and not a local file path.
registry_uri: The registry URI that gets passed directly to
mlflow initialization.
experiment_name: The experiment name to use for this Tune run.
If the experiment with the name already exists with MLflow,
it will be reused. If not, a new experiment will be created with
that name.
tags: An optional dictionary of string keys and values to set
as tags on the run
save_artifact: If set to True, automatically save the entire
contents of the Tune local_dir as an artifact to the
corresponding run in MlFlow.
Example:
.. code-block:: python
from ray.tune.integration.mlflow import MLflowLoggerCallback
tags = { "user_name" : "John",
"git_commit_hash" : "abc123"}
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
},
callbacks=[MLflowLoggerCallback(
experiment_name="experiment1",
tags=tags,
save_artifact=True)])
"""
@Deprecated(message=callback_deprecation_message)
class MLflowLoggerCallback(_MLflowLoggerCallback):
def __init__(
self,
tracking_uri: Optional[str] = None,
@ -68,72 +28,11 @@ class MLflowLoggerCallback(LoggerCallback):
tags: Optional[Dict] = None,
save_artifact: bool = False,
):
self.tracking_uri = tracking_uri
self.registry_uri = registry_uri
self.experiment_name = experiment_name
self.tags = tags
self.should_save_artifact = save_artifact
self.mlflow_util = _MLflowLoggerUtil()
if ray.util.client.ray.is_connected():
logger.warning(
"When using MLflowLoggerCallback with Ray Client, "
"it is recommended to use a remote tracking "
"server. If you are using a MLflow tracking server "
"backed by the local filesystem, then it must be "
"setup on the server side and not on the client "
"side."
)
def setup(self, *args, **kwargs):
# Setup the mlflow logging util.
self.mlflow_util.setup_mlflow(
tracking_uri=self.tracking_uri,
registry_uri=self.registry_uri,
experiment_name=self.experiment_name,
logger.warning(callback_deprecation_message)
super().__init__(
tracking_uri, registry_uri, experiment_name, tags, save_artifact
)
if self.tags is None:
# Create empty dictionary for tags if not given explicitly
self.tags = {}
self._trial_runs = {}
def log_trial_start(self, trial: "Trial"):
# Create run if not already exists.
if trial not in self._trial_runs:
# Set trial name in tags
tags = self.tags.copy()
tags["trial_name"] = str(trial)
run = self.mlflow_util.start_run(tags=tags, run_name=str(trial))
self._trial_runs[trial] = run.info.run_id
run_id = self._trial_runs[trial]
# Log the config parameters.
config = trial.config
self.mlflow_util.log_params(run_id=run_id, params_to_log=config)
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
run_id = self._trial_runs[trial]
self.mlflow_util.log_metrics(run_id=run_id, metrics_to_log=result, step=step)
def log_trial_end(self, trial: "Trial", failed: bool = False):
run_id = self._trial_runs[trial]
# Log the artifact if set_artifact is set to True.
if self.should_save_artifact:
self.mlflow_util.save_artifacts(run_id=run_id, dir=trial.logdir)
# Stop the run once trial finishes.
status = "FINISHED" if not failed else "FAILED"
self.mlflow_util.end_run(run_id=run_id, status=status)
def mlflow_mixin(func: Callable):
"""mlflow_mixin

View file

@ -1,84 +1,54 @@
import enum
import os
import pickle
from collections.abc import Sequence
from multiprocessing import Process, Queue
from numbers import Number
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import urllib
from typing import List, Dict, Callable, Optional
from ray import logger
from ray.tune import Trainable
from ray.tune.trainable import FunctionTrainable
from ray.tune.logger import LoggerCallback
from ray.tune.utils import flatten_dict
from ray.tune.experiment import Trial
import yaml
from ray.air.callbacks.wandb import (
wandb,
_clean_log,
_set_api_key,
WandbLoggerCallback as _WandbLoggerCallback,
)
try:
import wandb
except ImportError:
logger.error("pip install 'wandb' to use WandbLoggerCallback/WandbTrainableMixin.")
wandb = None
import logging
WANDB_ENV_VAR = "WANDB_API_KEY"
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
from ray.util.annotations import Deprecated
logger = logging.getLogger(__name__)
callback_deprecation_message = (
"`ray.tune.integration.wandb.WandbLoggerCallback` "
"is deprecated and will be removed in "
"the future. Please use `ray.air.callbacks.wandb.WandbLoggerCallback` "
"instead."
)
def _is_allowed_type(obj):
"""Return True if type is allowed for logging to wandb"""
if isinstance(obj, np.ndarray) and obj.size == 1:
return isinstance(obj.item(), Number)
if isinstance(obj, Sequence) and len(obj) > 0:
return isinstance(obj[0], _VALID_ITERABLE_TYPES)
return isinstance(obj, _VALID_TYPES)
def _clean_log(obj: Any):
# Fixes https://github.com/ray-project/ray/issues/10631
if isinstance(obj, dict):
return {k: _clean_log(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_clean_log(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(_clean_log(v) for v in obj)
elif _is_allowed_type(obj):
return obj
# Else
try:
pickle.dumps(obj)
yaml.dump(
obj,
Dumper=yaml.SafeDumper,
default_flow_style=False,
allow_unicode=True,
encoding="utf-8",
@Deprecated(message=callback_deprecation_message)
class WandbLoggerCallback(_WandbLoggerCallback):
def __init__(
self,
project: str,
group: Optional[str] = None,
api_key_file: Optional[str] = None,
api_key: Optional[str] = None,
excludes: Optional[List[str]] = None,
log_config: bool = False,
save_checkpoints: bool = False,
**kwargs
):
logger.warning(callback_deprecation_message)
super().__init__(
project,
group,
api_key_file,
api_key,
excludes,
log_config,
save_checkpoints,
**kwargs
)
return obj
except Exception:
# give up, similar to _SafeFallBackEncoder
fallback = str(obj)
# Try to convert to int
try:
fallback = int(fallback)
return fallback
except ValueError:
pass
# Try to convert to float
try:
fallback = float(fallback)
return fallback
except ValueError:
pass
# Else, return string
return fallback
def wandb_mixin(func: Callable):
@ -155,297 +125,6 @@ def wandb_mixin(func: Callable):
return func
def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
"""Set WandB API key from `wandb_config`. Will pop the
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
if api_key_file:
if api_key:
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
with open(api_key_file, "rt") as fp:
api_key = fp.readline().strip()
if api_key:
os.environ[WANDB_ENV_VAR] = api_key
elif not os.environ.get(WANDB_ENV_VAR):
try:
# Check if user is already logged into wandb.
wandb.ensure_configured()
if wandb.api.api_key:
logger.info("Already logged into W&B.")
return
except AttributeError:
pass
raise ValueError(
"No WandB API key found. Either set the {} environment "
"variable, pass `api_key` or `api_key_file` to the"
"`WandbLoggerCallback` class as arguments, "
"or run `wandb login` from the command line".format(WANDB_ENV_VAR)
)
class _QueueItem(enum.Enum):
END = enum.auto()
RESULT = enum.auto()
CHECKPOINT = enum.auto()
class _WandbLoggingProcess(Process):
"""
We need a `multiprocessing.Process` to allow multiple concurrent
wandb logging instances locally.
We use a queue for the driver to communicate with the logging process.
The queue accepts the following items:
- If it's a dict, it is assumed to be a result and will be logged using
``wandb.log()``
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
"""
def __init__(
self,
logdir: str,
queue: Queue,
exclude: List[str],
to_config: List[str],
*args,
**kwargs,
):
super(_WandbLoggingProcess, self).__init__()
os.chdir(logdir)
self.queue = queue
self._exclude = set(exclude)
self._to_config = set(to_config)
self.args = args
self.kwargs = kwargs
self._trial_name = self.kwargs.get("name", "unknown")
def run(self):
# Since we're running in a separate process already, use threads.
os.environ["WANDB_START_METHOD"] = "thread"
wandb.init(*self.args, **self.kwargs)
while True:
item_type, item_content = self.queue.get()
if item_type == _QueueItem.END:
break
if item_type == _QueueItem.CHECKPOINT:
self._handle_checkpoint(item_content)
continue
assert item_type == _QueueItem.RESULT
log, config_update = self._handle_result(item_content)
try:
wandb.config.update(config_update, allow_val_change=True)
wandb.log(log)
except urllib.error.HTTPError as e:
# Ignore HTTPError. Missing a few data points is not a
# big issue, as long as things eventually recover.
logger.warn("Failed to log result to w&b: {}".format(str(e)))
wandb.finish()
def _handle_checkpoint(self, checkpoint_path: str):
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
artifact.add_dir(checkpoint_path)
wandb.log_artifact(artifact)
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
config_update = result.get("config", {}).copy()
log = {}
flat_result = flatten_dict(result, delimiter="/")
for k, v in flat_result.items():
if any(k.startswith(item + "/") or k == item for item in self._to_config):
config_update[k] = v
elif any(k.startswith(item + "/") or k == item for item in self._exclude):
continue
elif not _is_allowed_type(v):
continue
else:
log[k] = v
config_update.pop("callbacks", None) # Remove callbacks
return log, config_update
class WandbLoggerCallback(LoggerCallback):
"""WandbLoggerCallback
Weights and biases (https://www.wandb.ai/) is a tool for experiment
tracking, model optimization, and dataset versioning. This Ray Tune
``LoggerCallback`` sends metrics to Wandb for automatic tracking and
visualization.
Args:
project: Name of the Wandb project. Mandatory.
group: Name of the Wandb group. Defaults to the trainable
name.
api_key_file: Path to file containing the Wandb API KEY. This
file only needs to be present on the node running the Tune script
if using the WandbLogger.
api_key: Wandb API Key. Alternative to setting ``api_key_file``.
excludes: List of metrics that should be excluded from
the log.
log_config: Boolean indicating if the ``config`` parameter of
the ``results`` dict should be logged. This makes sense if
parameters will change during training, e.g. with
PopulationBasedTraining. Defaults to False.
save_checkpoints: If ``True``, model checkpoints will be saved to
Wandb as artifacts. Defaults to ``False``.
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
by Tune, but can be overwritten by filling out the respective configuration
values.
Please see here for all other valid configuration settings:
https://docs.wandb.ai/library/init
Example:
.. code-block:: python
from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune.integration.wandb import WandbLoggerCallback
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
},
callbacks=[WandbLoggerCallback(
project="Optimization_Project",
api_key_file="/path/to/file",
log_config=True)])
"""
# Do not log these result keys
_exclude_results = ["done", "should_checkpoint"]
# Use these result keys to update `wandb.config`
_config_results = [
"trial_id",
"experiment_tag",
"node_ip",
"experiment_id",
"hostname",
"pid",
"date",
]
_logger_process_cls = _WandbLoggingProcess
def __init__(
self,
project: str,
group: Optional[str] = None,
api_key_file: Optional[str] = None,
api_key: Optional[str] = None,
excludes: Optional[List[str]] = None,
log_config: bool = False,
save_checkpoints: bool = False,
**kwargs,
):
self.project = project
self.group = group
self.api_key_path = api_key_file
self.api_key = api_key
self.excludes = excludes or []
self.log_config = log_config
self.save_checkpoints = save_checkpoints
self.kwargs = kwargs
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
self._trial_queues: Dict["Trial", Queue] = {}
def setup(self, *args, **kwargs):
self.api_key_file = (
os.path.expanduser(self.api_key_path) if self.api_key_path else None
)
_set_api_key(self.api_key_file, self.api_key)
def log_trial_start(self, trial: "Trial"):
config = trial.config.copy()
config.pop("callbacks", None) # Remove callbacks
exclude_results = self._exclude_results.copy()
# Additional excludes
exclude_results += self.excludes
# Log config keys on each result?
if not self.log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = trial.trial_id if trial else None
trial_name = str(trial) if trial else None
# Project name for Wandb
wandb_project = self.project
# Grouping
wandb_group = self.group or trial.trainable_name if trial else None
# remove unpickleable items!
config = _clean_log(config)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
resume=False,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config,
)
wandb_init_kwargs.update(self.kwargs)
self._trial_queues[trial] = Queue()
self._trial_processes[trial] = self._logger_process_cls(
logdir=trial.logdir,
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs,
)
self._trial_processes[trial].start()
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_processes:
self.log_trial_start(trial)
result = _clean_log(result)
self._trial_queues[trial].put((_QueueItem.RESULT, result))
def log_trial_save(self, trial: "Trial"):
if self.save_checkpoints and trial.checkpoint:
self._trial_queues[trial].put(
(_QueueItem.CHECKPOINT, trial.checkpoint.dir_or_data)
)
def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_queues[trial].put((_QueueItem.END, None))
self._trial_processes[trial].join(timeout=10)
del self._trial_queues[trial]
del self._trial_processes[trial]
def __del__(self):
for trial in self._trial_processes:
if trial in self._trial_queues:
self._trial_queues[trial].put((_QueueItem.END, None))
del self._trial_queues[trial]
self._trial_processes[trial].join(timeout=2)
del self._trial_processes[trial]
class WandbTrainableMixin:
_wandb = wandb

View file

@ -1,6 +1,6 @@
import unittest
from unittest.mock import patch
from ray.tune.integration.comet import CometLoggerCallback
from ray.air.callbacks.comet import CometLoggerCallback
from collections import namedtuple

View file

@ -8,10 +8,12 @@ from mlflow.tracking import MlflowClient
from ray.tune.trainable import wrap_function
from ray.tune.integration.mlflow import (
MLflowLoggerCallback,
MLflowTrainableMixin,
mlflow_mixin,
)
from ray.air.callbacks.mlflow import (
MLflowLoggerCallback,
)
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
@ -136,7 +138,7 @@ class MLflowTest(unittest.TestCase):
logger.setup()
self.assertEqual(logger.tags, tags)
@patch("ray.tune.integration.mlflow._MLflowLoggerUtil", Mock_MLflowLoggerUtil)
@patch("ray.air.callbacks.mlflow._MLflowLoggerUtil", Mock_MLflowLoggerUtil)
def testMlFlowLoggerLogging(self):
clear_env_vars()
trial_config = {"par1": "a", "par2": "b"}

View file

@ -9,11 +9,13 @@ import numpy as np
from ray.tune import Trainable
from ray.tune.trainable import wrap_function
from ray.tune.integration.wandb import (
WandbTrainableMixin,
wandb_mixin,
)
from ray.air.callbacks.wandb import (
WandbLoggerCallback,
_WandbLoggingProcess,
WANDB_ENV_VAR,
WandbTrainableMixin,
wandb_mixin,
_QueueItem,
)
from ray.tune.result import TRIAL_INFO