mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[AIR] Fix HF checkpointing with same-node workers (#28154)
If we schedule multiple workers on the head node with HuggingFaceTrainer, a race condition can occur where they will begin moving the checkpoint files from their respective rank folders to one checkpoint folder, causing an exception. This PR fixes that and adds a test that would fail without this change. Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
e643b75129
commit
13457dab03
2 changed files with 67 additions and 6 deletions
|
@ -72,11 +72,16 @@ class _SyncedTrackedCheckpoint(_TrackedCheckpoint):
|
|||
target_ip = get_node_ip_address()
|
||||
|
||||
if source_ip == target_ip:
|
||||
# Move contents of source_path, but not source_path
|
||||
# itself. shutil.move is already recursive.
|
||||
for inner in Path(source_path).iterdir():
|
||||
shutil.move(str(inner.absolute()), str(path))
|
||||
shutil.rmtree(source_path, ignore_errors=True)
|
||||
source_path = Path(source_path)
|
||||
for inner in source_path.iterdir():
|
||||
try:
|
||||
shutil.move(str(inner.absolute()), str(path.absolute()))
|
||||
except OSError:
|
||||
# This file may have already been moved by another rank worker.
|
||||
# Disregard, as the files are identical across all ranks.
|
||||
pass
|
||||
# No need to file lock here as each rank worker has its own folder.
|
||||
shutil.rmtree(str(source_path.absolute()), ignore_errors=True)
|
||||
else:
|
||||
sync_dir_between_nodes(
|
||||
source_ip=source_ip,
|
||||
|
|
|
@ -18,6 +18,13 @@ from ray.train.huggingface import HuggingFacePredictor, HuggingFaceTrainer
|
|||
from ray.air.config import ScalingConfig
|
||||
from ray.train.huggingface._huggingface_utils import TrainReportCallback
|
||||
from ray.train.tests._huggingface_data import train_data, validation_data
|
||||
from ray import tune
|
||||
from ray.tune import Tuner
|
||||
from ray.tune.schedulers.async_hyperband import ASHAScheduler
|
||||
from ray.tune.schedulers.resource_changing_scheduler import (
|
||||
DistributeResources,
|
||||
ResourceChangingScheduler,
|
||||
)
|
||||
|
||||
# 16 first rows of tokenized wikitext-2-raw-v1 training & validation
|
||||
train_df = pd.read_json(train_data)
|
||||
|
@ -40,6 +47,14 @@ def ray_start_4_cpus():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_8_cpus():
|
||||
address_info = ray.init(num_cpus=8)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def train_function(train_dataset, eval_dataset=None, **config):
|
||||
model_config = AutoConfig.from_pretrained(model_checkpoint)
|
||||
model = AutoModelForCausalLM.from_config(model_config)
|
||||
|
@ -47,7 +62,7 @@ def train_function(train_dataset, eval_dataset=None, **config):
|
|||
f"{model_checkpoint}-wikitext2",
|
||||
evaluation_strategy=config.pop("evaluation_strategy", "epoch"),
|
||||
num_train_epochs=config.pop("epochs", 3),
|
||||
learning_rate=2e-5,
|
||||
learning_rate=config.pop("learning_rate", 2e-5),
|
||||
weight_decay=0.01,
|
||||
disable_tqdm=True,
|
||||
no_cuda=True,
|
||||
|
@ -167,6 +182,47 @@ def test_validation(ray_start_4_cpus):
|
|||
trainer.fit().error
|
||||
|
||||
|
||||
# Tests if checkpointing and restoring during tuning works correctly.
|
||||
def test_tune(ray_start_8_cpus):
|
||||
ray_train = ray.data.from_pandas(train_df)
|
||||
ray_validation = ray.data.from_pandas(validation_df)
|
||||
scaling_config = ScalingConfig(
|
||||
num_workers=2, use_gpu=False, trainer_resources={"CPU": 0}
|
||||
)
|
||||
trainer = HuggingFaceTrainer(
|
||||
trainer_init_per_worker=train_function,
|
||||
scaling_config=scaling_config,
|
||||
datasets={"train": ray_train, "evaluation": ray_validation},
|
||||
)
|
||||
|
||||
tune_epochs = 5
|
||||
tuner = Tuner(
|
||||
trainer,
|
||||
param_space={
|
||||
"trainer_init_config": {
|
||||
"learning_rate": tune.loguniform(2e-6, 2e-5),
|
||||
"epochs": tune_epochs,
|
||||
"save_strategy": "epoch",
|
||||
}
|
||||
},
|
||||
tune_config=tune.TuneConfig(
|
||||
metric="eval_loss",
|
||||
mode="min",
|
||||
num_samples=3,
|
||||
scheduler=ResourceChangingScheduler(
|
||||
ASHAScheduler(
|
||||
max_t=tune_epochs,
|
||||
),
|
||||
resources_allocation_function=DistributeResources(
|
||||
add_bundles=True, reserve_resources={"CPU": 1}
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
tune_results = tuner.fit()
|
||||
assert not tune_results.errors
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue