ray/doc/source/tune/examples/pbt_transformers.ipynb
Max Pumperla 372c620f58
[docs] Tune overhaul part II (#22656)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2022-02-26 23:07:34 -08:00

291 lines
No EOL
10 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"id": "3b05af3b",
"metadata": {},
"source": [
"(tune-huggingface-example)=\n",
"\n",
"# Using |:hugging_face:| Huggingface Transformers with Tune\n",
"\n",
"```{image} /images/hugging.png\n",
":align: center\n",
":alt: Huggingface Logo\n",
":height: 120px\n",
":target: https://huggingface.co\n",
"```\n",
"\n",
"```{contents}\n",
":backlinks: none\n",
":local: true\n",
"```\n",
"\n",
"## Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19e3c389",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"This example is uses the official\n",
"huggingface transformers `hyperparameter_search` API.\n",
"\"\"\"\n",
"import os\n",
"\n",
"import ray\n",
"from ray import tune\n",
"from ray.tune import CLIReporter\n",
"from ray.tune.examples.pbt_transformers.utils import (\n",
" download_data,\n",
" build_compute_metrics_fn,\n",
")\n",
"from ray.tune.schedulers import PopulationBasedTraining\n",
"from transformers import (\n",
" glue_tasks_num_labels,\n",
" AutoConfig,\n",
" AutoModelForSequenceClassification,\n",
" AutoTokenizer,\n",
" Trainer,\n",
" GlueDataset,\n",
" GlueDataTrainingArguments,\n",
" TrainingArguments,\n",
")\n",
"\n",
"\n",
"def tune_transformer(num_samples=8, gpus_per_trial=0, smoke_test=False):\n",
" data_dir_name = \"./data\" if not smoke_test else \"./test_data\"\n",
" data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))\n",
" if not os.path.exists(data_dir):\n",
" os.mkdir(data_dir, 0o755)\n",
"\n",
" # Change these as needed.\n",
" model_name = (\n",
" \"bert-base-uncased\" if not smoke_test else \"sshleifer/tiny-distilroberta-base\"\n",
" )\n",
" task_name = \"rte\"\n",
"\n",
" task_data_dir = os.path.join(data_dir, task_name.upper())\n",
"\n",
" num_labels = glue_tasks_num_labels[task_name]\n",
"\n",
" config = AutoConfig.from_pretrained(\n",
" model_name, num_labels=num_labels, finetuning_task=task_name\n",
" )\n",
"\n",
" # Download and cache tokenizer, model, and features\n",
" print(\"Downloading and caching Tokenizer\")\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
" # Triggers tokenizer download to cache\n",
" print(\"Downloading and caching pre-trained model\")\n",
" AutoModelForSequenceClassification.from_pretrained(\n",
" model_name,\n",
" config=config,\n",
" )\n",
"\n",
" def get_model():\n",
" return AutoModelForSequenceClassification.from_pretrained(\n",
" model_name,\n",
" config=config,\n",
" )\n",
"\n",
" # Download data.\n",
" download_data(task_name, data_dir)\n",
"\n",
" data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=task_data_dir)\n",
"\n",
" train_dataset = GlueDataset(\n",
" data_args, tokenizer=tokenizer, mode=\"train\", cache_dir=task_data_dir\n",
" )\n",
" eval_dataset = GlueDataset(\n",
" data_args, tokenizer=tokenizer, mode=\"dev\", cache_dir=task_data_dir\n",
" )\n",
"\n",
" training_args = TrainingArguments(\n",
" output_dir=\".\",\n",
" learning_rate=1e-5, # config\n",
" do_train=True,\n",
" do_eval=True,\n",
" no_cuda=gpus_per_trial <= 0,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
" num_train_epochs=2, # config\n",
" max_steps=-1,\n",
" per_device_train_batch_size=16, # config\n",
" per_device_eval_batch_size=16, # config\n",
" warmup_steps=0,\n",
" weight_decay=0.1, # config\n",
" logging_dir=\"./logs\",\n",
" skip_memory_metrics=True,\n",
" report_to=\"none\",\n",
" )\n",
"\n",
" trainer = Trainer(\n",
" model_init=get_model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=eval_dataset,\n",
" compute_metrics=build_compute_metrics_fn(task_name),\n",
" )\n",
"\n",
" tune_config = {\n",
" \"per_device_train_batch_size\": 32,\n",
" \"per_device_eval_batch_size\": 32,\n",
" \"num_train_epochs\": tune.choice([2, 3, 4, 5]),\n",
" \"max_steps\": 1 if smoke_test else -1, # Used for smoke test.\n",
" }\n",
"\n",
" scheduler = PopulationBasedTraining(\n",
" time_attr=\"training_iteration\",\n",
" metric=\"eval_acc\",\n",
" mode=\"max\",\n",
" perturbation_interval=1,\n",
" hyperparam_mutations={\n",
" \"weight_decay\": tune.uniform(0.0, 0.3),\n",
" \"learning_rate\": tune.uniform(1e-5, 5e-5),\n",
" \"per_device_train_batch_size\": [16, 32, 64],\n",
" },\n",
" )\n",
"\n",
" reporter = CLIReporter(\n",
" parameter_columns={\n",
" \"weight_decay\": \"w_decay\",\n",
" \"learning_rate\": \"lr\",\n",
" \"per_device_train_batch_size\": \"train_bs/gpu\",\n",
" \"num_train_epochs\": \"num_epochs\",\n",
" },\n",
" metric_columns=[\"eval_acc\", \"eval_loss\", \"epoch\", \"training_iteration\"],\n",
" )\n",
"\n",
" trainer.hyperparameter_search(\n",
" hp_space=lambda _: tune_config,\n",
" backend=\"ray\",\n",
" n_trials=num_samples,\n",
" resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
" scheduler=scheduler,\n",
" keep_checkpoints_num=1,\n",
" checkpoint_score_attr=\"training_iteration\",\n",
" stop={\"training_iteration\": 1} if smoke_test else None,\n",
" progress_reporter=reporter,\n",
" local_dir=\"~/ray_results/\",\n",
" name=\"tune_transformer_pbt\",\n",
" log_to_file=True,\n",
" )\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" import argparse\n",
"\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\n",
" \"--smoke-test\", default=True, action=\"store_true\", help=\"Finish quickly for testing\"\n",
" )\n",
" parser.add_argument(\n",
" \"--ray-address\",\n",
" type=str,\n",
" default=None,\n",
" help=\"Address to use for Ray. \"\n",
" 'Use \"auto\" for cluster. '\n",
" \"Defaults to None for local.\",\n",
" )\n",
" parser.add_argument(\n",
" \"--server-address\",\n",
" type=str,\n",
" default=None,\n",
" required=False,\n",
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
" )\n",
"\n",
" args, _ = parser.parse_known_args()\n",
"\n",
" if args.smoke_test:\n",
" ray.init()\n",
" elif args.server_address:\n",
" ray.init(f\"ray://{args.server_address}\")\n",
" else:\n",
" ray.init(args.ray_address)\n",
"\n",
" if args.smoke_test:\n",
" tune_transformer(num_samples=1, gpus_per_trial=0, smoke_test=True)\n",
" else:\n",
" # You can change the number of GPUs here:\n",
" tune_transformer(num_samples=8, gpus_per_trial=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"\"\"\"Utilities to load and cache data.\"\"\"\n",
"\n",
"import os\n",
"from typing import Callable, Dict\n",
"import numpy as np\n",
"from transformers import EvalPrediction\n",
"from transformers import glue_compute_metrics, glue_output_modes\n",
"\n",
"\n",
"def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:\n",
" \"\"\"Function from transformers/examples/text-classification/run_glue.py\"\"\"\n",
" output_mode = glue_output_modes[task_name]\n",
"\n",
" def compute_metrics_fn(p: EvalPrediction):\n",
" if output_mode == \"classification\":\n",
" preds = np.argmax(p.predictions, axis=1)\n",
" elif output_mode == \"regression\":\n",
" preds = np.squeeze(p.predictions)\n",
" metrics = glue_compute_metrics(task_name, preds, p.label_ids)\n",
" return metrics\n",
"\n",
" return compute_metrics_fn\n",
"\n",
"\n",
"def download_data(task_name, data_dir=\"./data\"):\n",
" # Download RTE training data\n",
" print(\"Downloading dataset.\")\n",
" import urllib\n",
" import zipfile\n",
"\n",
" if task_name == \"rte\":\n",
" url = \"https://dl.fbaipublicfiles.com/glue/data/RTE.zip\"\n",
" else:\n",
" raise ValueError(\"Unknown task: {}\".format(task_name))\n",
" data_file = os.path.join(data_dir, \"{}.zip\".format(task_name))\n",
" if not os.path.exists(data_file):\n",
" urllib.request.urlretrieve(url, data_file)\n",
" with zipfile.ZipFile(data_file) as zip_ref:\n",
" zip_ref.extractall(data_dir)\n",
" print(\"Downloaded data for task {} to {}\".format(task_name, data_dir))\n",
" else:\n",
" print(\n",
" \"Data already exists. Using downloaded data for task {} from {}\".format(\n",
" task_name, data_dir\n",
" )\n",
" )"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"orphan": true
},
"nbformat": 4,
"nbformat_minor": 5
}