From 857e4dba2f4559c418cc142823f06992adf189a6 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 17 Apr 2020 15:17:30 -0700 Subject: [PATCH] [sgd] HuggingFace GLUE Fine-tuning Example (#7792) * Init fp16 * fp16 and schedulers * scheduler linking and fp16 * to fp16 * loss scaling and documentation * more documentation * add tests, refactor config * moredocs * more docs * fix logo, add test mode, add fp16 flag * fix tests * fix scheduler * fix apex * improve safety * fix tests * fix tests * remove pin memory default * rm * fix * Update doc/examples/doc_code/raysgd_torch_signatures.py * fix * migrate changes from other PR * ok thanks * pass * signatures * lint' * Update python/ray/experimental/sgd/pytorch/utils.py * Apply suggestions from code review Co-Authored-By: Edward Oakes * should address most comments * comments * fix this ci * first_pass * add overrides * override * fixing up operators * format * sgd * constants * rm * revert * save * failures * fixes * trainer * run test * operator * code * op * ok done * operator * sgd test fixes * ok * trainer * format * Apply suggestions from code review Co-Authored-By: Edward Oakes * Update doc/source/raysgd/raysgd_pytorch.rst * docstring * dcgan * doc * commits * nit * testing * revert * Start renaming pytorch to torch * Rename PyTorchTrainer to TorchTrainer * Rename PyTorch runners to Torch runners * Finish renaming API * Rename to torch in tests * Finish renaming docs + tests * Run format + fix DeprecationWarning * fix * move tests up * benchmarks * rename * remove some args * better metrics output * fix up the benchmark * benchmark-yaml * horovod-benchmark * benchmarks * Remove benchmark code for cleanups * benchmark-code * nits * benchmark yamls * benchmark yaml * ok * ok * ok * benchmark * nit * finish_bench * makedatacreator * relax * metrics * autosetsampler * profile * movements * OK * smoothen * fix * nitdocs * loss * envflag * comments * nit * format * visible * images * move_images * fix * rernder * rrender * rest * multgpu * fix * nit * finish * extrra * setup * experimental * as_trainable * fix * ok * format * create_torch_pbt * setup_pbt * ok * format * ok * format * docs * ok * Draft head-is-worker * Fix missing concurrency between local and remote workers * Fix tqdm to work with head-is-worker * Cleanup * Implement state_dict and load_state_dict * Reserve resources on the head node for the local worker * Update the development cluster setup * Add spot block reservation to the development yaml * ok * Draft the fault tolerance fix * Small fixes to local-remote concurrency * Cleanup + fix typo * fixes * worker_counts * some formatting and asha * fix * okme * fixactorkill * unify * Revert the cluster mounts * Cut the handler-reporter API * Fix most tests * Rm tqdm_handler.py * Re-add tune test * Automatically force-shutdown on actor errors on shutdown * Formatting * fix_tune_test * Add timeout error verification * Rename tqdm to use_tqdm * fixtests * ok * remove_redundant * deprecated * deactivated * ok_try_this * lint * nice * done * retries * fixes * kill * retry * init_transformer * init * deployit * improve_example * trans * rename * formats * format-to-py37 * time_to_test * more_changes * ok * update_args_and_script * fp16_epoch * huggingface * training stats * distributed * Apply suggestions from code review * transformer Co-authored-by: Edward Oakes Co-authored-by: Maksim Smolin --- doc/source/raysgd/raysgd_pytorch.rst | 17 +- .../torch/examples/transformers/README.rst | 89 +++++ .../torch/examples/transformers/__init__.py | 0 .../torch/examples/transformers/cluster.yaml | 78 ++++ .../transformers/transformers_example.py | 371 ++++++++++++++++++ .../sgd/torch/examples/transformers/utils.py | 212 ++++++++++ 6 files changed, 760 insertions(+), 7 deletions(-) create mode 100644 python/ray/util/sgd/torch/examples/transformers/README.rst create mode 100644 python/ray/util/sgd/torch/examples/transformers/__init__.py create mode 100644 python/ray/util/sgd/torch/examples/transformers/cluster.yaml create mode 100644 python/ray/util/sgd/torch/examples/transformers/transformers_example.py create mode 100644 python/ray/util/sgd/torch/examples/transformers/utils.py diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 5d85b27d6..40a20c2c1 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -708,23 +708,26 @@ TorchTrainer Examples Here are some examples of using RaySGD for training PyTorch models. If you'd like to contribute an example, feel free to create a `pull request here `_. -- `Torch training example `__: +- `Torch training example `__ Simple example of using Ray's TorchTrainer. -- `TorchTrainer and RayTune example `__: +- `TorchTrainer and RayTune example `__ Simple example of hyperparameter tuning with Ray's TorchTrainer. -- `Semantic Segmentation example `__: +- `Semantic Segmentation example `__ Fine-tuning a ResNet50 model on VOC with Batch Norm. -- `ImageNet Models example `__: +- `Huggingface Transformer GLUE fine tuning example `__ + Fine-tuning a pre-trained Transformer model on GLUE tasks. Based off of the `huggingface/transformers `_ ``run_glue.py`` example. + +- `ImageNet Models example `__ Training state-of-the-art ImageNet models. -- `CIFAR10 example `__: +- `CIFAR10 example `__ Training a ResNet18 model on CIFAR10. -- `CIFAR10 RayTune example `__: +- `CIFAR10 RayTune example `__ Tuning a ResNet18 model on CIFAR10 with Population-based training on RayTune. -- `DCGAN example `__: +- `DCGAN example `__ Training a Deep Convolutional GAN on MNIST. It constructs two models and two optimizers and uses a custom training operator. diff --git a/python/ray/util/sgd/torch/examples/transformers/README.rst b/python/ray/util/sgd/torch/examples/transformers/README.rst new file mode 100644 index 000000000..2c0ac11d6 --- /dev/null +++ b/python/ray/util/sgd/torch/examples/transformers/README.rst @@ -0,0 +1,89 @@ +HuggingFace Transformers Glue Fine-tuning Example +================================================= + +We've ported the ``huggingface/transformers/examples/run_glue.py`` example to +RaySGD. This example enables fine-tuning the library models for sequence classification on the GLUE benchmark: General Language Understanding Evaluation. + +This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa. + +The below information can be found at the `HuggingFace Repository `_ and is copied over at your convenience. + +Before running any one of these GLUE tasks you should download the +`GLUE data `_ by running +`this script `_ +and unpack it to some directory ``$GLUE_DIR``. + +.. code-block:: bash + + export GLUE_DIR=/path/to/glue + export TASK_NAME=MRPC + + python transformers_example.py \ + --model_type bert \ + --model_name_or_path bert-base-cased \ + --task_name $TASK_NAME \ + --do_train \ + --do_eval \ + --data_dir glue_data/$TASK_NAME \ + --max_seq_length 128 \ + --per_gpu_train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /tmp/$TASK_NAME/ + +where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI. + +The dev set results will be present within the text file ``eval_results.txt`` in the specified output_dir. +In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate +output folder called ``/tmp/MNLI-MM/`` in addition to ``/tmp/MNLI/``. + +Multi-GPU training with Apex +---------------------------- + +To run an example tuning MNLI on your local machine with 8 GPUs and apex, first install `apex `_, and then run: + +.. code-block:: bash + + python transformers_example.py \ + --model_type bert \ + --model_name_or_path bert-base-cased \ + --task_name mnli \ + --do_train \ + --do_eval \ + --data_dir glue_data/MNLI/ \ + --max_seq_length 128 \ + --per_gpu_train_batch_size 8 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir output_dir \ + --num_workers 8 + --fp16 + + +Multi-node training +------------------- + +To run an example tuning MNLI on AWS with 16 GPUs and apex, just run: + +.. code-block:: bash + + ray up cluster.yaml + # Optionally, + # ray monitor cluster.yaml + ray submit cluster.yaml transformers_example.py -- --model_type bert \ + --model_name_or_path bert-base-cased \ + --task_name mnli \ + --do_train \ + --do_eval \ + --data_dir /home/ubuntu/glue_data/MNLI/ \ + --max_seq_length 128 \ + --per_gpu_train_batch_size 8 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /home/ubuntu/output/ \ + --num_workers 16 \ + --fp16 \ + --address auto + +Note that with Apex, you can increase ``per_gpu_train_batch_size`` to 32, which +should make each epoch take 10 minutes or less. diff --git a/python/ray/util/sgd/torch/examples/transformers/__init__.py b/python/ray/util/sgd/torch/examples/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/util/sgd/torch/examples/transformers/cluster.yaml b/python/ray/util/sgd/torch/examples/transformers/cluster.yaml new file mode 100644 index 000000000..59b3877cf --- /dev/null +++ b/python/ray/util/sgd/torch/examples/transformers/cluster.yaml @@ -0,0 +1,78 @@ +# An unique identifier for the head node and workers of this cluster. +cluster_name: transformer-cluster + +# The maximum number of workers nodes to launch in addition to the head +# node. This takes precedence over min_workers. min_workers default to 0. +min_workers: 3 +initial_workers: 3 +max_workers: 3 + +target_utilization_fraction: 0.9 +# Cloud-provider specific configuration. +provider: + type: aws + region: us-east-1 + availability_zone: us-east-1c + +# How Ray will authenticate with newly launched nodes. +auth: + ssh_user: ubuntu + + +head_node: + InstanceType: p3.8xlarge + ImageId: ami-0698bcaf8bd9ef56d + InstanceMarketOptions: + MarketType: spot + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 300 + + +worker_nodes: + InstanceType: p3.8xlarge + ImageId: ami-0698bcaf8bd9ef56d + InstanceMarketOptions: + MarketType: spot + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 300 + # SpotOptions: + # MaxPrice: "9.0" + # # Run workers on spot by default. Comment this out to use on-demand. + # InstanceMarketOptions: + # MarketType: spot + +setup_commands: + # This replaces the standard anaconda Ray installation + - ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + - pip install -q tqdm + + # Installing this without -U to make sure we don't replace the existing Ray installation + - pip install ray[tune] + - pip install -U ipdb torch + # Install HuggingFace + - git clone https://github.com/huggingface/transformers || true + - cd transformers && + pip install . && + pip install -r ./examples/requirements.txt + # Download glue + - if [[ -e glue_data ]]; + then echo "not downloading glue"; + else wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py && python download_glue_data.py; + fi + + # Install Apex + - git clone https://github.com/NVIDIA/apex; + cd apex && + pip install -v --no-cache-dir ./ || + true + + +file_mounts: { +} + +# Custom commands that will be run on the head node after common setup. +head_setup_commands: [] diff --git a/python/ray/util/sgd/torch/examples/transformers/transformers_example.py b/python/ray/util/sgd/torch/examples/transformers/transformers_example.py new file mode 100644 index 000000000..2eb7fc146 --- /dev/null +++ b/python/ray/util/sgd/torch/examples/transformers/transformers_example.py @@ -0,0 +1,371 @@ +# coding=utf-8 +# This is a modified example originally from The Google AI Language Team +# Authors and The HuggingFace Inc. team. +# Modified by Richard Liaw. +# Copyright 2018 The Google AI Language Team Authors, +# The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE ( +Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" + +import argparse +import logging +import json +import os +import time +from filelock import FileLock +from dataclasses import dataclass, field +from typing import Optional +import random + +import numpy as np +import torch +from torch.utils.data import DataLoader, RandomSampler +from tqdm import trange +import torch.distributed as dist + +from transformers import (MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, AdamW, + AutoConfig, AutoModelForSequenceClassification, + AutoTokenizer, get_linear_schedule_with_warmup, + HfArgumentParser, TrainingArguments) +from transformers import glue_output_modes as output_modes +from transformers import glue_processors as processors + +import ray +from ray.util.sgd.torch import TrainingOperator +from ray.util.sgd import TorchTrainer +from ray.util.sgd.torch.examples.transformers.utils import ( + evaluate, load_and_cache_examples, save_and_evaluate_checkpoints) + +try: + from apex import amp +except ImportError: + amp = None + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + +ALL_MODELS = sum( + (tuple(conf.pretrained_config_archive_map.keys()) + for conf in MODEL_CONFIG_CLASSES), + (), +) + +logger = logging.getLogger(__name__) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + +def announce_training(args, dataset_len, t_total): + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", dataset_len) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", + args.per_gpu_train_batch_size) + logger.info( + " Total train batch size (w. parallel, distributed & accum) = %d", + args.per_gpu_train_batch_size * args.gradient_accumulation_steps * + args.num_workers, + ) + logger.info(" Gradient Accumulation steps = %d", + args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + +def model_creator(config): + with FileLock(os.path.expanduser("~/.download.lock")): + args = config["args"] + processor = processors[args.task_name]() + label_list = processor.get_labels() + num_labels = len(label_list) + config = AutoConfig.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + num_labels=num_labels, + finetuning_task=args.task_name, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + return model + + +def optimizer_creator(model, cfg): + args = cfg["args"] + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": args.weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0 + }, + ] + + return AdamW( + optimizer_grouped_parameters, + lr=args.learning_rate, + eps=args.adam_epsilon) + + +def data_creator(config): + args = config["args"] + start = time.time() + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name + if args.tokenizer_name else args.model_name_or_path, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + logger.info("tokenizer instantiation time: {}".format(time.time() - start)) + + train_dataset = load_and_cache_examples( + args, args.task_name, tokenizer, evaluate=False) + train_sampler = RandomSampler( + train_dataset) if not dist.is_initialized() else None + return DataLoader( + train_dataset, + sampler=train_sampler, + batch_size=args.per_gpu_train_batch_size) + + +class TransformerOperator(TrainingOperator): + def setup(self, config): + self.args = args = config["args"] + self.tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name + if args.tokenizer_name else args.model_name_or_path, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + + self.train_data_len = len(self.train_loader) + self._warmup_scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=self.calculate_t_total()) + self._global_step = 0 + + announce_training(args, self.train_data_len, self.calculate_t_total()) + + def train_batch(self, batch, batch_info=None): + args = self.args + model = self.model + optimizer = self.optimizer + step = batch_info["batch_idx"] + + model.train() + batch = tuple(t.to(self.device) for t in batch) + inputs = { + "input_ids": batch[0], + "attention_mask": batch[1], + "labels": batch[3] + } + if args.model_type != "distilbert": + # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids + inputs["token_type_ids"] = (batch[2] if args.model_type in [ + "bert", "xlnet", "albert" + ] else None) + outputs = model(**inputs) + + # model outputs are always tuple in transformers (see doc) + loss = outputs[0] + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + batch_loss = loss.item() + + # last step in epoch but step is always smaller + # than gradient_accumulation_steps + ending = (self.train_data_len <= args.gradient_accumulation_steps + and (step + 1) == self.train_data_len) + if (step + 1) % args.gradient_accumulation_steps == 0 or ending: + if args.fp16: + torch.nn.utils.clip_grad_norm_( + amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), + args.max_grad_norm) + + self.optimizer.step() + self._warmup_scheduler.step() # Update learning rate schedule + model.zero_grad() + self._global_step += 1 + + learning_rate_scalar = self._warmup_scheduler.get_lr()[0] + return {"learning_rate": learning_rate_scalar, "loss": batch_loss} + + def calculate_t_total(self): + args = self.args + grad_accum_steps = args.gradient_accumulation_steps + train_data_len = len(self.train_loader) + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // ( + train_data_len // grad_accum_steps) + 1 + else: + t_total = ( + train_data_len // grad_accum_steps * args.num_train_epochs) + return t_total + + +@dataclass +class ModelArguments: + """Arguments pertaining to model/config/tokenizer.""" + + model_name_or_path: str = field( + metadata=dict(help="Path to pre-trained model or shortcut name " + "selected in the list: " + ", ".join(ALL_MODELS))) + model_type: str = field( + metadata=dict(help="Model type selected " + "in the list: " + ", ".join(MODEL_TYPES))) + config_name: Optional[str] = field( + default=None, + metadata=dict( + help="Pretrained config name or path if not the same as model_name" + )) + tokenizer_name: Optional[str] = field( + default=None, + metadata=dict(help="Pretrained tokenizer name or path " + "if not the same as model_name")) + cache_dir: Optional[str] = field( + default=None, + metadata=dict(help="Where do you want to store the pre-trained " + "models downloaded from s3")) + + +@dataclass +class DataProcessingArguments: + task_name: str = field( + metadata=dict(help="The name of the task to train selected " + "in the list: " + ", ".join(processors.keys()))) + data_dir: str = field( + metadata=dict(help="The input data dir. Should contain " + "the .tsv files (or other data files) for the task.")) + max_seq_length: int = field( + default=128, + metadata=dict(help="The maximum total input sequence length " + "after tokenization. Sequences longer " + "than this will be truncated, sequences " + "shorter will be padded.")) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}) + + +@dataclass +class RayArguments: + num_workers: int = field( + default=1, + metadata={"help": "Number of data-parallel workers to use."}) + address: str = field( + default=None, + metadata={"help": "Address of the Ray cluster to connect to."}) + + +def main(): + parser = HfArgumentParser((ModelArguments, DataProcessingArguments, + TrainingArguments, RayArguments)) + all_args = parser.parse_args_into_dataclasses() + model_args, dataprocessing_args, training_args, ray_args = all_args + + # For now, let's merge all the sets of args into one, + # but soon, we'll keep distinct sets of args, with a + # cleaner separation of concerns. + args = argparse.Namespace(**vars(model_args), **vars(dataprocessing_args), + **vars(training_args), **vars(ray_args)) + + if (os.path.exists(args.output_dir) and os.listdir(args.output_dir) + and args.do_train and not args.overwrite_output_dir): + raise ValueError( + "Output directory ({}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome.".format(args.output_dir)) + + use_gpu = torch.cuda.is_available() and not args.no_cuda + + # Prepare GLUE task + args.task_name = args.task_name.lower() + if args.task_name not in processors: + raise ValueError("Task not found: %s" % (args.task_name)) + args.output_mode = output_modes[args.task_name] + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + logger.info("Training/evaluation parameters %s", args) + ray.init(address=args.address) + # Training + + trainer = TorchTrainer( + model_creator=model_creator, + data_creator=data_creator, + optimizer_creator=optimizer_creator, + training_operator_cls=TransformerOperator, + use_fp16=args.fp16, + apex_args={"opt_level": args.fp16_opt_level}, + num_workers=args.num_workers, + use_gpu=use_gpu, + use_tqdm=True, + config={"args": args}) + + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + tokenizer = trainer.get_local_operator().tokenizer + local_model = trainer.get_model() + + epochs_trained = 0 + train_iterator = trange( + epochs_trained, + int(args.num_train_epochs), + desc="Epoch", + ) + + trainer.apply_all_workers(lambda: set_seed(args)) + if args.do_train: + for _ in train_iterator: + stats = trainer.train() + print("Training stats:", stats) + logs = evaluate(args, local_model, tokenizer) + print(json.dumps(logs)) + + # Post-training validation + save_and_evaluate_checkpoints(args, local_model, tokenizer) + + +if __name__ == "__main__": + main() diff --git a/python/ray/util/sgd/torch/examples/transformers/utils.py b/python/ray/util/sgd/torch/examples/transformers/utils.py new file mode 100644 index 000000000..0b7137e22 --- /dev/null +++ b/python/ray/util/sgd/torch/examples/transformers/utils.py @@ -0,0 +1,212 @@ +# flake8: noqa +import glob +import logging +import os +from tqdm import tqdm +from filelock import FileLock +import numpy as np + +import torch +from torch.utils.data import (DataLoader, SequentialSampler, TensorDataset) + +from transformers import glue_processors as processors +from transformers import glue_compute_metrics as compute_metrics +from transformers import glue_output_modes as output_modes +from transformers import (glue_convert_examples_to_features as + convert_examples_to_features) +from transformers import ( + WEIGHTS_NAME, + AutoModelForSequenceClassification, + AutoTokenizer, +) +logger = logging.getLogger(__name__) + + +def load_and_cache_examples(args, task, tokenizer, evaluate=False): + processor = processors[task]() + output_mode = output_modes[task] + # Load data features from cache or dataset file + cached_features_file = os.path.join( + args.data_dir, + "cached_{}_{}_{}_{}".format( + "dev" if evaluate else "train", + list(filter(None, args.model_name_or_path.split("/"))).pop(), + str(args.max_seq_length), + str(task), + ), + ) + + with FileLock("/tmp/load_and_cache_examples.lock"): + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", + cached_features_file) + features = torch.load(cached_features_file) + else: + logger.info("Creating features from dataset file at %s", + args.data_dir) + label_list = processor.get_labels() + if task in ["mnli", "mnli-mm" + ] and args.model_type in ["roberta", "xlmroberta"]: + # HACK(label indices are swapped in RoBERTa pretrained model) + label_list[1], label_list[2] = label_list[2], label_list[1] + examples = (processor.get_dev_examples(args.data_dir) if evaluate + else processor.get_train_examples(args.data_dir)) + features = convert_examples_to_features( + examples, + tokenizer, + label_list=label_list, + max_length=args.max_seq_length, + output_mode=output_mode, + ) + if not os.path.exists(cached_features_file): + logger.info("Saving features into cached file %s", + cached_features_file) + torch.save(features, cached_features_file) + + # Convert to Tensors and build dataset + all_input_ids = torch.tensor( + [f.input_ids for f in features], dtype=torch.long) + all_attention_mask = torch.tensor( + [f.attention_mask for f in features], dtype=torch.long) + all_token_type_ids = torch.tensor( + [f.token_type_ids for f in features], dtype=torch.long) + if output_mode == "classification": + all_labels = torch.tensor( + [f.label for f in features], dtype=torch.long) + elif output_mode == "regression": + all_labels = torch.tensor( + [f.label for f in features], dtype=torch.float) + + dataset = TensorDataset(all_input_ids, all_attention_mask, + all_token_type_ids, all_labels) + return dataset + + +def save_and_evaluate_checkpoints(args, model, tokenizer): + # Saving best-practices: if you use defaults names for the model, + # you can reload it using from_pretrained() + if args.do_train: + # Create output directory if needed + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using + # `save_pretrained()`. They can then be + # reloaded using `from_pretrained()` + model_to_save = (model.module if hasattr(model, "module") else + model) # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments + # together with the trained model + torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + + # Load a trained model and vocabulary that you have fine-tuned + model = AutoModelForSequenceClassification.from_pretrained( + args.output_dir) + tokenizer = AutoTokenizer.from_pretrained(args.output_dir) + model.to(args.device) + + results = {} + if args.do_eval: + tokenizer = AutoTokenizer.from_pretrained(args.output_dir) + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list( + os.path.dirname(c) for c in sorted( + glob.glob( + args.output_dir + "/**/" + WEIGHTS_NAME, + recursive=True))) + logging.getLogger("transformers.modeling_utils").setLevel( + logging.WARN) # Reduce logging + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split("-")[ + -1] if len(checkpoints) > 1 else "" + prefix = checkpoint.split("/")[ + -1] if checkpoint.find("checkpoint") != -1 else "" + + model = AutoModelForSequenceClassification.from_pretrained( + checkpoint) + model.to(args.device) + result = evaluate(args, model, tokenizer, prefix=prefix) + result = dict( + (k + "_{}".format(global_step), v) for k, v in result.items()) + results.update(result) + + return results + + +def evaluate(args, model, tokenizer, prefix=""): + # Loop to handle MNLI double evaluation (matched, mis-matched) + eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( + args.task_name, ) + eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM" + ) if args.task_name == "mnli" else (args.output_dir, ) + + results = {} + for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): + eval_dataset = load_and_cache_examples( + args, eval_task, tokenizer, evaluate=True) + + if not os.path.exists(eval_output_dir): + os.makedirs(eval_output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(eval_dataset) + eval_dataloader = DataLoader( + eval_dataset, + sampler=eval_sampler, + batch_size=args.eval_batch_size) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + eval_loss = 0.0 + nb_eval_steps = 0 + preds = None + out_label_ids = None + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = tuple(t.to(args.device) for t in batch) + + with torch.no_grad(): + inputs = { + "input_ids": batch[0], + "attention_mask": batch[1], + "labels": batch[3] + } + if args.model_type != "distilbert": + # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't + # use segment_ids + inputs["token_type_ids"] = (batch[2] + if args.model_type in [ + "bert", "xlnet", "albert" + ] else None) + outputs = model(**inputs) + tmp_eval_loss, logits = outputs[:2] + + eval_loss += tmp_eval_loss.mean().item() + nb_eval_steps += 1 + if preds is None: + preds = logits.detach().cpu().numpy() + out_label_ids = inputs["labels"].detach().cpu().numpy() + else: + preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) + out_label_ids = np.append( + out_label_ids, + inputs["labels"].detach().cpu().numpy(), + axis=0) + + eval_loss = eval_loss / nb_eval_steps + if args.output_mode == "classification": + preds = np.argmax(preds, axis=1) + elif args.output_mode == "regression": + preds = np.squeeze(preds) + result = compute_metrics(eval_task, preds, out_label_ids) + results.update(result) + return results