[Tune] Transformer blog example (#9789)

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Amog Kamsetty 2020-08-04 22:05:01 -07:00 committed by GitHub
parent ead8b86372
commit 5af7d24f66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 443 additions and 0 deletions

View file

@ -161,6 +161,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \
--smoke-test
$SUPPRESS_OUTPUT docker run --rm --shm-size${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 "$DOCKER_SHA" \
python /ray/python/ray/tune/examples/pbt_transformers/pbt_transformers.py \
--smoke-test
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 "$DOCKER_SHA" \
python /ray/ci/long_running_distributed_tests/workloads/pytorch_pbt_failure.py \
--smoke-test

View file

@ -18,6 +18,7 @@ redis
sphinx==3.0.4
sphinx-click
sphinx-copybutton
sphinxemoji
sphinx-gallery
sphinx-jsonschema
sphinx-tabs

View file

@ -67,6 +67,7 @@ extensions = [
'sphinx_tabs.tabs',
'sphinx-jsonschema',
'sphinx_gallery.gen_gallery',
'sphinxemoji.sphinxemoji',
'sphinx_copybutton',
'versionwarning.extension',
]

View file

@ -56,6 +56,11 @@ LightGBM Example
- :doc:`/tune/examples/lightgbm_example`: Trains a basic LightGBM model with Tune with the function-based API and a LightGBM callback.
|:hugging_face:| Huggingface Transformers Example
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- :doc:`/tune/examples/pbt_transformers`: Fine-tunes a Huggingface transformer with Tune Population Based Training.
Contributed Examples
~~~~~~~~~~~~~~~~~~~~

View file

@ -0,0 +1,8 @@
:orphan:
pbt_transformers_example
~~~~~~~~~~~~~~~~~~~~~~~~
.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/pbt_transformers.py
.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/trainer.py
.. literalinclude:: /../../python/ray/tune/examples/pbt_transformers/utils.py

View file

@ -474,6 +474,16 @@ py_test(
# args = ["--smoke-test"]
# )
py_test(
name = "pbt_transformers",
size = "large",
srcs = ["examples/pbt_transformers/pbt_transformers.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example"],
args = ["--smoke-test"]
)
# Requires GPUs. Add smoke test?
# py_test(
# name = "pbt_tune_cifar10_with_keras",

View file

@ -52,6 +52,11 @@ LightGBM Example
- `lightgbm_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/lightgbm_example.py>`__: Trains a basic LightGBM model with Tune with the function-based API and a LightGBM callback.
Huggingface Transformers Example
--------------------------------
- `pbt_transformers <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_transformers/pbt_transformers.py>`__: Fine-tunes a Huggingface transformer with Tune Population Based Training.
Contributed Examples
--------------------

View file

@ -0,0 +1,262 @@
import os
import ray
from ray.tune import CLIReporter
from ray.tune.schedulers import PopulationBasedTraining
from ray import tune
from ray.tune.examples.pbt_transformers.utils import \
build_compute_metrics_fn, download_data
from ray.tune.examples.pbt_transformers import trainer
from transformers import (AutoConfig, AutoModelForSequenceClassification,
AutoTokenizer, GlueDataset, GlueDataTrainingArguments
as DataTrainingArguments, glue_tasks_num_labels,
Trainer, TrainingArguments)
def get_trainer(model_name_or_path,
train_dataset,
eval_dataset,
task_name,
training_args,
wandb_args=None):
try:
num_labels = glue_tasks_num_labels[task_name]
except KeyError:
raise ValueError("Task not found: %s" % (task_name))
config = AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
finetuning_task=task_name,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
config=config,
)
tune_trainer = trainer.TuneTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=build_compute_metrics_fn(task_name),
wandb_args=wandb_args)
return tune_trainer
def recover_checkpoint(tune_checkpoint_dir, model_name=None):
if tune_checkpoint_dir is None or len(tune_checkpoint_dir) == 0:
return model_name
# Get subdirectory used for Huggingface.
subdirs = [
os.path.join(tune_checkpoint_dir, name)
for name in os.listdir(tune_checkpoint_dir)
if os.path.isdir(os.path.join(tune_checkpoint_dir, name))
]
# There should only be 1 subdir.
assert len(subdirs) == 1, subdirs
return subdirs[0]
# __train_begin__
def train_transformer(config, checkpoint_dir=None):
data_args = DataTrainingArguments(
task_name=config["task_name"], data_dir=config["data_dir"])
tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
train_dataset = GlueDataset(
data_args,
tokenizer=tokenizer,
mode="train",
cache_dir=config["data_dir"])
eval_dataset = GlueDataset(
data_args,
tokenizer=tokenizer,
mode="dev",
cache_dir=config["data_dir"])
eval_dataset = eval_dataset[:len(eval_dataset) // 2]
training_args = TrainingArguments(
output_dir=tune.get_trial_dir(),
learning_rate=config["learning_rate"],
do_train=True,
do_eval=True,
evaluate_during_training=True,
eval_steps=(len(train_dataset) // config["per_gpu_train_batch_size"]) +
1,
# We explicitly set save to 0, and do saving in evaluate instead
save_steps=0,
num_train_epochs=config["num_epochs"],
max_steps=config["max_steps"],
per_device_train_batch_size=config["per_gpu_train_batch_size"],
per_device_eval_batch_size=config["per_gpu_val_batch_size"],
warmup_steps=0,
weight_decay=config["weight_decay"],
logging_dir="./logs",
)
# Arguments for W&B.
name = tune.get_trial_name()
wandb_args = {
"project_name": "transformers_pbt",
"watch": "false", # Either set to gradient, false, or all
"run_name": name,
}
tune_trainer = get_trainer(
recover_checkpoint(checkpoint_dir, config["model_name"]),
train_dataset,
eval_dataset,
config["task_name"],
training_args,
wandb_args=wandb_args)
tune_trainer.train(
recover_checkpoint(checkpoint_dir, config["model_name"]))
# __train_end__
# __tune_begin__
def tune_transformer(num_samples=8,
gpus_per_trial=0,
smoke_test=False,
ray_address=None):
ray.init(ray_address, log_to_driver=False)
data_dir_name = "./data" if not smoke_test else "./test_data"
data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))
if not os.path.exists(data_dir):
os.mkdir(data_dir, 0o755)
# Change these as needed.
model_name = "bert-base-uncased" if not smoke_test \
else "distilbert-base-uncased"
task_name = "rte"
task_data_dir = os.path.join(data_dir, task_name.upper())
# Download and cache tokenizer, model, and features
print("Downloading and caching Tokenizer")
# Triggers tokenizer download to cache
AutoTokenizer.from_pretrained(model_name)
print("Downloading and caching pre-trained model")
# Triggers model download to cache
AutoModelForSequenceClassification.from_pretrained(model_name)
# Download data.
download_data(task_name, data_dir)
config = {
"model_name": model_name,
"task_name": task_name,
"data_dir": task_data_dir,
"per_gpu_val_batch_size": 32,
"per_gpu_train_batch_size": tune.choice([16, 32, 64]),
"learning_rate": tune.uniform(1e-5, 5e-5),
"weight_decay": tune.uniform(0.0, 0.3),
"num_epochs": tune.choice([2, 3, 4, 5]),
"max_steps": 1 if smoke_test else -1, # Used for smoke test.
}
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="eval_acc",
mode="max",
perturbation_interval=1,
hyperparam_mutations={
"weight_decay": lambda: tune.uniform(0.0, 0.3).func(None),
"learning_rate": lambda: tune.uniform(1e-5, 5e-5).func(None),
"per_gpu_train_batch_size": [16, 32, 64],
})
reporter = CLIReporter(
parameter_columns={
"weight_decay": "w_decay",
"learning_rate": "lr",
"per_gpu_train_batch_size": "train_bs/gpu",
"num_epochs": "num_epochs"
},
metric_columns=[
"eval_acc", "eval_loss", "epoch", "training_iteration"
])
analysis = tune.run(
train_transformer,
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
config=config,
num_samples=num_samples,
scheduler=scheduler,
keep_checkpoints_num=3,
checkpoint_score_attr="training_iteration",
stop={"training_iteration": 1} if smoke_test else None,
progress_reporter=reporter,
local_dir="~/ray_results/",
name="tune_transformer_pbt")
if not smoke_test:
test_best_model(analysis, config["model_name"], config["task_name"],
config["data_dir"])
# __tune_end__
def test_best_model(analysis, model_name, task_name, data_dir):
data_args = DataTrainingArguments(task_name=task_name, data_dir=data_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name)
best_config = analysis.get_best_config(metric="eval_acc", mode="max")
print(best_config)
best_checkpoint = recover_checkpoint(
analysis.get_best_trial(metric="eval_acc",
mode="max").checkpoint.value)
print(best_checkpoint)
best_model = AutoModelForSequenceClassification.from_pretrained(
best_checkpoint).to("cuda")
test_args = TrainingArguments(output_dir="./best_model_results", )
test_dataset = GlueDataset(
data_args, tokenizer=tokenizer, mode="dev", cache_dir=data_dir)
test_dataset = test_dataset[len(test_dataset) // 2:]
test_trainer = Trainer(
best_model,
test_args,
compute_metrics=build_compute_metrics_fn(task_name))
metrics = test_trainer.evaluate(test_dataset)
print(metrics)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--ray-address",
type=str,
default=None,
help="Address to use for Ray. "
"Use \"auto\" for cluster. "
"Defaults to None for local.")
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_transformer(
num_samples=1,
gpus_per_trial=0,
smoke_test=True,
ray_address=args.ray_address)
else:
# You can change the number of GPUs here:
tune_transformer(
num_samples=8, gpus_per_trial=1, ray_address=args.ray_address)

View file

@ -0,0 +1,10 @@
index sentence1 sentence2 label
0 Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation. Christopher Reeve had an accident. not_entailment
1 Yet, we now are discovering that antibiotics are losing their effectiveness against illness. Disease-causing bacteria are mutating faster than we can come up with new antibiotics to fight the new variations. Bacteria is winning the war against antibiotics. entailment
2 Cairo is now home to some 15 million people - a burgeoning population that produces approximately 10,000 tonnes of rubbish per day, putting an enormous strain on public services. In the past 10 years, the government has tried hard to encourage private investment in the refuse sector, but some estimate 4,000 tonnes of waste is left behind every day, festering in the heat as it waits for someone to clear it up. It is often the people in the poorest neighbourhoods that are worst affected. But in some areas they are fighting back. In Shubra, one of the northern districts of the city, the residents have taken to the streets armed with dustpans and brushes to clean up public areas which have been used as public dumps. 15 million tonnes of rubbish are produced daily in Cairo. not_entailment
3 The Amish community in Pennsylvania, which numbers about 55,000, lives an agrarian lifestyle, shunning technological advances like electricity and automobiles. And many say their insular lifestyle gives them a sense that they are protected from the violence of American society. But as residents gathered near the school, some wearing traditional garb and arriving in horse-drawn buggies, they said that sense of safety had been shattered. "If someone snaps and wants to do something stupid, there's no distance that's going to stop them," said Jake King, 56, an Amish lantern maker who knew several families whose children had been shot. Pennsylvania has the biggest Amish community in the U.S. not_entailment
4 Security forces were on high alert after an election campaign in which more than 1,000 people, including seven election candidates, have been killed. Security forces were on high alert after a campaign marred by violence. entailment
5 In 1979, the leaders signed the Egypt-Israel peace treaty on the White House lawn. Both President Begin and Sadat received the Nobel Peace Prize for their work. The two nations have enjoyed peaceful relations to this day. The Israel-Egypt Peace Agreement was signed in 1979. entailment
6 singer and actress Britney Spears, 24, has filled papers in Los Angeles County Superior Court to divorce her husband Kevin Federline, 28. A spokeswoman for the court, Kathy Roberts stated that the papers cited irreconcilable differences" as the reason for the divorce and have, according to the courts, been legally separated as of Monday, November 6, the same day that Spears appeared on Late Night with David Letterman. Spears is to divorce from Kevin Federline. entailment
7 Following the successful bid to bring the 2010 Ryder Cup to Wales, the Wales Tourist Board has wasted little time in commissioning work to ensure that the benefits accruing from the event are felt throughout the country. Wales to host 2010 Ryder Cup. entailment
8 Steve Jobs was attacked by Sculley and other Apple executives for not delivering enough hot new products and resigned from the company a few weeks later. Steve Jobs worked for Apple. entailment
Can't render this file because it contains an unexpected character in line 5 and column 443.

View file

@ -0,0 +1,10 @@
index sentence1 sentence2 label
0 No Weapons of Mass Destruction Found in Iraq Yet. Weapons of Mass Destruction Found in Iraq. not_entailment
1 A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI. Pope Benedict XVI is the new leader of the Roman Catholic Church. entailment
2 Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients. Herceptin can be used to treat breast cancer. entailment
3 Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment. The previous name of Ho Chi Minh City was Saigon. entailment
4 A man is due in court later charged with the murder 26 years ago of a teenager whose case was the first to be featured on BBC One's Crimewatch. Colette Aram, 16, was walking to her boyfriend's house in Keyworth, Nottinghamshire, on 30 October 1983 when she disappeared. Her body was later found in a field close to her home. Paul Stewart Hutchinson, 50, has been charged with murder and is due before Nottingham magistrates later. Paul Stewart Hutchinson is accused of having stabbed a girl. not_entailment
5 Britain said, Friday, that it has barred cleric, Omar Bakri, from returning to the country from Lebanon, where he was released by police after being detained for 24 hours. Bakri was briefly detained, but was released. entailment
6 Nearly 4 million children who have at least one parent who entered the U.S. illegally were born in the United States and are U.S. citizens as a result, according to the study conducted by the Pew Hispanic Center. That's about three quarters of the estimated 5.5 million children of illegal immigrants inside the United States, according to the study. About 1.8 million children of undocumented immigrants live in poverty, the study found. Three quarters of U.S. illegal immigrants have children. not_entailment
7 Like the United States, U.N. officials are also dismayed that Aristide killed a conference called by Prime Minister Robert Malval in Port-au-Prince in hopes of bringing all the feuding parties together. Aristide had Prime Minister Robert Malval murdered in Port-au-Prince. not_entailment
8 WASHINGTON -- A newly declassified narrative of the Bush administration's advice to the CIA on harsh interrogations shows that the small group of Justice Department lawyers who wrote memos authorizing controversial interrogation techniques were operating not on their own but with direction from top administration officials, including then-Vice President Dick Cheney and national security adviser Condoleezza Rice. At the same time, the narrative suggests that then-Defense Secretary Donald H. Rumsfeld and then-Secretary of State Colin Powell were largely left out of the decision-making process. Dick Cheney was the Vice President of Bush. entailment
1 index sentence1 sentence2 label
2 0 No Weapons of Mass Destruction Found in Iraq Yet. Weapons of Mass Destruction Found in Iraq. not_entailment
3 1 A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI. Pope Benedict XVI is the new leader of the Roman Catholic Church. entailment
4 2 Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients. Herceptin can be used to treat breast cancer. entailment
5 3 Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment. The previous name of Ho Chi Minh City was Saigon. entailment
6 4 A man is due in court later charged with the murder 26 years ago of a teenager whose case was the first to be featured on BBC One's Crimewatch. Colette Aram, 16, was walking to her boyfriend's house in Keyworth, Nottinghamshire, on 30 October 1983 when she disappeared. Her body was later found in a field close to her home. Paul Stewart Hutchinson, 50, has been charged with murder and is due before Nottingham magistrates later. Paul Stewart Hutchinson is accused of having stabbed a girl. not_entailment
7 5 Britain said, Friday, that it has barred cleric, Omar Bakri, from returning to the country from Lebanon, where he was released by police after being detained for 24 hours. Bakri was briefly detained, but was released. entailment
8 6 Nearly 4 million children who have at least one parent who entered the U.S. illegally were born in the United States and are U.S. citizens as a result, according to the study conducted by the Pew Hispanic Center. That's about three quarters of the estimated 5.5 million children of illegal immigrants inside the United States, according to the study. About 1.8 million children of undocumented immigrants live in poverty, the study found. Three quarters of U.S. illegal immigrants have children. not_entailment
9 7 Like the United States, U.N. officials are also dismayed that Aristide killed a conference called by Prime Minister Robert Malval in Port-au-Prince in hopes of bringing all the feuding parties together. Aristide had Prime Minister Robert Malval murdered in Port-au-Prince. not_entailment
10 8 WASHINGTON -- A newly declassified narrative of the Bush administration's advice to the CIA on harsh interrogations shows that the small group of Justice Department lawyers who wrote memos authorizing controversial interrogation techniques were operating not on their own but with direction from top administration officials, including then-Vice President Dick Cheney and national security adviser Condoleezza Rice. At the same time, the narrative suggests that then-Defense Secretary Donald H. Rumsfeld and then-Secretary of State Colin Powell were largely left out of the decision-making process. Dick Cheney was the Vice President of Bush. entailment

View file

@ -0,0 +1,80 @@
import logging
import os
from typing import Dict, Optional, Tuple
from ray import tune
import transformers
from transformers.file_utils import is_torch_tpu_available
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import torch
from torch.utils.data import Dataset
import wandb
logger = logging.getLogger(__name__)
"""A Trainer class integrated with Tune.
The only changes to the original transformers.Trainer are:
- Report eval metrics to Tune
- Save state using Tune's checkpoint directories
"""
class TuneTransformerTrainer(transformers.Trainer):
def __init__(self, *args, wandb_args=None, **kwargs):
self.wandb_args = wandb_args
super().__init__(*args, **kwargs)
def get_optimizers(
self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
self.current_optimizer, self.current_scheduler = super(
).get_optimizers(num_training_steps)
return (self.current_optimizer, self.current_scheduler)
def evaluate(self,
eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self._prediction_loop(
eval_dataloader, description="Evaluation")
self._log(output.metrics)
tune.report(**output.metrics)
self.save_state()
return output.metrics
def save_state(self):
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir
# This is the directory name that Huggingface requires.
output_dir = os.path.join(
self.args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
self.save_model(output_dir)
if self.is_world_master():
torch.save(self.current_optimizer.state_dict(),
os.path.join(output_dir, "optimizer.pt"))
torch.save(self.current_scheduler.state_dict(),
os.path.join(output_dir, "scheduler.pt"))
def _setup_wandb(self):
if self.is_world_master() and self.wandb_args is not None:
wandb.init(
project=self.wandb_args["project_name"],
name=self.wandb_args["run_name"],
id=self.wandb_args["run_name"],
dir=tune.get_trial_dir(),
config=vars(self.args),
reinit=True,
allow_val_change=True,
resume=self.wandb_args["run_name"])
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available(
) and self.wandb_args["watch"] != "false":
wandb.watch(
self.model,
log=self.wandb_args["watch"],
log_freq=max(100, self.args.logging_steps))

View file

@ -0,0 +1,46 @@
"""Utilities to load and cache data."""
import os
from typing import Callable, Dict
import numpy as np
from transformers import EvalPrediction
from transformers import glue_compute_metrics, glue_output_modes
def build_compute_metrics_fn(
task_name: str) -> Callable[[EvalPrediction], Dict]:
"""Function from transformers/examples/text-classification/run_glue.py"""
output_mode = glue_output_modes[task_name]
def compute_metrics_fn(p: EvalPrediction):
if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1)
elif output_mode == "regression":
preds = np.squeeze(p.predictions)
metrics = glue_compute_metrics(task_name, preds, p.label_ids)
return metrics
return compute_metrics_fn
def download_data(task_name, data_dir="./data"):
# Download RTE training data
print("Downloading dataset.")
import urllib
import zipfile
if task_name == "rte":
url = "https://firebasestorage.googleapis.com/v0/b/" \
"mtl-sentence-representations.appspot.com" \
"/o/data%2FRTE.zip?alt=media" \
"&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb"
else:
raise ValueError("Unknown task: {}".format(task_name))
data_file = os.path.join(data_dir, "{}.zip".format(task_name))
if not os.path.exists(data_file):
urllib.request.urlretrieve(url, data_file)
with zipfile.ZipFile(data_file) as zip_ref:
zip_ref.extractall(data_dir)
print("Downloaded data for task {} to {}".format(task_name, data_dir))
else:
print("Data already exists. Using downloaded data for task {} from {}".
format(task_name, data_dir))

View file

@ -23,6 +23,7 @@ tensorflow_probability
timm
torch>=1.5.0
torchvision>=0.6.0
transformers
tune-sklearn==0.0.5
wandb
xgboost