mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune/air] Tuner().restore() from cloud URIs (#26963)
Currently, restoring from cloud URIs does not work for Tuner() objects. With this PR, e.g. `Tuner.restore("s3://bucket/exp")` will work. Signed-off-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
parent
50b20809b8
commit
78d6fc689b
2 changed files with 131 additions and 28 deletions
|
@ -1,8 +1,12 @@
|
|||
import copy
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING, Tuple
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.air._internal.remote_storage import download_from_uri, is_non_local_path_uri
|
||||
from ray.air.config import RunConfig, ScalingConfig
|
||||
from ray.tune import Experiment, TuneError, ExperimentAnalysis
|
||||
from ray.tune.execution.trial_runner import _ResumeConfig
|
||||
|
@ -69,38 +73,29 @@ class TunerInternal:
|
|||
):
|
||||
from ray.train.trainer import BaseTrainer
|
||||
|
||||
# Restored from Tuner checkpoint.
|
||||
# If no run config was passed to Tuner directly, use the one from the Trainer,
|
||||
# if available
|
||||
if not run_config and isinstance(trainable, BaseTrainer):
|
||||
run_config = trainable.run_config
|
||||
|
||||
self._tune_config = tune_config or TuneConfig()
|
||||
self._run_config = run_config or RunConfig()
|
||||
|
||||
# Restore from Tuner checkpoint.
|
||||
if restore_path:
|
||||
trainable_ckpt = os.path.join(restore_path, _TRAINABLE_PKL)
|
||||
with open(trainable_ckpt, "rb") as fp:
|
||||
trainable = pickle.load(fp)
|
||||
|
||||
tuner_ckpt = os.path.join(restore_path, _TUNER_PKL)
|
||||
with open(tuner_ckpt, "rb") as fp:
|
||||
tuner = pickle.load(fp)
|
||||
self.__dict__.update(tuner.__dict__)
|
||||
|
||||
self._is_restored = True
|
||||
self._trainable = trainable
|
||||
self._experiment_checkpoint_dir = restore_path
|
||||
self._resume_config = resume_config
|
||||
self._restore_from_path_or_uri(
|
||||
path_or_uri=restore_path, resume_config=resume_config
|
||||
)
|
||||
return
|
||||
|
||||
# Start from fresh
|
||||
if not trainable:
|
||||
raise TuneError("You need to provide a trainable to tune.")
|
||||
|
||||
self._resume_config = None
|
||||
|
||||
# If no run config was passed to Tuner directly, use the one from the Trainer,
|
||||
# if available
|
||||
if not run_config and isinstance(trainable, BaseTrainer):
|
||||
run_config = trainable.run_config
|
||||
|
||||
self._is_restored = False
|
||||
self._trainable = trainable
|
||||
self._tune_config = tune_config or TuneConfig()
|
||||
self._run_config = run_config or RunConfig()
|
||||
self._resume_config = None
|
||||
|
||||
self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
|
||||
self._experiment_checkpoint_dir = self._setup_create_experiment_checkpoint_dir(
|
||||
self._run_config
|
||||
|
@ -116,14 +111,75 @@ class TunerInternal:
|
|||
# without allowing for checkpointing tuner and trainable.
|
||||
# Thus this has to happen before tune.run() so that we can have something
|
||||
# to restore from.
|
||||
tuner_ckpt = os.path.join(self._experiment_checkpoint_dir, _TUNER_PKL)
|
||||
with open(tuner_ckpt, "wb") as fp:
|
||||
experiment_checkpoint_path = Path(self._experiment_checkpoint_dir)
|
||||
with open(experiment_checkpoint_path / _TUNER_PKL, "wb") as fp:
|
||||
pickle.dump(self, fp)
|
||||
|
||||
trainable_ckpt = os.path.join(self._experiment_checkpoint_dir, _TRAINABLE_PKL)
|
||||
with open(trainable_ckpt, "wb") as fp:
|
||||
with open(experiment_checkpoint_path / _TRAINABLE_PKL, "wb") as fp:
|
||||
pickle.dump(self._trainable, fp)
|
||||
|
||||
def _restore_from_path_or_uri(
|
||||
self, path_or_uri: str, resume_config: Optional[_ResumeConfig]
|
||||
):
|
||||
# Sync down from cloud storage if needed
|
||||
synced, experiment_checkpoint_dir = self._maybe_sync_down_tuner_state(
|
||||
path_or_uri
|
||||
)
|
||||
experiment_checkpoint_path = Path(experiment_checkpoint_dir)
|
||||
|
||||
if (
|
||||
not (experiment_checkpoint_path / _TRAINABLE_PKL).exists()
|
||||
or not (experiment_checkpoint_path / _TUNER_PKL).exists()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Could not find Tuner state in restore directory. Did you pass"
|
||||
f"the correct path (including experiment directory?) Got: "
|
||||
f"{path_or_uri}"
|
||||
)
|
||||
|
||||
# Load trainable and tuner state
|
||||
with open(experiment_checkpoint_path / _TRAINABLE_PKL, "rb") as fp:
|
||||
trainable = pickle.load(fp)
|
||||
|
||||
with open(experiment_checkpoint_path / _TUNER_PKL, "rb") as fp:
|
||||
tuner = pickle.load(fp)
|
||||
self.__dict__.update(tuner.__dict__)
|
||||
|
||||
self._is_restored = True
|
||||
self._trainable = trainable
|
||||
self._resume_config = resume_config
|
||||
|
||||
if not synced:
|
||||
# If we didn't sync, use the restore_path local dir
|
||||
self._experiment_checkpoint_dir = path_or_uri
|
||||
else:
|
||||
# If we synced, `experiment_checkpoint_dir` will contain a temporary
|
||||
# directory. Create an experiment checkpoint dir instead and move
|
||||
# our data there.
|
||||
new_exp_path = Path(
|
||||
self._setup_create_experiment_checkpoint_dir(self._run_config)
|
||||
)
|
||||
for file_dir in experiment_checkpoint_path.glob("*"):
|
||||
file_dir.rename(new_exp_path / file_dir.name)
|
||||
shutil.rmtree(experiment_checkpoint_path)
|
||||
self._experiment_checkpoint_dir = str(new_exp_path)
|
||||
|
||||
def _maybe_sync_down_tuner_state(self, restore_path: str) -> Tuple[bool, str]:
|
||||
"""Sync down trainable state from remote storage.
|
||||
|
||||
Returns:
|
||||
Tuple of (downloaded from remote, local_dir)
|
||||
"""
|
||||
if not is_non_local_path_uri(restore_path):
|
||||
return False, restore_path
|
||||
|
||||
tempdir = Path(tempfile.mkdtemp("tmp_experiment_dir"))
|
||||
|
||||
path = Path(restore_path)
|
||||
download_from_uri(str(path / _TRAINABLE_PKL), str(tempdir / _TRAINABLE_PKL))
|
||||
download_from_uri(str(path / _TUNER_PKL), str(tempdir / _TUNER_PKL))
|
||||
return True, str(tempdir)
|
||||
|
||||
def _process_scaling_config(self) -> None:
|
||||
"""Converts ``self._param_space["scaling_config"]`` to a dict.
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@ import pytest
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.air import RunConfig, Checkpoint, session, FailureConfig
|
||||
from ray.air._internal.remote_storage import download_from_uri
|
||||
from ray.tune import Callback
|
||||
from ray.tune.execution.trial_runner import find_newest_experiment_checkpoint
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.tune.tune_config import TuneConfig
|
||||
from ray.tune.tuner import Tuner
|
||||
|
@ -301,6 +303,51 @@ def test_tuner_resume_errored_only(ray_start_2_cpus, tmpdir):
|
|||
assert sorted([r.metrics.get("it", 0) for r in results]) == sorted([2, 1, 3, 0])
|
||||
|
||||
|
||||
def test_tuner_restore_from_cloud(ray_start_2_cpus, tmpdir):
|
||||
"""Check that restoring Tuner() objects from cloud storage works"""
|
||||
tuner = Tuner(
|
||||
lambda config: 1,
|
||||
run_config=RunConfig(
|
||||
name="exp_dir",
|
||||
local_dir=str(tmpdir / "ray_results"),
|
||||
sync_config=tune.SyncConfig(upload_dir="memory:///test/restore"),
|
||||
),
|
||||
)
|
||||
tuner.fit()
|
||||
|
||||
check_path = tmpdir / "check_save"
|
||||
download_from_uri("memory:///test/restore", str(check_path))
|
||||
remote_contents = os.listdir(check_path / "exp_dir")
|
||||
|
||||
assert "tuner.pkl" in remote_contents
|
||||
assert "trainable.pkl" in remote_contents
|
||||
|
||||
prev_cp = find_newest_experiment_checkpoint(str(check_path / "exp_dir"))
|
||||
prev_lstat = os.lstat(prev_cp)
|
||||
|
||||
(tmpdir / "ray_results").remove(ignore_errors=True)
|
||||
|
||||
tuner2 = Tuner.restore("memory:///test/restore/exp_dir")
|
||||
results = tuner2.fit()
|
||||
|
||||
assert results[0].metrics["_metric"] == 1
|
||||
local_contents = os.listdir(tmpdir / "ray_results" / "exp_dir")
|
||||
assert "tuner.pkl" in local_contents
|
||||
assert "trainable.pkl" in local_contents
|
||||
|
||||
after_cp = find_newest_experiment_checkpoint(
|
||||
str(tmpdir / "ray_results" / "exp_dir")
|
||||
)
|
||||
after_lstat = os.lstat(after_cp)
|
||||
|
||||
# Experiment checkpoint was updated
|
||||
assert os.path.basename(prev_cp) != os.path.basename(after_cp)
|
||||
# Old experiment checkpoint still exists in dir
|
||||
assert os.path.basename(prev_cp) in local_contents
|
||||
# Contents changed
|
||||
assert prev_lstat.st_size != after_lstat.st_size
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue