mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] Move integration logging callbacks to AIR (#26126)
As the integration logging callbacks are commonly used with AIR Trainers, they should be moved from the tune package to the air package. The old imports will still work, but raise a deprecation warning.
This commit is contained in:
parent
c9be251b7a
commit
128f9e5664
18 changed files with 1034 additions and 874 deletions
|
@ -799,7 +799,7 @@
|
|||
"source": [
|
||||
"from ray.train.huggingface import HuggingFaceTrainer\n",
|
||||
"from ray.air import RunConfig\n",
|
||||
"from ray.tune.integration.mlflow import MLflowLoggerCallback\n",
|
||||
"from ray.air.callbacks.mlflow import MLflowLoggerCallback\n",
|
||||
"\n",
|
||||
"trainer = HuggingFaceTrainer(\n",
|
||||
" trainer_init_per_worker=trainer_init_per_worker,\n",
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
"from ray.air import RunConfig\n",
|
||||
"from ray.air.result import Result\n",
|
||||
"from ray.train.xgboost import XGBoostTrainer\n",
|
||||
"from ray.tune.integration.comet import CometLoggerCallback"
|
||||
"from ray.air.callbacks.comet import CometLoggerCallback"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
"from ray.air import RunConfig\n",
|
||||
"from ray.air.result import Result\n",
|
||||
"from ray.train.xgboost import XGBoostTrainer\n",
|
||||
"from ray.tune.integration.wandb import WandbLoggerCallback"
|
||||
"from ray.air.callbacks.wandb import WandbLoggerCallback"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -25,7 +25,7 @@ MLflow (tune.integration.mlflow)
|
|||
|
||||
:ref:`See also here <tune-mlflow-ref>`.
|
||||
|
||||
.. autoclass:: ray.tune.integration.mlflow.MLflowLoggerCallback
|
||||
.. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback
|
||||
|
||||
.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin
|
||||
|
||||
|
@ -56,7 +56,7 @@ Weights and Biases (tune.integration.wandb)
|
|||
|
||||
:ref:`See also here <tune-wandb-ref>`.
|
||||
|
||||
.. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback
|
||||
.. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback
|
||||
|
||||
.. autofunction:: ray.tune.integration.wandb.wandb_mixin
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@
|
|||
"source": [
|
||||
"# This cell is hidden from the rendered notebook. It makes the \n",
|
||||
"from unittest.mock import MagicMock\n",
|
||||
"from ray.tune.integration.comet import CometLoggerCallback\n",
|
||||
"from ray.air.callbacks.comet import CometLoggerCallback\n",
|
||||
"\n",
|
||||
"CometLoggerCallback._logger_process_cls = MagicMock\n",
|
||||
"api_key = \"abc\"\n",
|
||||
|
@ -103,7 +103,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.tune.integration.comet import CometLoggerCallback\n",
|
||||
"from ray.air.callbacks.comet import CometLoggerCallback\n",
|
||||
"\n",
|
||||
"analysis = tune.run(\n",
|
||||
" train_function,\n",
|
||||
|
@ -133,7 +133,7 @@
|
|||
"Click on the following dropdown to see this callback API in detail:\n",
|
||||
"\n",
|
||||
"```{eval-rst}\n",
|
||||
".. autoclass:: ray.tune.integration.comet.CometLoggerCallback\n",
|
||||
".. autoclass:: ray.air.callbacks.comet.CometLoggerCallback\n",
|
||||
" :noindex:\n",
|
||||
"```"
|
||||
]
|
||||
|
|
|
@ -52,7 +52,11 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0e47339",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
|
@ -62,25 +66,35 @@
|
|||
"import mlflow\n",
|
||||
"\n",
|
||||
"from ray import tune\n",
|
||||
"from ray.tune.integration.mlflow import MLflowLoggerCallback, mlflow_mixin"
|
||||
"from ray.air.callbacks.mlflow import MLflowLoggerCallback\n",
|
||||
"from ray.tune.integration.mlflow import mlflow_mixin"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Next, let's define an easy objective function (a Tune `Trainable`) that iteratively computes steps and evaluates\n",
|
||||
"intermediate scores that we report to Tune."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Next, let's define an easy objective function (a Tune `Trainable`) that iteratively computes steps and evaluates\n",
|
||||
"intermediate scores that we report to Tune."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluation_fn(step, width, height):\n",
|
||||
|
@ -96,30 +110,33 @@
|
|||
" # Feed the score back to Tune.\n",
|
||||
" tune.report(iterations=step, mean_loss=intermediate_score)\n",
|
||||
" time.sleep(0.1)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Given an MLFlow tracking URI, you can now simply use the `MLflowLoggerCallback` as a `callback` argument to\n",
|
||||
"your `tune.run()` call:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Given an MLFlow tracking URI, you can now simply use the `MLflowLoggerCallback` as a `callback` argument to\n",
|
||||
"your `tune.run()` call:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tune_function(mlflow_tracking_uri, finish_fast=False):\n",
|
||||
|
@ -140,28 +157,31 @@
|
|||
" \"steps\": 5 if finish_fast else 100,\n",
|
||||
" },\n",
|
||||
" )"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"To use the `mlflow_mixin` decorator, you can simply decorate the objective function from earlier.\n",
|
||||
"Note that we also use `mlflow.log_metrics(...)` to log metrics to MLflow.\n",
|
||||
"Otherwise, the decorated version of our objective is identical to its original."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@mlflow_mixin\n",
|
||||
|
@ -177,26 +197,29 @@
|
|||
" # Feed the score back to Tune.\n",
|
||||
" tune.report(iterations=step, mean_loss=intermediate_score)\n",
|
||||
" time.sleep(0.1)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"With this new objective function ready, you can now create a Tune run with it as follows:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"With this new objective function ready, you can now create a Tune run with it as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tune_decorated(mlflow_tracking_uri, finish_fast=False):\n",
|
||||
|
@ -217,28 +240,31 @@
|
|||
" },\n",
|
||||
" },\n",
|
||||
" )"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"If you hapen to have an MLFlow tracking URI, you can set it below in the `mlflow_tracking_uri` variable and set\n",
|
||||
"`smoke_test=False`.\n",
|
||||
"Otherwise, you can just run a quick test of the `tune_function` and `tune_decorated` functions without using MLflow."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"smoke_test = True\n",
|
||||
|
@ -261,13 +287,7 @@
|
|||
" [mlflow.get_experiment_by_name(\"mixin_example\").experiment_id]\n",
|
||||
" )\n",
|
||||
" print(df)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
@ -287,7 +307,7 @@
|
|||
"(tune-mlflow-logger)=\n",
|
||||
"\n",
|
||||
"```{eval-rst}\n",
|
||||
".. autoclass:: ray.tune.integration.mlflow.MLflowLoggerCallback\n",
|
||||
".. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback\n",
|
||||
" :noindex:\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
|
|
|
@ -47,7 +47,11 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "100bcf8a",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
|
@ -55,8 +59,8 @@
|
|||
"\n",
|
||||
"from ray import tune\n",
|
||||
"from ray.tune import Trainable\n",
|
||||
"from ray.air.callbacks.wandb import WandbLoggerCallback\n",
|
||||
"from ray.tune.integration.wandb import (\n",
|
||||
" WandbLoggerCallback,\n",
|
||||
" WandbTrainableMixin,\n",
|
||||
" wandb_mixin,\n",
|
||||
")"
|
||||
|
@ -64,45 +68,57 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"Next, let's define an easy `objective` function (a Tune `Trainable`) that reports a random loss to Tune.\n",
|
||||
"The objective function itself is not important for this example, since we want to focus on the Weights & Biases\n",
|
||||
"integration primarily."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def objective(config, checkpoint_dir=None):\n",
|
||||
" for i in range(30):\n",
|
||||
" loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
|
||||
" tune.report(loss=loss)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"Given that you provide an `api_key_file` pointing to your Weights & Biases API key, you cna define a\n",
|
||||
"simple grid-search Tune run using the `WandbLoggerCallback` as follows:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tune_function(api_key_file):\n",
|
||||
|
@ -120,28 +136,31 @@
|
|||
" ],\n",
|
||||
" )\n",
|
||||
" return analysis.best_config"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"To use the `wandb_mixin` decorator, you can simply decorate the objective function from earlier.\n",
|
||||
"Note that we also use `wandb.log(...)` to log the `loss` to Weights & Biases as a dictionary.\n",
|
||||
"Otherwise, the decorated version of our objective is identical to its original."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@wandb_mixin\n",
|
||||
|
@ -150,27 +169,30 @@
|
|||
" loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
|
||||
" tune.report(loss=loss)\n",
|
||||
" wandb.log(dict(loss=loss))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"With the `decorated_objective` defined, running a Tune experiment is as simple as providing this objective and\n",
|
||||
"passing the `api_key_file` to the `wandb` key of your Tune `config`:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tune_decorated(api_key_file):\n",
|
||||
|
@ -186,26 +208,29 @@
|
|||
" },\n",
|
||||
" )\n",
|
||||
" return analysis.best_config"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Finally, you can also define a class-based Tune `Trainable` by using the `WandbTrainableMixin` to define your objective:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Finally, you can also define a class-based Tune `Trainable` by using the `WandbTrainableMixin` to define your objective:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class WandbTrainable(WandbTrainableMixin, Trainable):\n",
|
||||
|
@ -214,28 +239,31 @@
|
|||
" loss = self.config[\"mean\"] + self.config[\"sd\"] * np.random.randn()\n",
|
||||
" wandb.log({\"loss\": loss})\n",
|
||||
" return {\"loss\": loss, \"done\": True}"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"Running Tune with this `WandbTrainable` works exactly the same as with the function API.\n",
|
||||
"The below `tune_trainable` function differs from `tune_decorated` above only in the first argument we pass to\n",
|
||||
"`tune.run()`:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tune_trainable(api_key_file):\n",
|
||||
|
@ -251,28 +279,31 @@
|
|||
" },\n",
|
||||
" )\n",
|
||||
" return analysis.best_config"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"Since you may not have an API key for Wandb, we can _mock_ the Wandb logger and test all three of our training\n",
|
||||
"functions as follows.\n",
|
||||
"If you do have an API key file, make sure to set `mock_api` to `False` and pass in the right `api_key_file` below."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
},
|
||||
"vscode": {
|
||||
"languageId": "python"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tempfile\n",
|
||||
|
@ -298,13 +329,7 @@
|
|||
"\n",
|
||||
"if mock_api:\n",
|
||||
" temp_file.close()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
@ -321,7 +346,7 @@
|
|||
"(tune-wandb-logger)=\n",
|
||||
"\n",
|
||||
"```{eval-rst}\n",
|
||||
".. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback\n",
|
||||
".. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback\n",
|
||||
" :noindex:\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
|
|
249
python/ray/air/callbacks/comet.py
Normal file
249
python/ray/air/callbacks/comet.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
||||
|
||||
def _import_comet():
|
||||
"""Try importing comet_ml.
|
||||
|
||||
Used to check if comet_ml is installed and, otherwise, pass an informative
|
||||
error message.
|
||||
"""
|
||||
if "COMET_DISABLE_AUTO_LOGGING" not in os.environ:
|
||||
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
|
||||
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
except ImportError:
|
||||
raise RuntimeError("pip install 'comet-ml' to use CometLoggerCallback")
|
||||
|
||||
return comet_ml
|
||||
|
||||
|
||||
class CometLoggerCallback(LoggerCallback):
|
||||
"""CometLoggerCallback for logging Tune results to Comet.
|
||||
|
||||
Comet (https://comet.ml/site/) is a tool to manage and optimize the
|
||||
entire ML lifecycle, from experiment tracking, model optimization
|
||||
and dataset versioning to model production monitoring.
|
||||
|
||||
This Ray Tune ``LoggerCallback`` sends metrics and parameters to
|
||||
Comet for tracking.
|
||||
|
||||
In order to use the CometLoggerCallback you must first install Comet
|
||||
via ``pip install comet_ml``
|
||||
|
||||
Then set the following environment variables
|
||||
``export COMET_API_KEY=<Your API Key>``
|
||||
|
||||
Alternatively, you can also pass in your API Key as an argument to the
|
||||
CometLoggerCallback constructor.
|
||||
|
||||
``CometLoggerCallback(api_key=<Your API Key>)``
|
||||
|
||||
Args:
|
||||
online: Whether to make use of an Online or
|
||||
Offline Experiment. Defaults to True.
|
||||
tags: Tags to add to the logged Experiment.
|
||||
Defaults to None.
|
||||
save_checkpoints: If ``True``, model checkpoints will be saved to
|
||||
Comet ML as artifacts. Defaults to ``False``.
|
||||
**experiment_kwargs: Other keyword arguments will be passed to the
|
||||
constructor for comet_ml.Experiment (or OfflineExperiment if
|
||||
online=False).
|
||||
|
||||
Please consult the Comet ML documentation for more information on the
|
||||
Experiment and OfflineExperiment classes: https://comet.ml/site/
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.air.callbacks.comet import CometLoggerCallback
|
||||
tune.run(
|
||||
train,
|
||||
config=config
|
||||
callbacks=[CometLoggerCallback(
|
||||
True,
|
||||
['tag1', 'tag2'],
|
||||
workspace='my_workspace',
|
||||
project_name='my_project_name'
|
||||
)]
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
# Do not enable these auto log options unless overridden
|
||||
_exclude_autolog = [
|
||||
"auto_output_logging",
|
||||
"log_git_metadata",
|
||||
"log_git_patch",
|
||||
"log_env_cpu",
|
||||
"log_env_gpu",
|
||||
]
|
||||
|
||||
# Do not log these metrics.
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# These values should be logged as system info instead of metrics.
|
||||
_system_results = ["node_ip", "hostname", "pid", "date"]
|
||||
|
||||
# These values should be logged as "Other" instead of as metrics.
|
||||
_other_results = ["trial_id", "experiment_id", "experiment_tag"]
|
||||
|
||||
_episode_results = ["hist_stats/episode_reward", "hist_stats/episode_lengths"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
online: bool = True,
|
||||
tags: List[str] = None,
|
||||
save_checkpoints: bool = False,
|
||||
**experiment_kwargs,
|
||||
):
|
||||
_import_comet()
|
||||
self.online = online
|
||||
self.tags = tags
|
||||
self.save_checkpoints = save_checkpoints
|
||||
self.experiment_kwargs = experiment_kwargs
|
||||
|
||||
# Disable the specific autologging features that cause throttling.
|
||||
self._configure_experiment_defaults()
|
||||
|
||||
# Mapping from trial to experiment object.
|
||||
self._trial_experiments = {}
|
||||
|
||||
self._to_exclude = self._exclude_results.copy()
|
||||
self._to_system = self._system_results.copy()
|
||||
self._to_other = self._other_results.copy()
|
||||
self._to_episodes = self._episode_results.copy()
|
||||
|
||||
def _configure_experiment_defaults(self):
|
||||
"""Disable the specific autologging features that cause throttling."""
|
||||
for option in self._exclude_autolog:
|
||||
if not self.experiment_kwargs.get(option):
|
||||
self.experiment_kwargs[option] = False
|
||||
|
||||
def _check_key_name(self, key: str, item: str) -> bool:
|
||||
"""
|
||||
Check if key argument is equal to item argument or starts with item and
|
||||
a forward slash. Used for parsing trial result dictionary into ignored
|
||||
keys, system metrics, episode logs, etc.
|
||||
"""
|
||||
return key.startswith(item + "/") or key == item
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
"""
|
||||
Initialize an Experiment (or OfflineExperiment if self.online=False)
|
||||
and start logging to Comet.
|
||||
|
||||
Args:
|
||||
trial: Trial object.
|
||||
|
||||
"""
|
||||
_import_comet() # is this necessary?
|
||||
from comet_ml import Experiment, OfflineExperiment
|
||||
from comet_ml.config import set_global_experiment
|
||||
|
||||
if trial not in self._trial_experiments:
|
||||
experiment_cls = Experiment if self.online else OfflineExperiment
|
||||
experiment = experiment_cls(**self.experiment_kwargs)
|
||||
self._trial_experiments[trial] = experiment
|
||||
# Set global experiment to None to allow for multiple experiments.
|
||||
set_global_experiment(None)
|
||||
else:
|
||||
experiment = self._trial_experiments[trial]
|
||||
|
||||
experiment.set_name(str(trial))
|
||||
experiment.add_tags(self.tags)
|
||||
experiment.log_other("Created from", "Ray")
|
||||
|
||||
config = trial.config.copy()
|
||||
config.pop("callbacks", None)
|
||||
experiment.log_parameters(config)
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
"""
|
||||
Log the current result of a Trial upon each iteration.
|
||||
"""
|
||||
if trial not in self._trial_experiments:
|
||||
self.log_trial_start(trial)
|
||||
experiment = self._trial_experiments[trial]
|
||||
step = result["training_iteration"]
|
||||
|
||||
config_update = result.pop("config", {}).copy()
|
||||
config_update.pop("callbacks", None) # Remove callbacks
|
||||
for k, v in config_update.items():
|
||||
if isinstance(v, dict):
|
||||
experiment.log_parameters(flatten_dict({k: v}, "/"), step=step)
|
||||
|
||||
else:
|
||||
experiment.log_parameter(k, v, step=step)
|
||||
|
||||
other_logs = {}
|
||||
metric_logs = {}
|
||||
system_logs = {}
|
||||
episode_logs = {}
|
||||
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
for k, v in flat_result.items():
|
||||
if any(self._check_key_name(k, item) for item in self._to_exclude):
|
||||
continue
|
||||
|
||||
if any(self._check_key_name(k, item) for item in self._to_other):
|
||||
other_logs[k] = v
|
||||
|
||||
elif any(self._check_key_name(k, item) for item in self._to_system):
|
||||
system_logs[k] = v
|
||||
|
||||
elif any(self._check_key_name(k, item) for item in self._to_episodes):
|
||||
episode_logs[k] = v
|
||||
|
||||
else:
|
||||
metric_logs[k] = v
|
||||
|
||||
experiment.log_others(other_logs)
|
||||
experiment.log_metrics(metric_logs, step=step)
|
||||
|
||||
for k, v in system_logs.items():
|
||||
experiment.log_system_info(k, v)
|
||||
|
||||
for k, v in episode_logs.items():
|
||||
experiment.log_curve(k, x=range(len(v)), y=v, step=step)
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
comet_ml = _import_comet()
|
||||
|
||||
if self.save_checkpoints and trial.checkpoint:
|
||||
experiment = self._trial_experiments[trial]
|
||||
|
||||
artifact = comet_ml.Artifact(
|
||||
name=f"checkpoint_{(str(trial))}", artifact_type="model"
|
||||
)
|
||||
|
||||
# Walk through checkpoint directory and add all files to artifact
|
||||
checkpoint_root = trial.checkpoint.dir_or_data
|
||||
for root, dirs, files in os.walk(checkpoint_root):
|
||||
rel_root = os.path.relpath(root, checkpoint_root)
|
||||
for file in files:
|
||||
local_file = os.path.join(checkpoint_root, rel_root, file)
|
||||
logical_path = os.path.join(rel_root, file)
|
||||
|
||||
# Strip leading `./`
|
||||
if logical_path.startswith("./"):
|
||||
logical_path = logical_path[2:]
|
||||
|
||||
artifact.add(local_file, logical_path=logical_path)
|
||||
|
||||
experiment.log_artifact(artifact)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
self._trial_experiments[trial].end()
|
||||
del self._trial_experiments[trial]
|
||||
|
||||
def __del__(self):
|
||||
for trial, experiment in self._trial_experiments.items():
|
||||
experiment.end()
|
||||
self._trial_experiments = {}
|
134
python/ray/air/callbacks/mlflow.py
Normal file
134
python/ray/air/callbacks/mlflow.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
import ray
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLflowLoggerCallback(LoggerCallback):
|
||||
"""MLflow Logger to automatically log Tune results and config to MLflow.
|
||||
|
||||
MLflow (https://mlflow.org) Tracking is an open source library for
|
||||
recording and querying experiments. This Ray Tune ``LoggerCallback``
|
||||
sends information (config parameters, training results & metrics,
|
||||
and artifacts) to MLflow for automatic experiment tracking.
|
||||
|
||||
Args:
|
||||
tracking_uri: The tracking URI for where to manage experiments
|
||||
and runs. This can either be a local file path or a remote server.
|
||||
This arg gets passed directly to mlflow
|
||||
initialization. When using Tune in a multi-node setting, make sure
|
||||
to set this to a remote server and not a local file path.
|
||||
registry_uri: The registry URI that gets passed directly to
|
||||
mlflow initialization.
|
||||
experiment_name: The experiment name to use for this Tune run.
|
||||
If the experiment with the name already exists with MLflow,
|
||||
it will be reused. If not, a new experiment will be created with
|
||||
that name.
|
||||
tags: An optional dictionary of string keys and values to set
|
||||
as tags on the run
|
||||
save_artifact: If set to True, automatically save the entire
|
||||
contents of the Tune local_dir as an artifact to the
|
||||
corresponding run in MlFlow.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.air.callbacks.mlflow import MLflowLoggerCallback
|
||||
|
||||
tags = { "user_name" : "John",
|
||||
"git_commit_hash" : "abc123"}
|
||||
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[MLflowLoggerCallback(
|
||||
experiment_name="experiment1",
|
||||
tags=tags,
|
||||
save_artifact=True)])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tracking_uri: Optional[str] = None,
|
||||
registry_uri: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
tags: Optional[Dict] = None,
|
||||
save_artifact: bool = False,
|
||||
):
|
||||
|
||||
self.tracking_uri = tracking_uri
|
||||
self.registry_uri = registry_uri
|
||||
self.experiment_name = experiment_name
|
||||
self.tags = tags
|
||||
self.should_save_artifact = save_artifact
|
||||
|
||||
self.mlflow_util = _MLflowLoggerUtil()
|
||||
|
||||
if ray.util.client.ray.is_connected():
|
||||
logger.warning(
|
||||
"When using MLflowLoggerCallback with Ray Client, "
|
||||
"it is recommended to use a remote tracking "
|
||||
"server. If you are using a MLflow tracking server "
|
||||
"backed by the local filesystem, then it must be "
|
||||
"setup on the server side and not on the client "
|
||||
"side."
|
||||
)
|
||||
|
||||
def setup(self, *args, **kwargs):
|
||||
# Setup the mlflow logging util.
|
||||
self.mlflow_util.setup_mlflow(
|
||||
tracking_uri=self.tracking_uri,
|
||||
registry_uri=self.registry_uri,
|
||||
experiment_name=self.experiment_name,
|
||||
)
|
||||
|
||||
if self.tags is None:
|
||||
# Create empty dictionary for tags if not given explicitly
|
||||
self.tags = {}
|
||||
|
||||
self._trial_runs = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
# Create run if not already exists.
|
||||
if trial not in self._trial_runs:
|
||||
|
||||
# Set trial name in tags
|
||||
tags = self.tags.copy()
|
||||
tags["trial_name"] = str(trial)
|
||||
|
||||
run = self.mlflow_util.start_run(tags=tags, run_name=str(trial))
|
||||
self._trial_runs[trial] = run.info.run_id
|
||||
|
||||
run_id = self._trial_runs[trial]
|
||||
|
||||
# Log the config parameters.
|
||||
config = trial.config
|
||||
self.mlflow_util.log_params(run_id=run_id, params_to_log=config)
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
run_id = self._trial_runs[trial]
|
||||
self.mlflow_util.log_metrics(run_id=run_id, metrics_to_log=result, step=step)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
run_id = self._trial_runs[trial]
|
||||
|
||||
# Log the artifact if set_artifact is set to True.
|
||||
if self.should_save_artifact:
|
||||
self.mlflow_util.save_artifacts(run_id=run_id, dir=trial.logdir)
|
||||
|
||||
# Stop the run once trial finishes.
|
||||
status = "FINISHED" if not failed else "FAILED"
|
||||
self.mlflow_util.end_run(run_id=run_id, status=status)
|
370
python/ray/air/callbacks/wandb.py
Normal file
370
python/ray/air/callbacks/wandb.py
Normal file
|
@ -0,0 +1,370 @@
|
|||
import enum
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
from multiprocessing import Process, Queue
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import urllib
|
||||
|
||||
from ray import logger
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.experiment import Trial
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
logger.error("pip install 'wandb' to use WandbLoggerCallback/WandbTrainableMixin.")
|
||||
wandb = None
|
||||
|
||||
WANDB_ENV_VAR = "WANDB_API_KEY"
|
||||
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
|
||||
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
|
||||
|
||||
|
||||
def _is_allowed_type(obj):
|
||||
"""Return True if type is allowed for logging to wandb"""
|
||||
if isinstance(obj, np.ndarray) and obj.size == 1:
|
||||
return isinstance(obj.item(), Number)
|
||||
if isinstance(obj, Sequence) and len(obj) > 0:
|
||||
return isinstance(obj[0], _VALID_ITERABLE_TYPES)
|
||||
return isinstance(obj, _VALID_TYPES)
|
||||
|
||||
|
||||
def _clean_log(obj: Any):
|
||||
# Fixes https://github.com/ray-project/ray/issues/10631
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_log(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [_clean_log(v) for v in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(_clean_log(v) for v in obj)
|
||||
elif _is_allowed_type(obj):
|
||||
return obj
|
||||
|
||||
# Else
|
||||
try:
|
||||
pickle.dumps(obj)
|
||||
yaml.dump(
|
||||
obj,
|
||||
Dumper=yaml.SafeDumper,
|
||||
default_flow_style=False,
|
||||
allow_unicode=True,
|
||||
encoding="utf-8",
|
||||
)
|
||||
return obj
|
||||
except Exception:
|
||||
# give up, similar to _SafeFallBackEncoder
|
||||
fallback = str(obj)
|
||||
|
||||
# Try to convert to int
|
||||
try:
|
||||
fallback = int(fallback)
|
||||
return fallback
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try to convert to float
|
||||
try:
|
||||
fallback = float(fallback)
|
||||
return fallback
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Else, return string
|
||||
return fallback
|
||||
|
||||
|
||||
def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
|
||||
"""Set WandB API key from `wandb_config`. Will pop the
|
||||
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
|
||||
if api_key_file:
|
||||
if api_key:
|
||||
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
|
||||
with open(api_key_file, "rt") as fp:
|
||||
api_key = fp.readline().strip()
|
||||
if api_key:
|
||||
os.environ[WANDB_ENV_VAR] = api_key
|
||||
elif not os.environ.get(WANDB_ENV_VAR):
|
||||
try:
|
||||
# Check if user is already logged into wandb.
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key:
|
||||
logger.info("Already logged into W&B.")
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
raise ValueError(
|
||||
"No WandB API key found. Either set the {} environment "
|
||||
"variable, pass `api_key` or `api_key_file` to the"
|
||||
"`WandbLoggerCallback` class as arguments, "
|
||||
"or run `wandb login` from the command line".format(WANDB_ENV_VAR)
|
||||
)
|
||||
|
||||
|
||||
class _QueueItem(enum.Enum):
|
||||
END = enum.auto()
|
||||
RESULT = enum.auto()
|
||||
CHECKPOINT = enum.auto()
|
||||
|
||||
|
||||
class _WandbLoggingProcess(Process):
|
||||
"""
|
||||
We need a `multiprocessing.Process` to allow multiple concurrent
|
||||
wandb logging instances locally.
|
||||
|
||||
We use a queue for the driver to communicate with the logging process.
|
||||
The queue accepts the following items:
|
||||
|
||||
- If it's a dict, it is assumed to be a result and will be logged using
|
||||
``wandb.log()``
|
||||
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logdir: str,
|
||||
queue: Queue,
|
||||
exclude: List[str],
|
||||
to_config: List[str],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super(_WandbLoggingProcess, self).__init__()
|
||||
|
||||
os.chdir(logdir)
|
||||
self.queue = queue
|
||||
self._exclude = set(exclude)
|
||||
self._to_config = set(to_config)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_name = self.kwargs.get("name", "unknown")
|
||||
|
||||
def run(self):
|
||||
# Since we're running in a separate process already, use threads.
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
wandb.init(*self.args, **self.kwargs)
|
||||
|
||||
while True:
|
||||
item_type, item_content = self.queue.get()
|
||||
if item_type == _QueueItem.END:
|
||||
break
|
||||
|
||||
if item_type == _QueueItem.CHECKPOINT:
|
||||
self._handle_checkpoint(item_content)
|
||||
continue
|
||||
|
||||
assert item_type == _QueueItem.RESULT
|
||||
log, config_update = self._handle_result(item_content)
|
||||
try:
|
||||
wandb.config.update(config_update, allow_val_change=True)
|
||||
wandb.log(log)
|
||||
except urllib.error.HTTPError as e:
|
||||
# Ignore HTTPError. Missing a few data points is not a
|
||||
# big issue, as long as things eventually recover.
|
||||
logger.warn("Failed to log result to w&b: {}".format(str(e)))
|
||||
wandb.finish()
|
||||
|
||||
def _handle_checkpoint(self, checkpoint_path: str):
|
||||
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
|
||||
artifact.add_dir(checkpoint_path)
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
|
||||
config_update = result.get("config", {}).copy()
|
||||
log = {}
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
|
||||
for k, v in flat_result.items():
|
||||
if any(k.startswith(item + "/") or k == item for item in self._to_config):
|
||||
config_update[k] = v
|
||||
elif any(k.startswith(item + "/") or k == item for item in self._exclude):
|
||||
continue
|
||||
elif not _is_allowed_type(v):
|
||||
continue
|
||||
else:
|
||||
log[k] = v
|
||||
|
||||
config_update.pop("callbacks", None) # Remove callbacks
|
||||
return log, config_update
|
||||
|
||||
|
||||
class WandbLoggerCallback(LoggerCallback):
|
||||
"""WandbLoggerCallback
|
||||
|
||||
Weights and biases (https://www.wandb.ai/) is a tool for experiment
|
||||
tracking, model optimization, and dataset versioning. This Ray Tune
|
||||
``LoggerCallback`` sends metrics to Wandb for automatic tracking and
|
||||
visualization.
|
||||
|
||||
Args:
|
||||
project: Name of the Wandb project. Mandatory.
|
||||
group: Name of the Wandb group. Defaults to the trainable
|
||||
name.
|
||||
api_key_file: Path to file containing the Wandb API KEY. This
|
||||
file only needs to be present on the node running the Tune script
|
||||
if using the WandbLogger.
|
||||
api_key: Wandb API Key. Alternative to setting ``api_key_file``.
|
||||
excludes: List of metrics that should be excluded from
|
||||
the log.
|
||||
log_config: Boolean indicating if the ``config`` parameter of
|
||||
the ``results`` dict should be logged. This makes sense if
|
||||
parameters will change during training, e.g. with
|
||||
PopulationBasedTraining. Defaults to False.
|
||||
save_checkpoints: If ``True``, model checkpoints will be saved to
|
||||
Wandb as artifacts. Defaults to ``False``.
|
||||
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
|
||||
|
||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||
by Tune, but can be overwritten by filling out the respective configuration
|
||||
values.
|
||||
|
||||
Please see here for all other valid configuration settings:
|
||||
https://docs.wandb.ai/library/init
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from ray.air.callbacks.wandb import WandbLoggerCallback
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[WandbLoggerCallback(
|
||||
project="Optimization_Project",
|
||||
api_key_file="/path/to/file",
|
||||
log_config=True)])
|
||||
|
||||
"""
|
||||
|
||||
# Do not log these result keys
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# Use these result keys to update `wandb.config`
|
||||
_config_results = [
|
||||
"trial_id",
|
||||
"experiment_tag",
|
||||
"node_ip",
|
||||
"experiment_id",
|
||||
"hostname",
|
||||
"pid",
|
||||
"date",
|
||||
]
|
||||
|
||||
_logger_process_cls = _WandbLoggingProcess
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
group: Optional[str] = None,
|
||||
api_key_file: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
excludes: Optional[List[str]] = None,
|
||||
log_config: bool = False,
|
||||
save_checkpoints: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.project = project
|
||||
self.group = group
|
||||
self.api_key_path = api_key_file
|
||||
self.api_key = api_key
|
||||
self.excludes = excludes or []
|
||||
self.log_config = log_config
|
||||
self.save_checkpoints = save_checkpoints
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
|
||||
self._trial_queues: Dict["Trial", Queue] = {}
|
||||
|
||||
def setup(self, *args, **kwargs):
|
||||
self.api_key_file = (
|
||||
os.path.expanduser(self.api_key_path) if self.api_key_path else None
|
||||
)
|
||||
_set_api_key(self.api_key_file, self.api_key)
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
config = trial.config.copy()
|
||||
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
exclude_results += self.excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
if not self.log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = trial.trial_id if trial else None
|
||||
trial_name = str(trial) if trial else None
|
||||
|
||||
# Project name for Wandb
|
||||
wandb_project = self.project
|
||||
|
||||
# Grouping
|
||||
wandb_group = self.group or trial.trainable_name if trial else None
|
||||
|
||||
# remove unpickleable items!
|
||||
config = _clean_log(config)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=trial_name,
|
||||
resume=False,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config,
|
||||
)
|
||||
wandb_init_kwargs.update(self.kwargs)
|
||||
|
||||
self._trial_queues[trial] = Queue()
|
||||
self._trial_processes[trial] = self._logger_process_cls(
|
||||
logdir=trial.logdir,
|
||||
queue=self._trial_queues[trial],
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs,
|
||||
)
|
||||
self._trial_processes[trial].start()
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
if trial not in self._trial_processes:
|
||||
self.log_trial_start(trial)
|
||||
|
||||
result = _clean_log(result)
|
||||
self._trial_queues[trial].put((_QueueItem.RESULT, result))
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
if self.save_checkpoints and trial.checkpoint:
|
||||
self._trial_queues[trial].put(
|
||||
(_QueueItem.CHECKPOINT, trial.checkpoint.dir_or_data)
|
||||
)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||
self._trial_processes[trial].join(timeout=10)
|
||||
|
||||
del self._trial_queues[trial]
|
||||
del self._trial_processes[trial]
|
||||
|
||||
def __del__(self):
|
||||
for trial in self._trial_processes:
|
||||
if trial in self._trial_queues:
|
||||
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||
del self._trial_queues[trial]
|
||||
self._trial_processes[trial].join(timeout=2)
|
||||
del self._trial_processes[trial]
|
|
@ -8,7 +8,8 @@ import time
|
|||
import mlflow
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.mlflow import MLflowLoggerCallback, mlflow_mixin
|
||||
from ray.air.callbacks.mlflow import MLflowLoggerCallback
|
||||
from ray.tune.integration.mlflow import mlflow_mixin
|
||||
|
||||
|
||||
def evaluation_fn(step, width, height):
|
||||
|
|
|
@ -7,8 +7,8 @@ import wandb
|
|||
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.air.callbacks.wandb import WandbLoggerCallback
|
||||
from ray.tune.integration.wandb import (
|
||||
WandbLoggerCallback,
|
||||
WandbTrainableMixin,
|
||||
wandb_mixin,
|
||||
)
|
||||
|
|
|
@ -1,249 +1,28 @@
|
|||
import os
|
||||
from typing import Dict, List
|
||||
from ray.air.callbacks.comet import CometLoggerCallback as _CometLoggerCallback
|
||||
from typing import List
|
||||
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.tune.utils import flatten_dict
|
||||
import logging
|
||||
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
def _import_comet():
|
||||
"""Try importing comet_ml.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Used to check if comet_ml is installed and, otherwise, pass an informative
|
||||
error message.
|
||||
"""
|
||||
if "COMET_DISABLE_AUTO_LOGGING" not in os.environ:
|
||||
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
|
||||
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
except ImportError:
|
||||
raise RuntimeError("pip install 'comet-ml' to use CometLoggerCallback")
|
||||
|
||||
return comet_ml
|
||||
|
||||
|
||||
class CometLoggerCallback(LoggerCallback):
|
||||
"""CometLoggerCallback for logging Tune results to Comet.
|
||||
|
||||
Comet (https://comet.ml/site/) is a tool to manage and optimize the
|
||||
entire ML lifecycle, from experiment tracking, model optimization
|
||||
and dataset versioning to model production monitoring.
|
||||
|
||||
This Ray Tune ``LoggerCallback`` sends metrics and parameters to
|
||||
Comet for tracking.
|
||||
|
||||
In order to use the CometLoggerCallback you must first install Comet
|
||||
via ``pip install comet_ml``
|
||||
|
||||
Then set the following environment variables
|
||||
``export COMET_API_KEY=<Your API Key>``
|
||||
|
||||
Alternatively, you can also pass in your API Key as an argument to the
|
||||
CometLoggerCallback constructor.
|
||||
|
||||
``CometLoggerCallback(api_key=<Your API Key>)``
|
||||
|
||||
Args:
|
||||
online: Whether to make use of an Online or
|
||||
Offline Experiment. Defaults to True.
|
||||
tags: Tags to add to the logged Experiment.
|
||||
Defaults to None.
|
||||
save_checkpoints: If ``True``, model checkpoints will be saved to
|
||||
Comet ML as artifacts. Defaults to ``False``.
|
||||
**experiment_kwargs: Other keyword arguments will be passed to the
|
||||
constructor for comet_ml.Experiment (or OfflineExperiment if
|
||||
online=False).
|
||||
|
||||
Please consult the Comet ML documentation for more information on the
|
||||
Experiment and OfflineExperiment classes: https://comet.ml/site/
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.comet import CometLoggerCallback
|
||||
tune.run(
|
||||
train,
|
||||
config=config
|
||||
callbacks=[CometLoggerCallback(
|
||||
True,
|
||||
['tag1', 'tag2'],
|
||||
workspace='my_workspace',
|
||||
project_name='my_project_name'
|
||||
)]
|
||||
callback_deprecation_message = (
|
||||
"`ray.tune.integration.comet.CometLoggerCallback` "
|
||||
"is deprecated and will be removed in "
|
||||
"the future. Please use `ray.air.callbacks.comet.CometLoggerCallback` "
|
||||
"instead."
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
# Do not enable these auto log options unless overridden
|
||||
_exclude_autolog = [
|
||||
"auto_output_logging",
|
||||
"log_git_metadata",
|
||||
"log_git_patch",
|
||||
"log_env_cpu",
|
||||
"log_env_gpu",
|
||||
]
|
||||
|
||||
# Do not log these metrics.
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# These values should be logged as system info instead of metrics.
|
||||
_system_results = ["node_ip", "hostname", "pid", "date"]
|
||||
|
||||
# These values should be logged as "Other" instead of as metrics.
|
||||
_other_results = ["trial_id", "experiment_id", "experiment_tag"]
|
||||
|
||||
_episode_results = ["hist_stats/episode_reward", "hist_stats/episode_lengths"]
|
||||
|
||||
@Deprecated(message=callback_deprecation_message)
|
||||
class CometLoggerCallback(_CometLoggerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
online: bool = True,
|
||||
tags: List[str] = None,
|
||||
save_checkpoints: bool = False,
|
||||
**experiment_kwargs,
|
||||
**experiment_kwargs
|
||||
):
|
||||
_import_comet()
|
||||
self.online = online
|
||||
self.tags = tags
|
||||
self.save_checkpoints = save_checkpoints
|
||||
self.experiment_kwargs = experiment_kwargs
|
||||
|
||||
# Disable the specific autologging features that cause throttling.
|
||||
self._configure_experiment_defaults()
|
||||
|
||||
# Mapping from trial to experiment object.
|
||||
self._trial_experiments = {}
|
||||
|
||||
self._to_exclude = self._exclude_results.copy()
|
||||
self._to_system = self._system_results.copy()
|
||||
self._to_other = self._other_results.copy()
|
||||
self._to_episodes = self._episode_results.copy()
|
||||
|
||||
def _configure_experiment_defaults(self):
|
||||
"""Disable the specific autologging features that cause throttling."""
|
||||
for option in self._exclude_autolog:
|
||||
if not self.experiment_kwargs.get(option):
|
||||
self.experiment_kwargs[option] = False
|
||||
|
||||
def _check_key_name(self, key: str, item: str) -> bool:
|
||||
"""
|
||||
Check if key argument is equal to item argument or starts with item and
|
||||
a forward slash. Used for parsing trial result dictionary into ignored
|
||||
keys, system metrics, episode logs, etc.
|
||||
"""
|
||||
return key.startswith(item + "/") or key == item
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
"""
|
||||
Initialize an Experiment (or OfflineExperiment if self.online=False)
|
||||
and start logging to Comet.
|
||||
|
||||
Args:
|
||||
trial: Trial object.
|
||||
|
||||
"""
|
||||
_import_comet() # is this necessary?
|
||||
from comet_ml import Experiment, OfflineExperiment
|
||||
from comet_ml.config import set_global_experiment
|
||||
|
||||
if trial not in self._trial_experiments:
|
||||
experiment_cls = Experiment if self.online else OfflineExperiment
|
||||
experiment = experiment_cls(**self.experiment_kwargs)
|
||||
self._trial_experiments[trial] = experiment
|
||||
# Set global experiment to None to allow for multiple experiments.
|
||||
set_global_experiment(None)
|
||||
else:
|
||||
experiment = self._trial_experiments[trial]
|
||||
|
||||
experiment.set_name(str(trial))
|
||||
experiment.add_tags(self.tags)
|
||||
experiment.log_other("Created from", "Ray")
|
||||
|
||||
config = trial.config.copy()
|
||||
config.pop("callbacks", None)
|
||||
experiment.log_parameters(config)
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
"""
|
||||
Log the current result of a Trial upon each iteration.
|
||||
"""
|
||||
if trial not in self._trial_experiments:
|
||||
self.log_trial_start(trial)
|
||||
experiment = self._trial_experiments[trial]
|
||||
step = result["training_iteration"]
|
||||
|
||||
config_update = result.pop("config", {}).copy()
|
||||
config_update.pop("callbacks", None) # Remove callbacks
|
||||
for k, v in config_update.items():
|
||||
if isinstance(v, dict):
|
||||
experiment.log_parameters(flatten_dict({k: v}, "/"), step=step)
|
||||
|
||||
else:
|
||||
experiment.log_parameter(k, v, step=step)
|
||||
|
||||
other_logs = {}
|
||||
metric_logs = {}
|
||||
system_logs = {}
|
||||
episode_logs = {}
|
||||
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
for k, v in flat_result.items():
|
||||
if any(self._check_key_name(k, item) for item in self._to_exclude):
|
||||
continue
|
||||
|
||||
if any(self._check_key_name(k, item) for item in self._to_other):
|
||||
other_logs[k] = v
|
||||
|
||||
elif any(self._check_key_name(k, item) for item in self._to_system):
|
||||
system_logs[k] = v
|
||||
|
||||
elif any(self._check_key_name(k, item) for item in self._to_episodes):
|
||||
episode_logs[k] = v
|
||||
|
||||
else:
|
||||
metric_logs[k] = v
|
||||
|
||||
experiment.log_others(other_logs)
|
||||
experiment.log_metrics(metric_logs, step=step)
|
||||
|
||||
for k, v in system_logs.items():
|
||||
experiment.log_system_info(k, v)
|
||||
|
||||
for k, v in episode_logs.items():
|
||||
experiment.log_curve(k, x=range(len(v)), y=v, step=step)
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
comet_ml = _import_comet()
|
||||
|
||||
if self.save_checkpoints and trial.checkpoint:
|
||||
experiment = self._trial_experiments[trial]
|
||||
|
||||
artifact = comet_ml.Artifact(
|
||||
name=f"checkpoint_{(str(trial))}", artifact_type="model"
|
||||
)
|
||||
|
||||
# Walk through checkpoint directory and add all files to artifact
|
||||
checkpoint_root = trial.checkpoint.dir_or_data
|
||||
for root, dirs, files in os.walk(checkpoint_root):
|
||||
rel_root = os.path.relpath(root, checkpoint_root)
|
||||
for file in files:
|
||||
local_file = os.path.join(checkpoint_root, rel_root, file)
|
||||
logical_path = os.path.join(rel_root, file)
|
||||
|
||||
# Strip leading `./`
|
||||
if logical_path.startswith("./"):
|
||||
logical_path = logical_path[2:]
|
||||
|
||||
artifact.add(local_file, logical_path=logical_path)
|
||||
|
||||
experiment.log_artifact(artifact)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
self._trial_experiments[trial].end()
|
||||
del self._trial_experiments[trial]
|
||||
|
||||
def __del__(self):
|
||||
for trial, experiment in self._trial_experiments.items():
|
||||
experiment.end()
|
||||
self._trial_experiments = {}
|
||||
logging.warning(callback_deprecation_message)
|
||||
super().__init__(online, tags, save_checkpoints, **experiment_kwargs)
|
||||
|
|
|
@ -1,65 +1,25 @@
|
|||
from ray.air.callbacks.mlflow import MLflowLoggerCallback as _MLflowLoggerCallback
|
||||
|
||||
import logging
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import ray
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.util.annotations import Deprecated
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
callback_deprecation_message = (
|
||||
"`ray.tune.integration.mlflow.MLflowLoggerCallback` "
|
||||
"is deprecated and will be removed in "
|
||||
"the future. Please use `ray.air.callbacks.mlflow.MLflowLoggerCallback` "
|
||||
"instead."
|
||||
)
|
||||
|
||||
class MLflowLoggerCallback(LoggerCallback):
|
||||
"""MLflow Logger to automatically log Tune results and config to MLflow.
|
||||
|
||||
MLflow (https://mlflow.org) Tracking is an open source library for
|
||||
recording and querying experiments. This Ray Tune ``LoggerCallback``
|
||||
sends information (config parameters, training results & metrics,
|
||||
and artifacts) to MLflow for automatic experiment tracking.
|
||||
|
||||
Args:
|
||||
tracking_uri: The tracking URI for where to manage experiments
|
||||
and runs. This can either be a local file path or a remote server.
|
||||
This arg gets passed directly to mlflow
|
||||
initialization. When using Tune in a multi-node setting, make sure
|
||||
to set this to a remote server and not a local file path.
|
||||
registry_uri: The registry URI that gets passed directly to
|
||||
mlflow initialization.
|
||||
experiment_name: The experiment name to use for this Tune run.
|
||||
If the experiment with the name already exists with MLflow,
|
||||
it will be reused. If not, a new experiment will be created with
|
||||
that name.
|
||||
tags: An optional dictionary of string keys and values to set
|
||||
as tags on the run
|
||||
save_artifact: If set to True, automatically save the entire
|
||||
contents of the Tune local_dir as an artifact to the
|
||||
corresponding run in MlFlow.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.mlflow import MLflowLoggerCallback
|
||||
|
||||
tags = { "user_name" : "John",
|
||||
"git_commit_hash" : "abc123"}
|
||||
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[MLflowLoggerCallback(
|
||||
experiment_name="experiment1",
|
||||
tags=tags,
|
||||
save_artifact=True)])
|
||||
|
||||
"""
|
||||
|
||||
@Deprecated(message=callback_deprecation_message)
|
||||
class MLflowLoggerCallback(_MLflowLoggerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
tracking_uri: Optional[str] = None,
|
||||
|
@ -68,72 +28,11 @@ class MLflowLoggerCallback(LoggerCallback):
|
|||
tags: Optional[Dict] = None,
|
||||
save_artifact: bool = False,
|
||||
):
|
||||
|
||||
self.tracking_uri = tracking_uri
|
||||
self.registry_uri = registry_uri
|
||||
self.experiment_name = experiment_name
|
||||
self.tags = tags
|
||||
self.should_save_artifact = save_artifact
|
||||
|
||||
self.mlflow_util = _MLflowLoggerUtil()
|
||||
|
||||
if ray.util.client.ray.is_connected():
|
||||
logger.warning(
|
||||
"When using MLflowLoggerCallback with Ray Client, "
|
||||
"it is recommended to use a remote tracking "
|
||||
"server. If you are using a MLflow tracking server "
|
||||
"backed by the local filesystem, then it must be "
|
||||
"setup on the server side and not on the client "
|
||||
"side."
|
||||
logger.warning(callback_deprecation_message)
|
||||
super().__init__(
|
||||
tracking_uri, registry_uri, experiment_name, tags, save_artifact
|
||||
)
|
||||
|
||||
def setup(self, *args, **kwargs):
|
||||
# Setup the mlflow logging util.
|
||||
self.mlflow_util.setup_mlflow(
|
||||
tracking_uri=self.tracking_uri,
|
||||
registry_uri=self.registry_uri,
|
||||
experiment_name=self.experiment_name,
|
||||
)
|
||||
|
||||
if self.tags is None:
|
||||
# Create empty dictionary for tags if not given explicitly
|
||||
self.tags = {}
|
||||
|
||||
self._trial_runs = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
# Create run if not already exists.
|
||||
if trial not in self._trial_runs:
|
||||
|
||||
# Set trial name in tags
|
||||
tags = self.tags.copy()
|
||||
tags["trial_name"] = str(trial)
|
||||
|
||||
run = self.mlflow_util.start_run(tags=tags, run_name=str(trial))
|
||||
self._trial_runs[trial] = run.info.run_id
|
||||
|
||||
run_id = self._trial_runs[trial]
|
||||
|
||||
# Log the config parameters.
|
||||
config = trial.config
|
||||
self.mlflow_util.log_params(run_id=run_id, params_to_log=config)
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
run_id = self._trial_runs[trial]
|
||||
self.mlflow_util.log_metrics(run_id=run_id, metrics_to_log=result, step=step)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
run_id = self._trial_runs[trial]
|
||||
|
||||
# Log the artifact if set_artifact is set to True.
|
||||
if self.should_save_artifact:
|
||||
self.mlflow_util.save_artifacts(run_id=run_id, dir=trial.logdir)
|
||||
|
||||
# Stop the run once trial finishes.
|
||||
status = "FINISHED" if not failed else "FAILED"
|
||||
self.mlflow_util.end_run(run_id=run_id, status=status)
|
||||
|
||||
|
||||
def mlflow_mixin(func: Callable):
|
||||
"""mlflow_mixin
|
||||
|
|
|
@ -1,84 +1,54 @@
|
|||
import enum
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Sequence
|
||||
from multiprocessing import Process, Queue
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import urllib
|
||||
from typing import List, Dict, Callable, Optional
|
||||
|
||||
from ray import logger
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.trainable import FunctionTrainable
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.experiment import Trial
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
logger.error("pip install 'wandb' to use WandbLoggerCallback/WandbTrainableMixin.")
|
||||
wandb = None
|
||||
|
||||
WANDB_ENV_VAR = "WANDB_API_KEY"
|
||||
_VALID_TYPES = (Number, wandb.data_types.Video, wandb.data_types.Image)
|
||||
_VALID_ITERABLE_TYPES = (wandb.data_types.Video, wandb.data_types.Image)
|
||||
|
||||
|
||||
def _is_allowed_type(obj):
|
||||
"""Return True if type is allowed for logging to wandb"""
|
||||
if isinstance(obj, np.ndarray) and obj.size == 1:
|
||||
return isinstance(obj.item(), Number)
|
||||
if isinstance(obj, Sequence) and len(obj) > 0:
|
||||
return isinstance(obj[0], _VALID_ITERABLE_TYPES)
|
||||
return isinstance(obj, _VALID_TYPES)
|
||||
|
||||
|
||||
def _clean_log(obj: Any):
|
||||
# Fixes https://github.com/ray-project/ray/issues/10631
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_log(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [_clean_log(v) for v in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(_clean_log(v) for v in obj)
|
||||
elif _is_allowed_type(obj):
|
||||
return obj
|
||||
|
||||
# Else
|
||||
try:
|
||||
pickle.dumps(obj)
|
||||
yaml.dump(
|
||||
obj,
|
||||
Dumper=yaml.SafeDumper,
|
||||
default_flow_style=False,
|
||||
allow_unicode=True,
|
||||
encoding="utf-8",
|
||||
from ray.air.callbacks.wandb import (
|
||||
wandb,
|
||||
_clean_log,
|
||||
_set_api_key,
|
||||
WandbLoggerCallback as _WandbLoggerCallback,
|
||||
)
|
||||
return obj
|
||||
except Exception:
|
||||
# give up, similar to _SafeFallBackEncoder
|
||||
fallback = str(obj)
|
||||
|
||||
# Try to convert to int
|
||||
try:
|
||||
fallback = int(fallback)
|
||||
return fallback
|
||||
except ValueError:
|
||||
pass
|
||||
import logging
|
||||
|
||||
# Try to convert to float
|
||||
try:
|
||||
fallback = float(fallback)
|
||||
return fallback
|
||||
except ValueError:
|
||||
pass
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
# Else, return string
|
||||
return fallback
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
callback_deprecation_message = (
|
||||
"`ray.tune.integration.wandb.WandbLoggerCallback` "
|
||||
"is deprecated and will be removed in "
|
||||
"the future. Please use `ray.air.callbacks.wandb.WandbLoggerCallback` "
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@Deprecated(message=callback_deprecation_message)
|
||||
class WandbLoggerCallback(_WandbLoggerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
group: Optional[str] = None,
|
||||
api_key_file: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
excludes: Optional[List[str]] = None,
|
||||
log_config: bool = False,
|
||||
save_checkpoints: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
logger.warning(callback_deprecation_message)
|
||||
super().__init__(
|
||||
project,
|
||||
group,
|
||||
api_key_file,
|
||||
api_key,
|
||||
excludes,
|
||||
log_config,
|
||||
save_checkpoints,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def wandb_mixin(func: Callable):
|
||||
|
@ -155,297 +125,6 @@ def wandb_mixin(func: Callable):
|
|||
return func
|
||||
|
||||
|
||||
def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
|
||||
"""Set WandB API key from `wandb_config`. Will pop the
|
||||
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
|
||||
if api_key_file:
|
||||
if api_key:
|
||||
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
|
||||
with open(api_key_file, "rt") as fp:
|
||||
api_key = fp.readline().strip()
|
||||
if api_key:
|
||||
os.environ[WANDB_ENV_VAR] = api_key
|
||||
elif not os.environ.get(WANDB_ENV_VAR):
|
||||
try:
|
||||
# Check if user is already logged into wandb.
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key:
|
||||
logger.info("Already logged into W&B.")
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
raise ValueError(
|
||||
"No WandB API key found. Either set the {} environment "
|
||||
"variable, pass `api_key` or `api_key_file` to the"
|
||||
"`WandbLoggerCallback` class as arguments, "
|
||||
"or run `wandb login` from the command line".format(WANDB_ENV_VAR)
|
||||
)
|
||||
|
||||
|
||||
class _QueueItem(enum.Enum):
|
||||
END = enum.auto()
|
||||
RESULT = enum.auto()
|
||||
CHECKPOINT = enum.auto()
|
||||
|
||||
|
||||
class _WandbLoggingProcess(Process):
|
||||
"""
|
||||
We need a `multiprocessing.Process` to allow multiple concurrent
|
||||
wandb logging instances locally.
|
||||
|
||||
We use a queue for the driver to communicate with the logging process.
|
||||
The queue accepts the following items:
|
||||
|
||||
- If it's a dict, it is assumed to be a result and will be logged using
|
||||
``wandb.log()``
|
||||
- If it's a checkpoint object, it will be saved using ``wandb.log_artifact()``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logdir: str,
|
||||
queue: Queue,
|
||||
exclude: List[str],
|
||||
to_config: List[str],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super(_WandbLoggingProcess, self).__init__()
|
||||
|
||||
os.chdir(logdir)
|
||||
self.queue = queue
|
||||
self._exclude = set(exclude)
|
||||
self._to_config = set(to_config)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_name = self.kwargs.get("name", "unknown")
|
||||
|
||||
def run(self):
|
||||
# Since we're running in a separate process already, use threads.
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
wandb.init(*self.args, **self.kwargs)
|
||||
|
||||
while True:
|
||||
item_type, item_content = self.queue.get()
|
||||
if item_type == _QueueItem.END:
|
||||
break
|
||||
|
||||
if item_type == _QueueItem.CHECKPOINT:
|
||||
self._handle_checkpoint(item_content)
|
||||
continue
|
||||
|
||||
assert item_type == _QueueItem.RESULT
|
||||
log, config_update = self._handle_result(item_content)
|
||||
try:
|
||||
wandb.config.update(config_update, allow_val_change=True)
|
||||
wandb.log(log)
|
||||
except urllib.error.HTTPError as e:
|
||||
# Ignore HTTPError. Missing a few data points is not a
|
||||
# big issue, as long as things eventually recover.
|
||||
logger.warn("Failed to log result to w&b: {}".format(str(e)))
|
||||
wandb.finish()
|
||||
|
||||
def _handle_checkpoint(self, checkpoint_path: str):
|
||||
artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
|
||||
artifact.add_dir(checkpoint_path)
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
|
||||
config_update = result.get("config", {}).copy()
|
||||
log = {}
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
|
||||
for k, v in flat_result.items():
|
||||
if any(k.startswith(item + "/") or k == item for item in self._to_config):
|
||||
config_update[k] = v
|
||||
elif any(k.startswith(item + "/") or k == item for item in self._exclude):
|
||||
continue
|
||||
elif not _is_allowed_type(v):
|
||||
continue
|
||||
else:
|
||||
log[k] = v
|
||||
|
||||
config_update.pop("callbacks", None) # Remove callbacks
|
||||
return log, config_update
|
||||
|
||||
|
||||
class WandbLoggerCallback(LoggerCallback):
|
||||
"""WandbLoggerCallback
|
||||
|
||||
Weights and biases (https://www.wandb.ai/) is a tool for experiment
|
||||
tracking, model optimization, and dataset versioning. This Ray Tune
|
||||
``LoggerCallback`` sends metrics to Wandb for automatic tracking and
|
||||
visualization.
|
||||
|
||||
Args:
|
||||
project: Name of the Wandb project. Mandatory.
|
||||
group: Name of the Wandb group. Defaults to the trainable
|
||||
name.
|
||||
api_key_file: Path to file containing the Wandb API KEY. This
|
||||
file only needs to be present on the node running the Tune script
|
||||
if using the WandbLogger.
|
||||
api_key: Wandb API Key. Alternative to setting ``api_key_file``.
|
||||
excludes: List of metrics that should be excluded from
|
||||
the log.
|
||||
log_config: Boolean indicating if the ``config`` parameter of
|
||||
the ``results`` dict should be logged. This makes sense if
|
||||
parameters will change during training, e.g. with
|
||||
PopulationBasedTraining. Defaults to False.
|
||||
save_checkpoints: If ``True``, model checkpoints will be saved to
|
||||
Wandb as artifacts. Defaults to ``False``.
|
||||
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
|
||||
|
||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||
by Tune, but can be overwritten by filling out the respective configuration
|
||||
values.
|
||||
|
||||
Please see here for all other valid configuration settings:
|
||||
https://docs.wandb.ai/library/init
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[WandbLoggerCallback(
|
||||
project="Optimization_Project",
|
||||
api_key_file="/path/to/file",
|
||||
log_config=True)])
|
||||
|
||||
"""
|
||||
|
||||
# Do not log these result keys
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# Use these result keys to update `wandb.config`
|
||||
_config_results = [
|
||||
"trial_id",
|
||||
"experiment_tag",
|
||||
"node_ip",
|
||||
"experiment_id",
|
||||
"hostname",
|
||||
"pid",
|
||||
"date",
|
||||
]
|
||||
|
||||
_logger_process_cls = _WandbLoggingProcess
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
group: Optional[str] = None,
|
||||
api_key_file: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
excludes: Optional[List[str]] = None,
|
||||
log_config: bool = False,
|
||||
save_checkpoints: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.project = project
|
||||
self.group = group
|
||||
self.api_key_path = api_key_file
|
||||
self.api_key = api_key
|
||||
self.excludes = excludes or []
|
||||
self.log_config = log_config
|
||||
self.save_checkpoints = save_checkpoints
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
|
||||
self._trial_queues: Dict["Trial", Queue] = {}
|
||||
|
||||
def setup(self, *args, **kwargs):
|
||||
self.api_key_file = (
|
||||
os.path.expanduser(self.api_key_path) if self.api_key_path else None
|
||||
)
|
||||
_set_api_key(self.api_key_file, self.api_key)
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
config = trial.config.copy()
|
||||
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
exclude_results += self.excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
if not self.log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = trial.trial_id if trial else None
|
||||
trial_name = str(trial) if trial else None
|
||||
|
||||
# Project name for Wandb
|
||||
wandb_project = self.project
|
||||
|
||||
# Grouping
|
||||
wandb_group = self.group or trial.trainable_name if trial else None
|
||||
|
||||
# remove unpickleable items!
|
||||
config = _clean_log(config)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=trial_name,
|
||||
resume=False,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config,
|
||||
)
|
||||
wandb_init_kwargs.update(self.kwargs)
|
||||
|
||||
self._trial_queues[trial] = Queue()
|
||||
self._trial_processes[trial] = self._logger_process_cls(
|
||||
logdir=trial.logdir,
|
||||
queue=self._trial_queues[trial],
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs,
|
||||
)
|
||||
self._trial_processes[trial].start()
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
if trial not in self._trial_processes:
|
||||
self.log_trial_start(trial)
|
||||
|
||||
result = _clean_log(result)
|
||||
self._trial_queues[trial].put((_QueueItem.RESULT, result))
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
if self.save_checkpoints and trial.checkpoint:
|
||||
self._trial_queues[trial].put(
|
||||
(_QueueItem.CHECKPOINT, trial.checkpoint.dir_or_data)
|
||||
)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||
self._trial_processes[trial].join(timeout=10)
|
||||
|
||||
del self._trial_queues[trial]
|
||||
del self._trial_processes[trial]
|
||||
|
||||
def __del__(self):
|
||||
for trial in self._trial_processes:
|
||||
if trial in self._trial_queues:
|
||||
self._trial_queues[trial].put((_QueueItem.END, None))
|
||||
del self._trial_queues[trial]
|
||||
self._trial_processes[trial].join(timeout=2)
|
||||
del self._trial_processes[trial]
|
||||
|
||||
|
||||
class WandbTrainableMixin:
|
||||
_wandb = wandb
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
from ray.tune.integration.comet import CometLoggerCallback
|
||||
from ray.air.callbacks.comet import CometLoggerCallback
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
|
|
|
@ -8,10 +8,12 @@ from mlflow.tracking import MlflowClient
|
|||
|
||||
from ray.tune.trainable import wrap_function
|
||||
from ray.tune.integration.mlflow import (
|
||||
MLflowLoggerCallback,
|
||||
MLflowTrainableMixin,
|
||||
mlflow_mixin,
|
||||
)
|
||||
from ray.air.callbacks.mlflow import (
|
||||
MLflowLoggerCallback,
|
||||
)
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
|
||||
|
||||
|
@ -136,7 +138,7 @@ class MLflowTest(unittest.TestCase):
|
|||
logger.setup()
|
||||
self.assertEqual(logger.tags, tags)
|
||||
|
||||
@patch("ray.tune.integration.mlflow._MLflowLoggerUtil", Mock_MLflowLoggerUtil)
|
||||
@patch("ray.air.callbacks.mlflow._MLflowLoggerUtil", Mock_MLflowLoggerUtil)
|
||||
def testMlFlowLoggerLogging(self):
|
||||
clear_env_vars()
|
||||
trial_config = {"par1": "a", "par2": "b"}
|
||||
|
|
|
@ -9,11 +9,13 @@ import numpy as np
|
|||
from ray.tune import Trainable
|
||||
from ray.tune.trainable import wrap_function
|
||||
from ray.tune.integration.wandb import (
|
||||
WandbTrainableMixin,
|
||||
wandb_mixin,
|
||||
)
|
||||
from ray.air.callbacks.wandb import (
|
||||
WandbLoggerCallback,
|
||||
_WandbLoggingProcess,
|
||||
WANDB_ENV_VAR,
|
||||
WandbTrainableMixin,
|
||||
wandb_mixin,
|
||||
_QueueItem,
|
||||
)
|
||||
from ray.tune.result import TRIAL_INFO
|
||||
|
|
Loading…
Add table
Reference in a new issue