[tune/air] Tuner().restore() from cloud URIs (#26963)

Currently, restoring from cloud URIs does not work for Tuner() objects. With this PR, e.g. `Tuner.restore("s3://bucket/exp")` will work.

Signed-off-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Kai Fricke 2022-07-26 12:20:07 +01:00 committed by GitHub
parent 50b20809b8
commit 78d6fc689b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 28 deletions

View file

@ -1,8 +1,12 @@
import copy
import os
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING
import shutil
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING, Tuple
import ray.cloudpickle as pickle
from ray.air._internal.remote_storage import download_from_uri, is_non_local_path_uri
from ray.air.config import RunConfig, ScalingConfig
from ray.tune import Experiment, TuneError, ExperimentAnalysis
from ray.tune.execution.trial_runner import _ResumeConfig
@ -69,38 +73,29 @@ class TunerInternal:
):
from ray.train.trainer import BaseTrainer
# Restored from Tuner checkpoint.
# If no run config was passed to Tuner directly, use the one from the Trainer,
# if available
if not run_config and isinstance(trainable, BaseTrainer):
run_config = trainable.run_config
self._tune_config = tune_config or TuneConfig()
self._run_config = run_config or RunConfig()
# Restore from Tuner checkpoint.
if restore_path:
trainable_ckpt = os.path.join(restore_path, _TRAINABLE_PKL)
with open(trainable_ckpt, "rb") as fp:
trainable = pickle.load(fp)
tuner_ckpt = os.path.join(restore_path, _TUNER_PKL)
with open(tuner_ckpt, "rb") as fp:
tuner = pickle.load(fp)
self.__dict__.update(tuner.__dict__)
self._is_restored = True
self._trainable = trainable
self._experiment_checkpoint_dir = restore_path
self._resume_config = resume_config
self._restore_from_path_or_uri(
path_or_uri=restore_path, resume_config=resume_config
)
return
# Start from fresh
if not trainable:
raise TuneError("You need to provide a trainable to tune.")
self._resume_config = None
# If no run config was passed to Tuner directly, use the one from the Trainer,
# if available
if not run_config and isinstance(trainable, BaseTrainer):
run_config = trainable.run_config
self._is_restored = False
self._trainable = trainable
self._tune_config = tune_config or TuneConfig()
self._run_config = run_config or RunConfig()
self._resume_config = None
self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
self._experiment_checkpoint_dir = self._setup_create_experiment_checkpoint_dir(
self._run_config
@ -116,14 +111,75 @@ class TunerInternal:
# without allowing for checkpointing tuner and trainable.
# Thus this has to happen before tune.run() so that we can have something
# to restore from.
tuner_ckpt = os.path.join(self._experiment_checkpoint_dir, _TUNER_PKL)
with open(tuner_ckpt, "wb") as fp:
experiment_checkpoint_path = Path(self._experiment_checkpoint_dir)
with open(experiment_checkpoint_path / _TUNER_PKL, "wb") as fp:
pickle.dump(self, fp)
trainable_ckpt = os.path.join(self._experiment_checkpoint_dir, _TRAINABLE_PKL)
with open(trainable_ckpt, "wb") as fp:
with open(experiment_checkpoint_path / _TRAINABLE_PKL, "wb") as fp:
pickle.dump(self._trainable, fp)
def _restore_from_path_or_uri(
self, path_or_uri: str, resume_config: Optional[_ResumeConfig]
):
# Sync down from cloud storage if needed
synced, experiment_checkpoint_dir = self._maybe_sync_down_tuner_state(
path_or_uri
)
experiment_checkpoint_path = Path(experiment_checkpoint_dir)
if (
not (experiment_checkpoint_path / _TRAINABLE_PKL).exists()
or not (experiment_checkpoint_path / _TUNER_PKL).exists()
):
raise RuntimeError(
f"Could not find Tuner state in restore directory. Did you pass"
f"the correct path (including experiment directory?) Got: "
f"{path_or_uri}"
)
# Load trainable and tuner state
with open(experiment_checkpoint_path / _TRAINABLE_PKL, "rb") as fp:
trainable = pickle.load(fp)
with open(experiment_checkpoint_path / _TUNER_PKL, "rb") as fp:
tuner = pickle.load(fp)
self.__dict__.update(tuner.__dict__)
self._is_restored = True
self._trainable = trainable
self._resume_config = resume_config
if not synced:
# If we didn't sync, use the restore_path local dir
self._experiment_checkpoint_dir = path_or_uri
else:
# If we synced, `experiment_checkpoint_dir` will contain a temporary
# directory. Create an experiment checkpoint dir instead and move
# our data there.
new_exp_path = Path(
self._setup_create_experiment_checkpoint_dir(self._run_config)
)
for file_dir in experiment_checkpoint_path.glob("*"):
file_dir.rename(new_exp_path / file_dir.name)
shutil.rmtree(experiment_checkpoint_path)
self._experiment_checkpoint_dir = str(new_exp_path)
def _maybe_sync_down_tuner_state(self, restore_path: str) -> Tuple[bool, str]:
"""Sync down trainable state from remote storage.
Returns:
Tuple of (downloaded from remote, local_dir)
"""
if not is_non_local_path_uri(restore_path):
return False, restore_path
tempdir = Path(tempfile.mkdtemp("tmp_experiment_dir"))
path = Path(restore_path)
download_from_uri(str(path / _TRAINABLE_PKL), str(tempdir / _TRAINABLE_PKL))
download_from_uri(str(path / _TUNER_PKL), str(tempdir / _TUNER_PKL))
return True, str(tempdir)
def _process_scaling_config(self) -> None:
"""Converts ``self._param_space["scaling_config"]`` to a dict.

View file

@ -5,7 +5,9 @@ import pytest
import ray
from ray import tune
from ray.air import RunConfig, Checkpoint, session, FailureConfig
from ray.air._internal.remote_storage import download_from_uri
from ray.tune import Callback
from ray.tune.execution.trial_runner import find_newest_experiment_checkpoint
from ray.tune.experiment import Trial
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
@ -301,6 +303,51 @@ def test_tuner_resume_errored_only(ray_start_2_cpus, tmpdir):
assert sorted([r.metrics.get("it", 0) for r in results]) == sorted([2, 1, 3, 0])
def test_tuner_restore_from_cloud(ray_start_2_cpus, tmpdir):
"""Check that restoring Tuner() objects from cloud storage works"""
tuner = Tuner(
lambda config: 1,
run_config=RunConfig(
name="exp_dir",
local_dir=str(tmpdir / "ray_results"),
sync_config=tune.SyncConfig(upload_dir="memory:///test/restore"),
),
)
tuner.fit()
check_path = tmpdir / "check_save"
download_from_uri("memory:///test/restore", str(check_path))
remote_contents = os.listdir(check_path / "exp_dir")
assert "tuner.pkl" in remote_contents
assert "trainable.pkl" in remote_contents
prev_cp = find_newest_experiment_checkpoint(str(check_path / "exp_dir"))
prev_lstat = os.lstat(prev_cp)
(tmpdir / "ray_results").remove(ignore_errors=True)
tuner2 = Tuner.restore("memory:///test/restore/exp_dir")
results = tuner2.fit()
assert results[0].metrics["_metric"] == 1
local_contents = os.listdir(tmpdir / "ray_results" / "exp_dir")
assert "tuner.pkl" in local_contents
assert "trainable.pkl" in local_contents
after_cp = find_newest_experiment_checkpoint(
str(tmpdir / "ray_results" / "exp_dir")
)
after_lstat = os.lstat(after_cp)
# Experiment checkpoint was updated
assert os.path.basename(prev_cp) != os.path.basename(after_cp)
# Old experiment checkpoint still exists in dir
assert os.path.basename(prev_cp) in local_contents
# Contents changed
assert prev_lstat.st_size != after_lstat.st_size
if __name__ == "__main__":
import sys