ray/doc/source/ray-air/examples/huggingface_text_classification.ipynb

1549 lines
71 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tune a 🤗 Transformers model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VaFMt6AIhYbK"
},
"source": [
"This notebook is based on [an official 🤗 notebook - \"How to fine-tune a model on text classification\"](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb). The main aim of this notebook is to show the process of conversion from vanilla 🤗 to [Ray AIR](https://docs.ray.io/en/latest/ray-air/getting-started.html) 🤗 without changing the training logic unless necessary.\n",
"\n",
"In this notebook, we will:\n",
"1. [Set up Ray](#setup)\n",
"2. [Load the dataset](#load)\n",
"3. [Preprocess the dataset](#preprocess)\n",
"4. [Run the training with Ray AIR](#train)\n",
"5. [Predict on test data with Ray AIR](#predict)\n",
"6. [Optionally, share the model with the community](#share)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQbdfyWQhYbO"
},
"source": [
"Uncomment and run the following line in order to install all the necessary dependencies:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YajFzmkthYbO"
},
"outputs": [],
"source": [
"#! pip install \"datasets\" \"transformers>=4.19.0\" \"torch>=1.10.0\" \"mlflow\" \"ray[air]>=1.13\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pvSRaEHChYbP"
},
"source": [
"## Set up Ray <a name=\"setup\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LRdL3kWBhYbQ"
},
"source": [
"We will use `ray.init()` to initialize a local cluster. By default, this cluster will be compromised of only the machine you are running this notebook on. You can also run this notebook on an Anyscale cluster.\n",
"\n",
"This notebook *will not* run in [Ray Client](https://docs.ray.io/en/latest/cluster/ray-client.html) mode."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MOsHUjgdIrIW",
"outputId": "e527bdbb-2f28-4142-cca0-762e0566cbcd"
},
"outputs": [
{
"data": {
"text/plain": [
"RayContext(dashboard_url='', python_version='3.7.13', ray_version='2.0.0.dev0', ray_commit='e2ee2140f97ca08b70fd0f7561038b7f8d958d63', address_info={'node_ip_address': '172.28.0.2', 'raylet_ip_address': '172.28.0.2', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-05-12_18-30-10_467499_75/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-05-12_18-30-10_467499_75/sockets/raylet', 'webui_url': '', 'session_dir': '/tmp/ray/session_2022-05-12_18-30-10_467499_75', 'metrics_export_port': 64840, 'gcs_address': '172.28.0.2:58661', 'address': '172.28.0.2:58661', 'node_id': '65d091b8f504ccd72024fd0b1a8445a8f9ea43e86bcbf67868c22ba7'})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pprint import pprint\n",
"import ray\n",
"\n",
"ray.init()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oJiSdWy2hYbR"
},
"source": [
"We can check the resources our cluster is composed of. If you are running this notebook on your local machine or Google Colab, you should see the number of CPU cores and GPUs available on the said machine."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KlMz0dt9hYbS",
"outputId": "2d485449-ee69-4334-fcba-47e0ceb63078"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'CPU': 2.0,\n",
" 'GPU': 1.0,\n",
" 'accelerator_type:T4': 1.0,\n",
" 'memory': 7855477556.0,\n",
" 'node:172.28.0.2': 1.0,\n",
" 'object_store_memory': 3927738777.0}\n"
]
}
],
"source": [
"pprint(ray.cluster_resources())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uS6oeJELhYbS"
},
"source": [
"In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a text classification task of the [GLUE Benchmark](https://gluebenchmark.com/). We will be running the training using [Ray AIR](https://docs.ray.io/en/latest/ray-air/getting-started.html).\n",
"\n",
"You can change those two variables to control whether the training (which we will get to later) uses CPUs or GPUs, and how many workers should be spawned. Each worker will claim one CPU or GPU. Make sure not to request more resources than the resources present!\n",
"\n",
"By default, we will run the training with one GPU worker."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "gAbhv9OqhYbT"
},
"outputs": [],
"source": [
"use_gpu = True # set this to False to run on CPUs\n",
"num_workers = 1 # set this to number of GPUs/CPUs you want to use"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rEJBSTyZIrIb"
},
"source": [
"## Fine-tuning a model on a text classification task"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kTCFado4IrIc"
},
"source": [
"The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences. If you would like to learn more, refer to the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb).\n",
"\n",
"Each task is named by its acronym, with `mnli-mm` standing for the mismatched version of MNLI (so same training set as `mnli` but different validation and test sets):"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "YZbiBDuGIrId"
},
"outputs": [],
"source": [
"GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4RRkXuteIrIh"
},
"source": [
"This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on you model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "zVvslsfMIrIh"
},
"outputs": [],
"source": [
"task = \"cola\"\n",
"model_checkpoint = \"distilbert-base-uncased\"\n",
"batch_size = 16"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "whPRbBNbIrIl"
},
"source": [
"### Loading the dataset <a name=\"load\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7QYTpxXIrIl"
},
"source": [
"We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.\n",
"\n",
"Apart from `mnli-mm` being a special code, we can directly pass our task name to those functions.\n",
"\n",
"As Ray AIR doesn't provide integrations for 🤗 Datasets yet, we will simply run the normal 🤗 Datasets code to load the dataset from the Hub."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 200
},
"id": "MwhAeEOuhYbV",
"outputId": "3aff8c73-d6eb-4784-890a-a419403b5bda"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bf499d18407642489b7f5acb9dc88ca8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading builder script: 0%| | 0.00/7.78k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "032a4b0c60f04ad1839898524ffeb290",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading metadata: 0%| | 0.00/4.47k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset glue/cola (download: 368.14 KiB, generated: 596.73 KiB, post-processed: Unknown size, total: 964.86 KiB) to /root/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "360558368bf64c35ab14378a2183c644",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/377k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1a1ff1601285496b8fd00c40f0633720",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/8551 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "16dde3df50d74f25adac0db6a210eef8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating validation split: 0%| | 0/1043 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a4dd6698d1f54126b61f1fd0d0dde1f9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/1063 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1fa3ae216f64c0ab17b50ddc8e536b1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"actual_task = \"mnli\" if task == \"mnli-mm\" else task\n",
"datasets = load_dataset(\"glue\", actual_task)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzfPtOMoIrIu"
},
"source": [
"The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of `mnli`)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_TOee7nohYbW"
},
"source": [
"We will also need the metric. In order to avoid serialization errors, we will load the metric inside the training workers later. Therefore, now we will just define the function we will use."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "FNE583uBhYbW"
},
"outputs": [],
"source": [
"from datasets import load_metric\n",
"\n",
"def load_metric_fn():\n",
" return load_metric('glue', actual_task)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lnjDIuQ3IrI-"
},
"source": [
"The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n9qywopnIrJH"
},
"source": [
"### Preprocessing the data <a name=\"preprocess\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YVx71GdAIrJH"
},
"source": [
"Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.\n",
"\n",
"To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n",
"\n",
"- we get a tokenizer that corresponds to the model architecture we want to use,\n",
"- we download the vocabulary used when pretraining this specific checkpoint."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 145
},
"id": "eXNLu_-nIrJI",
"outputId": "f545a7a5-f341-4315-cd89-9942a657aa31"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8afaa1d7c12a41db8ad9f37c4067bfd4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/28.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c5849fe79464a3c990b1bdc140b3860",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/483 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "173cb43e6d594a87bd7a8a0fc6888aeb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/226k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6ea57663b5244adfa0780b8aca40a035",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/455k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vl6IidfdIrJK"
},
"source": [
"We pass along `use_fast=True` to the call above to use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. Those fast tokenizers are available for almost all models, but if you got an error with the previous call, remove that argument."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qo_0B1M2IrJM"
},
"source": [
"To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "fyGdtK9oIrJM"
},
"outputs": [],
"source": [
"task_to_keys = {\n",
" \"cola\": (\"sentence\", None),\n",
" \"mnli\": (\"premise\", \"hypothesis\"),\n",
" \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
" \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
" \"qnli\": (\"question\", \"sentence\"),\n",
" \"qqp\": (\"question1\", \"question2\"),\n",
" \"rte\": (\"sentence1\", \"sentence2\"),\n",
" \"sst2\": (\"sentence\", None),\n",
" \"stsb\": (\"sentence1\", \"sentence2\"),\n",
" \"wnli\": (\"sentence1\", \"sentence2\"),\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2C0hcmp9IrJQ"
},
"source": [
"We can them write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "vc0BSBLIIrJQ"
},
"outputs": [],
"source": [
"def preprocess_function(examples, *, tokenizer):\n",
" sentence1_key, sentence2_key = task_to_keys[task]\n",
" if sentence2_key is None:\n",
" return tokenizer(examples[sentence1_key], truncation=True)\n",
" return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zS-6iXTkIrJT"
},
"source": [
"To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 113,
"referenced_widgets": [
"28ff97b9821d495088a0711191c3e12e",
"86872b991c15442584118f1d80fd0002",
"c9c3257fe113444684ecf1ae6f75c29e",
"b2d325c5482c438c85a70c2da36cd87f",
"dc0925463f3840c4abea541a92bb1ea2",
"9115f7abc8bc436593ba2fa467e8d5a6",
"161cce93c7bb46e5a4241a7d18b89684",
"b99ba51d60ce410fb7eda1077f62d682",
"d4a98c5d1c754f5ab0f9fcd077cf679e",
"f36ca6add4eb42e59b2942e13b10ab57",
"b442801df1ca42f48918520235707926",
"ed9f698c7c4f46ff9c520ed0597b6bf6",
"a92090bf5a004510bed17c915ac7ce0f",
"8602f09bbbda43d8846b0eccc72b4e3b",
"ed6adf5ad4154b7c958b91eb99944cd4",
"1af9d6e90a7443ec89afa9d97e887ab9",
"9344f70ece404d25a55280914809b9a0",
"ff92d4134be847aeb6119eb9a9c78954",
"a773b4ebad9f4c9695407a472c767bb0",
"7257356322214ebc80101b3348bea854",
"c5b34a2569c847ea846f29ca955b540f",
"60d3537f850a4fb5ac7cd1f1e65c3a95",
"4f73054b701f4684b3a44793d10d4a0f",
"b5a5d7e5f9bc40289acfaa955fe8055a",
"c26bb829a6a649fc87f0fbf7c881011f",
"c3b85ffc3f044f80b2ed5460570e22bb",
"33f6d5b837b44ff7a8baccc6d592643a",
"72e5b4bb569348209d574dd2777a26e3",
"8a72aa5ea9e14d2e9b16f1ec04590a32",
"6f33096dd60741af910b74719c209ec6",
"0bf58c047d08490da78ede70471f9af8",
"8bf7008eaecb4317b47d62cfbb673299",
"41cbe808a34e473eba315488d1a59624"
]
},
"id": "DDtsaJeVIrJT",
"outputId": "29e116d3-9c07-47a2-9728-4e151747b6f6"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28ff97b9821d495088a0711191c3e12e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/9 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed9f698c7c4f46ff9c520ed0597b6bf6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4f73054b701f4684b3a44793d10d4a0f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"encoded_datasets = datasets.map(preprocess_function, batched=True, fn_kwargs=dict(tokenizer=tokenizer))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "256fOuzjhYbY"
},
"source": [
"For Ray AIR, instead of using 🤗 Dataset objects directly, we will convert them to [Ray Datasets](https://docs.ray.io/en/latest/data/dataset.html). Both are backed by Arrow tables, so the conversion is straightforward. We will use the built-in `ray.data.from_huggingface` function."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "OaTDkPPMhYbY"
},
"outputs": [],
"source": [
"import ray.data\n",
"\n",
"ray_datasets = ray.data.from_huggingface(encoded_datasets)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "545PP3o8IrJV"
},
"source": [
"### Fine-tuning the model with Ray AIR <a name=\"train\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FBiW8UpKIrJW"
},
"source": [
"Now that our data is ready, we can download the pretrained model and fine-tune it.\n",
"\n",
"Since all our tasks are about sentence classification, we use the `AutoModelForSequenceClassification` class.\n",
"\n",
"We will not go into details about each specific component of the training (see the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb) for that). The tokenizer is the same as we have used to encoded the dataset before.\n",
"\n",
"The main difference when using the Ray AIR is that we need to create our 🤗 Transformers `Trainer` inside a function (`trainer_init_per_worker`) and return it. That function will be passed to the `HuggingFaceTrainer` and ran on every Ray worker. The training will then proceed by the means of PyTorch DDP.\n",
"\n",
"Make sure that you initialize the model, metric and tokenizer inside that function. Otherwise, you may run into serialization errors.\n",
"\n",
"Furthermore, `push_to_hub=True` is not yet supported. Ray will however checkpoint the model at every epoch, allowing you to push it to hub manually. We will do that after the training.\n",
"\n",
"If you wish to use thrid party logging libraries, such as MLFlow or Weights&Biases, do not set them in `TrainingArguments` (they will be automatically disabled) - instead, you should be passing Ray AIR callbacks to `HuggingFaceTrainer`'s `run_config`. In this example, we will use MLFlow."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "TlqNaB8jIrJW"
},
"outputs": [],
"source": [
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"import numpy as np\n",
"import torch\n",
"\n",
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"metric_name = \"pearson\" if task == \"stsb\" else \"matthews_correlation\" if task == \"cola\" else \"accuracy\"\n",
"model_name = model_checkpoint.split(\"/\")[-1]\n",
"validation_key = \"validation_mismatched\" if task == \"mnli-mm\" else \"validation_matched\" if task == \"mnli\" else \"validation\"\n",
"name = f\"{model_name}-finetuned-{task}\"\n",
"\n",
"def trainer_init_per_worker(train_dataset, eval_dataset = None, **config):\n",
" print(f\"Is CUDA available: {torch.cuda.is_available()}\")\n",
" metric = load_metric_fn()\n",
" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
" model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)\n",
" args = TrainingArguments(\n",
" name,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" num_train_epochs=5,\n",
" weight_decay=0.01,\n",
" push_to_hub=False,\n",
" disable_tqdm=True, # declutter the output a little\n",
" no_cuda=not use_gpu, # you need to explicitly set no_cuda if you want CPUs\n",
" )\n",
"\n",
" def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" if task != \"stsb\":\n",
" predictions = np.argmax(predictions, axis=1)\n",
" else:\n",
" predictions = predictions[:, 0]\n",
" return metric.compute(predictions=predictions, references=labels)\n",
"\n",
" trainer = Trainer(\n",
" model,\n",
" args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=eval_dataset,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics\n",
" )\n",
"\n",
" print(\"Starting training\")\n",
" return trainer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CdzABDVcIrJg"
},
"source": [
"With our `trainer_init_per_worker` complete, we can now instantiate the `HuggingFaceTrainer`. Aside from the function, we set the `scaling_config`, controlling the amount of workers and resources used, and the `datasets` we will use for training and evaluation.\n",
"\n",
"We specify the `MlflowLoggerCallback` inside the `run_config`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "RElw7OgLhYba"
},
"outputs": [],
"source": [
"from ray.train.huggingface import HuggingFaceTrainer\n",
"from ray.air import RunConfig\n",
"from ray.air.callbacks.mlflow import MLflowLoggerCallback\n",
"\n",
"trainer = HuggingFaceTrainer(\n",
" trainer_init_per_worker=trainer_init_per_worker,\n",
" scaling_config={\"num_workers\": num_workers, \"use_gpu\": use_gpu},\n",
" datasets={\"train\": ray_datasets[\"train\"], \"evaluation\": ray_datasets[validation_key]},\n",
" run_config=RunConfig(callbacks=[MLflowLoggerCallback(experiment_name=name)])\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvS136zKhYba"
},
"source": [
"Finally, we call the `fit` method to being training with Ray AIR. We will save the `Result` object to a variable so we can access metrics and checkpoints."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "uNx5pyRlIrJh",
"outputId": "8496fe4f-f1c3-48ad-a6d3-b16a65716135"
},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Current time: 2022-05-12 18:35:14 (running for 00:03:48.08)<br>Memory usage on this node: 5.7/12.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/2 CPUs, 0/1 GPUs, 0.0/7.32 GiB heap, 0.0/3.66 GiB objects (0.0/1.0 accelerator_type:T4)<br>Result logdir: /root/ray_results/HuggingFaceTrainer_2022-05-12_18-31-26<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
"<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> loss</th><th style=\"text-align: right;\"> learning_rate</th><th style=\"text-align: right;\"> epoch</th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>HuggingFaceTrainer_bb9dd_00000</td><td>TERMINATED</td><td>172.28.0.2:419</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 222.391</td><td style=\"text-align: right;\">0.1575</td><td style=\"text-align: right;\"> 1.30841e-06</td><td style=\"text-align: right;\"> 5</td></tr>\n",
"</tbody>\n",
"</table><br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m 2022-05-12 18:31:33,158\tINFO torch.py:347 -- Setting up process group for: env:// [rank=0, world_size=1]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Is CUDA available: True\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading builder script: 5.76kB [00:00, 6.35MB/s] \n",
"Downloading: 0%| | 0.00/256M [00:00<?, ?B/s]\n",
"Downloading: 2%|▏ | 5.63M/256M [00:00<00:04, 59.1MB/s]\n",
"Downloading: 5%|▍ | 12.2M/256M [00:00<00:03, 65.0MB/s]\n",
"Downloading: 7%|▋ | 18.5M/256M [00:00<00:03, 65.6MB/s]\n",
"Downloading: 10%|▉ | 25.3M/256M [00:00<00:03, 67.5MB/s]\n",
"Downloading: 12%|█▏ | 31.7M/256M [00:00<00:03, 66.6MB/s]\n",
"Downloading: 15%|█▌ | 38.3M/256M [00:00<00:03, 67.6MB/s]\n",
"Downloading: 18%|█▊ | 44.8M/256M [00:00<00:03, 67.6MB/s]\n",
"Downloading: 20%|██ | 51.2M/256M [00:00<00:03, 66.6MB/s]\n",
"Downloading: 23%|██▎ | 57.9M/256M [00:00<00:03, 67.5MB/s]\n",
"Downloading: 25%|██▌ | 64.7M/256M [00:01<00:02, 68.6MB/s]\n",
"Downloading: 28%|██▊ | 71.2M/256M [00:01<00:02, 66.6MB/s]\n",
"Downloading: 31%|███ | 78.0M/256M [00:01<00:02, 67.9MB/s]\n",
"Downloading: 33%|███▎ | 84.5M/256M [00:01<00:02, 68.0MB/s]\n",
"Downloading: 36%|███▌ | 91.1M/256M [00:01<00:02, 68.2MB/s]\n",
"Downloading: 38%|███▊ | 97.7M/256M [00:01<00:02, 68.5MB/s]\n",
"Downloading: 41%|████ | 104M/256M [00:01<00:02, 62.8MB/s] \n",
"Downloading: 43%|████▎ | 110M/256M [00:01<00:02, 58.5MB/s]\n",
"Downloading: 46%|████▌ | 117M/256M [00:01<00:02, 60.5MB/s]\n",
"Downloading: 48%|████▊ | 123M/256M [00:01<00:02, 61.7MB/s]\n",
"Downloading: 50%|█████ | 129M/256M [00:02<00:02, 63.0MB/s]\n",
"Downloading: 53%|█████▎ | 135M/256M [00:02<00:01, 64.0MB/s]\n",
"Downloading: 55%|█████▌ | 142M/256M [00:02<00:01, 62.2MB/s]\n",
"Downloading: 58%|█████▊ | 148M/256M [00:02<00:01, 61.0MB/s]\n",
"Downloading: 60%|██████ | 154M/256M [00:02<00:01, 62.2MB/s]\n",
"Downloading: 62%|██████▏ | 160M/256M [00:02<00:01, 62.1MB/s]\n",
"Downloading: 65%|██████▌ | 166M/256M [00:02<00:01, 64.1MB/s]\n",
"Downloading: 67%|██████▋ | 172M/256M [00:02<00:01, 64.4MB/s]\n",
"Downloading: 73%|███████▎ | 186M/256M [00:02<00:01, 67.3MB/s]\n",
"Downloading: 75%|███████▌ | 192M/256M [00:03<00:00, 68.0MB/s]\n",
"Downloading: 78%|███████▊ | 199M/256M [00:03<00:00, 70.0MB/s]\n",
"Downloading: 81%|████████ | 206M/256M [00:03<00:00, 69.6MB/s]\n",
"Downloading: 83%|████████▎ | 213M/256M [00:03<00:00, 70.1MB/s]\n",
"Downloading: 86%|████████▌ | 220M/256M [00:03<00:00, 69.1MB/s]\n",
"Downloading: 89%|████████▊ | 226M/256M [00:03<00:00, 68.4MB/s]\n",
"Downloading: 91%|█████████ | 233M/256M [00:03<00:00, 62.3MB/s]\n",
"Downloading: 93%|█████████▎| 239M/256M [00:03<00:00, 60.2MB/s]\n",
"Downloading: 96%|█████████▌| 245M/256M [00:03<00:00, 61.8MB/s]\n",
"Downloading: 100%|██████████| 256M/256M [00:04<00:00, 65.0MB/s]\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias']\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.weight']\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m /usr/local/lib/python3.7/dist-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m FutureWarning,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Starting training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m ***** Running training *****\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num examples = 8551\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num Epochs = 5\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Instantaneous batch size per device = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Total train batch size (w. parallel, distributed & accumulation) = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Gradient Accumulation steps = 1\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Total optimization steps = 2675\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m [W reducer.cpp:1289] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'loss': 0.5441, 'learning_rate': 1.6261682242990654e-05, 'epoch': 0.93}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num examples = 1043\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-535\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/config.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'eval_loss': 0.4999416470527649, 'eval_matthews_correlation': 0.3991733676966143, 'eval_runtime': 1.0378, 'eval_samples_per_second': 1004.976, 'eval_steps_per_second': 63.594, 'epoch': 1.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/special_tokens_map.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial HuggingFaceTrainer_bb9dd_00000 reported loss=0.5441,learning_rate=1.6261682242990654e-05,epoch=1.0,step=535,eval_loss=0.4999416470527649,eval_matthews_correlation=0.3991733676966143,eval_runtime=1.0378,eval_samples_per_second=1004.976,eval_steps_per_second=63.594,_timestamp=1652380362,_time_this_iter_s=66.77899646759033,_training_iteration=1,should_checkpoint=True with parameters={}.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'loss': 0.3886, 'learning_rate': 1.2523364485981309e-05, 'epoch': 1.87}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num examples = 1043\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'eval_loss': 0.5397436618804932, 'eval_matthews_correlation': 0.5085739436587455, 'eval_runtime': 1.0792, 'eval_samples_per_second': 966.488, 'eval_steps_per_second': 61.158, 'epoch': 2.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-1070\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/special_tokens_map.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial HuggingFaceTrainer_bb9dd_00000 reported loss=0.3886,learning_rate=1.2523364485981309e-05,epoch=2.0,step=1070,eval_loss=0.5397436618804932,eval_matthews_correlation=0.5085739436587455,eval_runtime=1.0792,eval_samples_per_second=966.488,eval_steps_per_second=61.158,_timestamp=1652380400,_time_this_iter_s=37.84357762336731,_training_iteration=2,should_checkpoint=True with parameters={}.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'loss': 0.2746, 'learning_rate': 8.785046728971963e-06, 'epoch': 2.8}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num examples = 1043\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-1605\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/config.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'eval_loss': 0.6648283004760742, 'eval_matthews_correlation': 0.5141951979542654, 'eval_runtime': 1.1148, 'eval_samples_per_second': 935.563, 'eval_steps_per_second': 59.202, 'epoch': 3.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/special_tokens_map.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial HuggingFaceTrainer_bb9dd_00000 reported loss=0.2746,learning_rate=8.785046728971963e-06,epoch=3.0,step=1605,eval_loss=0.6648283004760742,eval_matthews_correlation=0.5141951979542654,eval_runtime=1.1148,eval_samples_per_second=935.563,eval_steps_per_second=59.202,_timestamp=1652380437,_time_this_iter_s=36.976723432540894,_training_iteration=3,should_checkpoint=True with parameters={}.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'loss': 0.196, 'learning_rate': 5.046728971962617e-06, 'epoch': 3.74}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num examples = 1043\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-2140\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-2140/config.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'eval_loss': 0.7566447854042053, 'eval_matthews_correlation': 0.5518326707011334, 'eval_runtime': 1.1113, 'eval_samples_per_second': 938.535, 'eval_steps_per_second': 59.39, 'epoch': 4.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-2140/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-2140/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-2140/special_tokens_map.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial HuggingFaceTrainer_bb9dd_00000 reported loss=0.196,learning_rate=5.046728971962617e-06,epoch=4.0,step=2140,eval_loss=0.7566447854042053,eval_matthews_correlation=0.5518326707011334,eval_runtime=1.1113,eval_samples_per_second=938.535,eval_steps_per_second=59.39,_timestamp=1652380474,_time_this_iter_s=36.68935775756836,_training_iteration=4,should_checkpoint=True with parameters={}.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'loss': 0.1575, 'learning_rate': 1.308411214953271e-06, 'epoch': 4.67}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-2675\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/special_tokens_map.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Num examples = 1043\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-2675\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/config.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'eval_loss': 0.8616615533828735, 'eval_matthews_correlation': 0.5420036503219092, 'eval_runtime': 1.2577, 'eval_samples_per_second': 829.302, 'eval_steps_per_second': 52.477, 'epoch': 5.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-2675/special_tokens_map.json\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m \n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m \n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m \n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=455)\u001b[0m {'train_runtime': 187.8585, 'train_samples_per_second': 227.592, 'train_steps_per_second': 14.239, 'train_loss': 0.30010223103460865, 'epoch': 5.0}\n",
"Trial HuggingFaceTrainer_bb9dd_00000 reported loss=0.1575,learning_rate=1.308411214953271e-06,epoch=5.0,step=2675,eval_loss=0.8616615533828735,eval_matthews_correlation=0.5420036503219092,eval_runtime=1.2577,eval_samples_per_second=829.302,eval_steps_per_second=52.477,train_runtime=187.8585,train_samples_per_second=227.592,train_steps_per_second=14.239,train_loss=0.30010223103460865,_timestamp=1652380513,_time_this_iter_s=39.63672137260437,_training_iteration=5,should_checkpoint=True with parameters={}.\n",
"Trial HuggingFaceTrainer_bb9dd_00000 completed. Last result: loss=0.1575,learning_rate=1.308411214953271e-06,epoch=5.0,step=2675,eval_loss=0.8616615533828735,eval_matthews_correlation=0.5420036503219092,eval_runtime=1.2577,eval_samples_per_second=829.302,eval_steps_per_second=52.477,train_runtime=187.8585,train_samples_per_second=227.592,train_steps_per_second=14.239,train_loss=0.30010223103460865,_timestamp=1652380513,_time_this_iter_s=39.63672137260437,_training_iteration=5,should_checkpoint=True\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-12 18:35:14,803\tINFO tune.py:753 -- Total run time: 228.34 seconds (228.07 seconds for the tuning loop).\n"
]
}
],
"source": [
"result = trainer.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4cnWqUWmhYba"
},
"source": [
"You can use the returned `Result` object to access metrics and the Ray AIR `Checkpoint` associated with the last iteration."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AMN5qjUwhYba",
"outputId": "7b754c36-c58b-4ff4-d7a8-63ec9764bd0c"
},
"outputs": [
{
"data": {
"text/plain": [
"Result(metrics={'loss': 0.1575, 'learning_rate': 1.308411214953271e-06, 'epoch': 5.0, 'step': 2675, 'eval_loss': 0.8616615533828735, 'eval_matthews_correlation': 0.5420036503219092, 'eval_runtime': 1.2577, 'eval_samples_per_second': 829.302, 'eval_steps_per_second': 52.477, 'train_runtime': 187.8585, 'train_samples_per_second': 227.592, 'train_steps_per_second': 14.239, 'train_loss': 0.30010223103460865, '_timestamp': 1652380513, '_time_this_iter_s': 39.63672137260437, '_training_iteration': 5, 'time_this_iter_s': 39.64510202407837, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 5, 'trial_id': 'bb9dd_00000', 'experiment_id': 'db0c5ea784a44980819bf5e1bfb72c04', 'date': '2022-05-12_18-35-13', 'timestamp': 1652380513, 'time_total_s': 222.39091277122498, 'pid': 419, 'hostname': 'e618da00601e', 'node_ip': '172.28.0.2', 'config': {}, 'time_since_restore': 222.39091277122498, 'timesteps_since_restore': 0, 'iterations_since_restore': 5, 'warmup_time': 0.004034996032714844, 'experiment_tag': '0'}, checkpoint=<ray.air.checkpoint.Checkpoint object at 0x7f9ffd9d9c90>, error=None)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predict on test data with Ray AIR <a name=\"predict\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tfoyu1q7hYbb"
},
"source": [
"You can now use the checkpoint to run prediction with `HuggingFacePredictor`, which wraps around [🤗 Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines). In order to distribute prediction, we use `BatchPredictor`. While this is not necessary for the very small example we are using (you could use `HuggingFacePredictor` directly), it will scale well to a large dataset."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 262
},
"id": "UOUcBkX8IrJi",
"outputId": "4dc16812-1400-482d-8c3f-85991ce4b081"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map Progress (2 actors 1 pending): 0%| | 0/1 [00:12<?, ?it/s]\u001b[2m\u001b[36m(BlockWorker pid=735)\u001b[0m 2022-05-12 18:36:08.491769: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:16<00:00, 16.63s/it]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div id=\"df-6bcebc1c-5de9-4e2b-802f-7d04902ab976\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>LABEL_1</td>\n",
" <td>0.998539</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>LABEL_1</td>\n",
" <td>0.997706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>LABEL_1</td>\n",
" <td>0.998476</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>LABEL_1</td>\n",
" <td>0.998498</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>LABEL_0</td>\n",
" <td>0.533578</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-6bcebc1c-5de9-4e2b-802f-7d04902ab976')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-6bcebc1c-5de9-4e2b-802f-7d04902ab976 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-6bcebc1c-5de9-4e2b-802f-7d04902ab976');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" label score\n",
"0 LABEL_1 0.998539\n",
"1 LABEL_1 0.997706\n",
"2 LABEL_1 0.998476\n",
"3 LABEL_1 0.998498\n",
"4 LABEL_0 0.533578"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from ray.train.huggingface import HuggingFacePredictor\n",
"from ray.train.batch_predictor import BatchPredictor\n",
"import pandas as pd\n",
"\n",
"sentences = ['Bill whistled past the house.',\n",
" 'The car honked its way down the road.',\n",
" 'Bill pushed Harry off the sofa.',\n",
" 'the kittens yawned awake and played.',\n",
" 'I demand that the more John eats, the more he pay.']\n",
"predictor = BatchPredictor.from_checkpoint(\n",
" checkpoint=result.checkpoint,\n",
" predictor_cls=HuggingFacePredictor,\n",
" task=\"text-classification\",\n",
")\n",
"data = ray.data.from_pandas(pd.DataFrame(sentences, columns=[\"sentence\"]))\n",
"prediction = predictor.predict(data)\n",
"prediction.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Share the model <a name=\"share\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mS8PId_NhYbb"
},
"source": [
"To be able to share your model with the community, there are a few more steps to follow.\n",
"\n",
"We have conducted the training on the Ray cluster, but share the model from the local enviroment - this will allow us to easily authenticate.\n",
"\n",
"First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2LClXkN8hYbb",
"tags": ["remove-cell-ci"]
},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SybKUDryhYbb"
},
"source": [
"Then you need to install Git-LFS. Uncomment the following instructions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_wF6aT-0hYbb",
"tags": ["remove-cell-ci"]
},
"outputs": [],
"source": [
"# !apt install git-lfs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5fr6E0e8hYbb"
},
"source": [
"Now, load the model and tokenizer locally, and recreate the 🤗 Transformers `Trainer`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cjH2A8m6hYbc",
"tags": ["remove-cell-ci"]
},
"outputs": [],
"source": [
"from ray.train.huggingface import load_checkpoint\n",
"\n",
"hf_trainer = load_checkpoint(\n",
" checkpoint=result.checkpoint,\n",
" model=AutoModelForSequenceClassification,\n",
" tokenizer=AutoTokenizer\n",
")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tgV2xKfFhYbc"
},
"source": [
"You can now upload the result of the training to the Hub, just execute this instruction:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XSkfJe3nhYbc",
"tags": ["remove-cell-ci"]
},
"outputs": [],
"source": [
"hf_trainer.push_to_hub()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UL-Boc4dhYbc"
},
"source": [
"You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `\"your-username/the-name-you-picked\"` so for instance:\n",
"\n",
"```python\n",
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"sgugger/my-awesome-model\")\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ipJBReeWhYbc",
"tags": ["remove-cell-ci"]
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "huggingface_text_classification.ipynb",
"provenance": []
},
"interpreter": {
"hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}