mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[sgd] HuggingFace GLUE Fine-tuning Example (#7792)
* Init fp16 * fp16 and schedulers * scheduler linking and fp16 * to fp16 * loss scaling and documentation * more documentation * add tests, refactor config * moredocs * more docs * fix logo, add test mode, add fp16 flag * fix tests * fix scheduler * fix apex * improve safety * fix tests * fix tests * remove pin memory default * rm * fix * Update doc/examples/doc_code/raysgd_torch_signatures.py * fix * migrate changes from other PR * ok thanks * pass * signatures * lint' * Update python/ray/experimental/sgd/pytorch/utils.py * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * should address most comments * comments * fix this ci * first_pass * add overrides * override * fixing up operators * format * sgd * constants * rm * revert * save * failures * fixes * trainer * run test * operator * code * op * ok done * operator * sgd test fixes * ok * trainer * format * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * Update doc/source/raysgd/raysgd_pytorch.rst * docstring * dcgan * doc * commits * nit * testing * revert * Start renaming pytorch to torch * Rename PyTorchTrainer to TorchTrainer * Rename PyTorch runners to Torch runners * Finish renaming API * Rename to torch in tests * Finish renaming docs + tests * Run format + fix DeprecationWarning * fix * move tests up * benchmarks * rename * remove some args * better metrics output * fix up the benchmark * benchmark-yaml * horovod-benchmark * benchmarks * Remove benchmark code for cleanups * benchmark-code * nits * benchmark yamls * benchmark yaml * ok * ok * ok * benchmark * nit * finish_bench * makedatacreator * relax * metrics * autosetsampler * profile * movements * OK * smoothen * fix * nitdocs * loss * envflag * comments * nit * format * visible * images * move_images * fix * rernder * rrender * rest * multgpu * fix * nit * finish * extrra * setup * experimental * as_trainable * fix * ok * format * create_torch_pbt * setup_pbt * ok * format * ok * format * docs * ok * Draft head-is-worker * Fix missing concurrency between local and remote workers * Fix tqdm to work with head-is-worker * Cleanup * Implement state_dict and load_state_dict * Reserve resources on the head node for the local worker * Update the development cluster setup * Add spot block reservation to the development yaml * ok * Draft the fault tolerance fix * Small fixes to local-remote concurrency * Cleanup + fix typo * fixes * worker_counts * some formatting and asha * fix * okme * fixactorkill * unify * Revert the cluster mounts * Cut the handler-reporter API * Fix most tests * Rm tqdm_handler.py * Re-add tune test * Automatically force-shutdown on actor errors on shutdown * Formatting * fix_tune_test * Add timeout error verification * Rename tqdm to use_tqdm * fixtests * ok * remove_redundant * deprecated * deactivated * ok_try_this * lint * nice * done * retries * fixes * kill * retry * init_transformer * init * deployit * improve_example * trans * rename * formats * format-to-py37 * time_to_test * more_changes * ok * update_args_and_script * fp16_epoch * huggingface * training stats * distributed * Apply suggestions from code review * transformer Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com> Co-authored-by: Maksim Smolin <maximsmol@gmail.com>
This commit is contained in:
parent
d6f4e5b3e1
commit
857e4dba2f
6 changed files with 760 additions and 7 deletions
|
@ -708,23 +708,26 @@ TorchTrainer Examples
|
|||
Here are some examples of using RaySGD for training PyTorch models. If you'd like
|
||||
to contribute an example, feel free to create a `pull request here <https://github.com/ray-project/ray/>`_.
|
||||
|
||||
- `Torch training example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/train_example.py>`__:
|
||||
- `Torch training example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/train_example.py>`__
|
||||
Simple example of using Ray's TorchTrainer.
|
||||
|
||||
- `TorchTrainer and RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/tune_example.py>`__:
|
||||
- `TorchTrainer and RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/tune_example.py>`__
|
||||
Simple example of hyperparameter tuning with Ray's TorchTrainer.
|
||||
|
||||
- `Semantic Segmentation example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py>`__:
|
||||
- `Semantic Segmentation example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py>`__
|
||||
Fine-tuning a ResNet50 model on VOC with Batch Norm.
|
||||
|
||||
- `ImageNet Models example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/image_models/train.py>`__:
|
||||
- `Huggingface Transformer GLUE fine tuning example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/transformers/transformers_example.py>`__
|
||||
Fine-tuning a pre-trained Transformer model on GLUE tasks. Based off of the `huggingface/transformers <https://github.com/huggingface/transformers/blob/master/examples/>`_ ``run_glue.py`` example.
|
||||
|
||||
- `ImageNet Models example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/image_models/train.py>`__
|
||||
Training state-of-the-art ImageNet models.
|
||||
|
||||
- `CIFAR10 example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py>`__:
|
||||
- `CIFAR10 example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py>`__
|
||||
Training a ResNet18 model on CIFAR10.
|
||||
|
||||
- `CIFAR10 RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py>`__:
|
||||
- `CIFAR10 RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_pbt.py>`__
|
||||
Tuning a ResNet18 model on CIFAR10 with Population-based training on RayTune.
|
||||
|
||||
- `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`__:
|
||||
- `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`__
|
||||
Training a Deep Convolutional GAN on MNIST. It constructs two models and two optimizers and uses a custom training operator.
|
||||
|
|
89
python/ray/util/sgd/torch/examples/transformers/README.rst
Normal file
89
python/ray/util/sgd/torch/examples/transformers/README.rst
Normal file
|
@ -0,0 +1,89 @@
|
|||
HuggingFace Transformers Glue Fine-tuning Example
|
||||
=================================================
|
||||
|
||||
We've ported the ``huggingface/transformers/examples/run_glue.py`` example to
|
||||
RaySGD. This example enables fine-tuning the library models for sequence classification on the GLUE benchmark: General Language Understanding Evaluation.
|
||||
|
||||
This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
|
||||
|
||||
The below information can be found at the `HuggingFace Repository <https://github.com/huggingface/transformers/tree/master/examples#glue-1>`_ and is copied over at your convenience.
|
||||
|
||||
Before running any one of these GLUE tasks you should download the
|
||||
`GLUE data <https://gluebenchmark.com/tasks>`_ by running
|
||||
`this script <https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e>`_
|
||||
and unpack it to some directory ``$GLUE_DIR``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export GLUE_DIR=/path/to/glue
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python transformers_example.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir glue_data/$TASK_NAME \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/$TASK_NAME/
|
||||
|
||||
where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
|
||||
|
||||
The dev set results will be present within the text file ``eval_results.txt`` in the specified output_dir.
|
||||
In case of MNLI, since there are two separate dev sets (matched and mismatched), there will be a separate
|
||||
output folder called ``/tmp/MNLI-MM/`` in addition to ``/tmp/MNLI/``.
|
||||
|
||||
Multi-GPU training with Apex
|
||||
----------------------------
|
||||
|
||||
To run an example tuning MNLI on your local machine with 8 GPUs and apex, first install `apex <https://github.com/NVIDIA/apex>`_, and then run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python transformers_example.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name mnli \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir glue_data/MNLI/ \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_train_batch_size 8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir output_dir \
|
||||
--num_workers 8
|
||||
--fp16
|
||||
|
||||
|
||||
Multi-node training
|
||||
-------------------
|
||||
|
||||
To run an example tuning MNLI on AWS with 16 GPUs and apex, just run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ray up cluster.yaml
|
||||
# Optionally,
|
||||
# ray monitor cluster.yaml
|
||||
ray submit cluster.yaml transformers_example.py -- --model_type bert \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name mnli \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir /home/ubuntu/glue_data/MNLI/ \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_train_batch_size 8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /home/ubuntu/output/ \
|
||||
--num_workers 16 \
|
||||
--fp16 \
|
||||
--address auto
|
||||
|
||||
Note that with Apex, you can increase ``per_gpu_train_batch_size`` to 32, which
|
||||
should make each epoch take 10 minutes or less.
|
78
python/ray/util/sgd/torch/examples/transformers/cluster.yaml
Normal file
78
python/ray/util/sgd/torch/examples/transformers/cluster.yaml
Normal file
|
@ -0,0 +1,78 @@
|
|||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: transformer-cluster
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 3
|
||||
initial_workers: 3
|
||||
max_workers: 3
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-east-1
|
||||
availability_zone: us-east-1c
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
|
||||
head_node:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0698bcaf8bd9ef56d
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 300
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0698bcaf8bd9ef56d
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 300
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
# # Run workers on spot by default. Comment this out to use on-demand.
|
||||
# InstanceMarketOptions:
|
||||
# MarketType: spot
|
||||
|
||||
setup_commands:
|
||||
# This replaces the standard anaconda Ray installation
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- pip install -q tqdm
|
||||
|
||||
# Installing this without -U to make sure we don't replace the existing Ray installation
|
||||
- pip install ray[tune]
|
||||
- pip install -U ipdb torch
|
||||
# Install HuggingFace
|
||||
- git clone https://github.com/huggingface/transformers || true
|
||||
- cd transformers &&
|
||||
pip install . &&
|
||||
pip install -r ./examples/requirements.txt
|
||||
# Download glue
|
||||
- if [[ -e glue_data ]];
|
||||
then echo "not downloading glue";
|
||||
else wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py && python download_glue_data.py;
|
||||
fi
|
||||
|
||||
# Install Apex
|
||||
- git clone https://github.com/NVIDIA/apex;
|
||||
cd apex &&
|
||||
pip install -v --no-cache-dir ./ ||
|
||||
true
|
||||
|
||||
|
||||
file_mounts: {
|
||||
}
|
||||
|
||||
# Custom commands that will be run on the head node after common setup.
|
||||
head_setup_commands: []
|
|
@ -0,0 +1,371 @@
|
|||
# coding=utf-8
|
||||
# This is a modified example originally from The Google AI Language Team
|
||||
# Authors and The HuggingFace Inc. team.
|
||||
# Modified by Richard Liaw.
|
||||
# Copyright 2018 The Google AI Language Team Authors,
|
||||
# The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for sequence classification on GLUE (
|
||||
Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from filelock import FileLock
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from tqdm import trange
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import (MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, AdamW,
|
||||
AutoConfig, AutoModelForSequenceClassification,
|
||||
AutoTokenizer, get_linear_schedule_with_warmup,
|
||||
HfArgumentParser, TrainingArguments)
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.torch.examples.transformers.utils import (
|
||||
evaluate, load_and_cache_examples, save_and_evaluate_checkpoints)
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in MODEL_CONFIG_CLASSES),
|
||||
(),
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def announce_training(args, dataset_len, t_total):
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", dataset_len)
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d",
|
||||
args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accum) = %d",
|
||||
args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
|
||||
args.num_workers,
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d",
|
||||
args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
with FileLock(os.path.expanduser("~/.download.lock")):
|
||||
args = config["args"]
|
||||
processor = processors[args.task_name]()
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def optimizer_creator(model, cfg):
|
||||
args = cfg["args"]
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0
|
||||
},
|
||||
]
|
||||
|
||||
return AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
eps=args.adam_epsilon)
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
args = config["args"]
|
||||
start = time.time()
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name
|
||||
if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
logger.info("tokenizer instantiation time: {}".format(time.time() - start))
|
||||
|
||||
train_dataset = load_and_cache_examples(
|
||||
args, args.task_name, tokenizer, evaluate=False)
|
||||
train_sampler = RandomSampler(
|
||||
train_dataset) if not dist.is_initialized() else None
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.per_gpu_train_batch_size)
|
||||
|
||||
|
||||
class TransformerOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.args = args = config["args"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name
|
||||
if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
)
|
||||
|
||||
self.train_data_len = len(self.train_loader)
|
||||
self._warmup_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=self.calculate_t_total())
|
||||
self._global_step = 0
|
||||
|
||||
announce_training(args, self.train_data_len, self.calculate_t_total())
|
||||
|
||||
def train_batch(self, batch, batch_info=None):
|
||||
args = self.args
|
||||
model = self.model
|
||||
optimizer = self.optimizer
|
||||
step = batch_info["batch_idx"]
|
||||
|
||||
model.train()
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]
|
||||
}
|
||||
if args.model_type != "distilbert":
|
||||
# XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
inputs["token_type_ids"] = (batch[2] if args.model_type in [
|
||||
"bert", "xlnet", "albert"
|
||||
] else None)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# model outputs are always tuple in transformers (see doc)
|
||||
loss = outputs[0]
|
||||
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
batch_loss = loss.item()
|
||||
|
||||
# last step in epoch but step is always smaller
|
||||
# than gradient_accumulation_steps
|
||||
ending = (self.train_data_len <= args.gradient_accumulation_steps
|
||||
and (step + 1) == self.train_data_len)
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 or ending:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
args.max_grad_norm)
|
||||
|
||||
self.optimizer.step()
|
||||
self._warmup_scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
self._global_step += 1
|
||||
|
||||
learning_rate_scalar = self._warmup_scheduler.get_lr()[0]
|
||||
return {"learning_rate": learning_rate_scalar, "loss": batch_loss}
|
||||
|
||||
def calculate_t_total(self):
|
||||
args = self.args
|
||||
grad_accum_steps = args.gradient_accumulation_steps
|
||||
train_data_len = len(self.train_loader)
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (
|
||||
train_data_len // grad_accum_steps) + 1
|
||||
else:
|
||||
t_total = (
|
||||
train_data_len // grad_accum_steps * args.num_train_epochs)
|
||||
return t_total
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""Arguments pertaining to model/config/tokenizer."""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata=dict(help="Path to pre-trained model or shortcut name "
|
||||
"selected in the list: " + ", ".join(ALL_MODELS)))
|
||||
model_type: str = field(
|
||||
metadata=dict(help="Model type selected "
|
||||
"in the list: " + ", ".join(MODEL_TYPES)))
|
||||
config_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata=dict(
|
||||
help="Pretrained config name or path if not the same as model_name"
|
||||
))
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata=dict(help="Pretrained tokenizer name or path "
|
||||
"if not the same as model_name"))
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata=dict(help="Where do you want to store the pre-trained "
|
||||
"models downloaded from s3"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProcessingArguments:
|
||||
task_name: str = field(
|
||||
metadata=dict(help="The name of the task to train selected "
|
||||
"in the list: " + ", ".join(processors.keys())))
|
||||
data_dir: str = field(
|
||||
metadata=dict(help="The input data dir. Should contain "
|
||||
"the .tsv files (or other data files) for the task."))
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata=dict(help="The maximum total input sequence length "
|
||||
"after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences "
|
||||
"shorter will be padded."))
|
||||
overwrite_cache: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RayArguments:
|
||||
num_workers: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of data-parallel workers to use."})
|
||||
address: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Address of the Ray cluster to connect to."})
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataProcessingArguments,
|
||||
TrainingArguments, RayArguments))
|
||||
all_args = parser.parse_args_into_dataclasses()
|
||||
model_args, dataprocessing_args, training_args, ray_args = all_args
|
||||
|
||||
# For now, let's merge all the sets of args into one,
|
||||
# but soon, we'll keep distinct sets of args, with a
|
||||
# cleaner separation of concerns.
|
||||
args = argparse.Namespace(**vars(model_args), **vars(dataprocessing_args),
|
||||
**vars(training_args), **vars(ray_args))
|
||||
|
||||
if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
|
||||
and args.do_train and not args.overwrite_output_dir):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
use_gpu = torch.cuda.is_available() and not args.no_cuda
|
||||
|
||||
# Prepare GLUE task
|
||||
args.task_name = args.task_name.lower()
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO)
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
ray.init(address=args.address)
|
||||
# Training
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
training_operator_cls=TransformerOperator,
|
||||
use_fp16=args.fp16,
|
||||
apex_args={"opt_level": args.fp16_opt_level},
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=use_gpu,
|
||||
use_tqdm=True,
|
||||
config={"args": args})
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
tokenizer = trainer.get_local_operator().tokenizer
|
||||
local_model = trainer.get_model()
|
||||
|
||||
epochs_trained = 0
|
||||
train_iterator = trange(
|
||||
epochs_trained,
|
||||
int(args.num_train_epochs),
|
||||
desc="Epoch",
|
||||
)
|
||||
|
||||
trainer.apply_all_workers(lambda: set_seed(args))
|
||||
if args.do_train:
|
||||
for _ in train_iterator:
|
||||
stats = trainer.train()
|
||||
print("Training stats:", stats)
|
||||
logs = evaluate(args, local_model, tokenizer)
|
||||
print(json.dumps(logs))
|
||||
|
||||
# Post-training validation
|
||||
save_and_evaluate_checkpoints(args, local_model, tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
212
python/ray/util/sgd/torch/examples/transformers/utils.py
Normal file
212
python/ray/util/sgd/torch/examples/transformers/utils.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
# flake8: noqa
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from filelock import FileLock
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, SequentialSampler, TensorDataset)
|
||||
|
||||
from transformers import glue_processors as processors
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import (glue_convert_examples_to_features as
|
||||
convert_examples_to_features)
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
processor = processors[task]()
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
),
|
||||
)
|
||||
|
||||
with FileLock("/tmp/load_and_cache_examples.lock"):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s",
|
||||
cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s",
|
||||
args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ["mnli", "mnli-mm"
|
||||
] and args.model_type in ["roberta", "xlmroberta"]:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (processor.get_dev_examples(args.data_dir) if evaluate
|
||||
else processor.get_train_examples(args.data_dir))
|
||||
features = convert_examples_to_features(
|
||||
examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
)
|
||||
if not os.path.exists(cached_features_file):
|
||||
logger.info("Saving features into cached file %s",
|
||||
cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor(
|
||||
[f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor(
|
||||
[f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor(
|
||||
[f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor(
|
||||
[f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor(
|
||||
[f.label for f in features], dtype=torch.float)
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask,
|
||||
all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def save_and_evaluate_checkpoints(args, model, tokenizer):
|
||||
# Saving best-practices: if you use defaults names for the model,
|
||||
# you can reload it using from_pretrained()
|
||||
if args.do_train:
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using
|
||||
# `save_pretrained()`. They can then be
|
||||
# reloaded using `from_pretrained()`
|
||||
model_to_save = (model.module if hasattr(model, "module") else
|
||||
model) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments
|
||||
# together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
results = {}
|
||||
if args.do_eval:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
os.path.dirname(c) for c in sorted(
|
||||
glob.glob(
|
||||
args.output_dir + "/**/" + WEIGHTS_NAME,
|
||||
recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(
|
||||
logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[
|
||||
-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[
|
||||
-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict(
|
||||
(k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (
|
||||
args.task_name, )
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM"
|
||||
) if args.task_name == "mnli" else (args.output_dir, )
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(
|
||||
args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir):
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.eval_batch_size)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]
|
||||
}
|
||||
if args.model_type != "distilbert":
|
||||
# XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't
|
||||
# use segment_ids
|
||||
inputs["token_type_ids"] = (batch[2]
|
||||
if args.model_type in [
|
||||
"bert", "xlnet", "albert"
|
||||
] else None)
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(
|
||||
out_label_ids,
|
||||
inputs["labels"].detach().cpu().numpy(),
|
||||
axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
return results
|
Loading…
Add table
Reference in a new issue