{
"cells": [
{
"cell_type": "markdown",
"id": "3b05af3b",
"metadata": {},
"source": [
"(tune-comet-ref)=\n",
"\n",
"# Using Comet with Tune\n",
"\n",
"[Comet](https://www.comet.ml/site/) is a tool to manage and optimize the\n",
"entire ML lifecycle, from experiment tracking, model optimization and dataset\n",
"versioning to model production monitoring.\n",
"\n",
"```{image} /images/comet_logo_full.png\n",
":align: center\n",
":alt: Comet\n",
":height: 120px\n",
":target: https://www.comet.ml/site/\n",
"```\n",
"\n",
"```{contents}\n",
":backlinks: none\n",
":local: true\n",
"```\n",
"\n",
"## Example\n",
"\n",
"To illustrate logging your trial results to Comet, we'll define a simple training function\n",
"that simulates a `loss` metric:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "19e3c389",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from ray import air, tune\n",
"from ray.air import session\n",
"\n",
"\n",
"def train_function(config, checkpoint_dir=None):\n",
" for i in range(30):\n",
" loss = config[\"mean\"] + config[\"sd\"] * np.random.randn()\n",
" session.report({\"loss\": loss})"
]
},
{
"cell_type": "markdown",
"id": "6fb69a24",
"metadata": {},
"source": [
"Now, given that you provide your Comet API key and your project name like so:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "993d5be6",
"metadata": {},
"outputs": [],
"source": [
"api_key = \"YOUR_COMET_API_KEY\"\n",
"project_name = \"YOUR_COMET_PROJECT_NAME\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e9ce0d76",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# This cell is hidden from the rendered notebook. It makes the \n",
"from unittest.mock import MagicMock\n",
"from ray.air.callbacks.comet import CometLoggerCallback\n",
"\n",
"CometLoggerCallback._logger_process_cls = MagicMock\n",
"api_key = \"abc\"\n",
"project_name = \"test\""
]
},
{
"cell_type": "markdown",
"id": "d792a1b0",
"metadata": {},
"source": [
"You can add a Comet logger by specifying the `callbacks` argument in your `RunConfig()` accordingly:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dbb761e7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-22 15:41:21,477\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8267\u001b[39m\u001b[22m\n",
"/Users/kai/coding/ray/python/ray/tune/trainable/function_trainable.py:643: DeprecationWarning: `checkpoint_dir` in `func(config, checkpoint_dir)` is being deprecated. To save and load checkpoint in trainable functions, please use the `ray.air.session` API:\n",
"\n",
"from ray.air import session\n",
"\n",
"def train(config):\n",
" # ...\n",
" session.report({\"metric\": metric}, checkpoint=checkpoint)\n",
"\n",
"For more information please see https://docs.ray.io/en/master/ray-air/key-concepts.html#session\n",
"\n",
" DeprecationWarning,\n"
]
},
{
"data": {
"text/html": [
"== Status ==
Current time: 2022-07-22 15:41:31 (running for 00:00:06.73)
Memory usage on this node: 9.9/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.5 GiB heap, 0.0/2.0 GiB objects
Current best trial: 5bf98_00000 with loss=1.0234101880766688 and parameters={'mean': 1, 'sd': 0.40575843135279466}
Result logdir: /Users/kai/ray_results/train_function_2022-07-22_15-41-18
Number of trials: 3/3 (3 TERMINATED)
Trial name | status | loc | mean | sd | iter | total time (s) | loss |
---|---|---|---|---|---|---|---|
train_function_5bf98_00000 | TERMINATED | 127.0.0.1:48140 | 1 | 0.405758 | 30 | 2.11758 | 1.02341 |
train_function_5bf98_00001 | TERMINATED | 127.0.0.1:48147 | 2 | 0.647335 | 30 | 0.0770731 | 1.53993 |
train_function_5bf98_00002 | TERMINATED | 127.0.0.1:48151 | 3 | 0.256568 | 30 | 0.0728431 | 3.0393 |