[tune] Remove TrialCheckpoint class (#25406)

The old user-facing TrialCheckpoint class has been deprecated in favor of `ray.ml.Checkpoint` and will be removed with this PR.

The main change in this PR is to delete the old `TrialCheckpoint` class and replace remaining API calls (e.g. `checkpoint.local_path`) with the correct AIR equivalents.

One issue that comes up is that with Ray client usage, checkpoint directories are not available on the local node (the client). Thus, we can't construct `Checkpoint` objects easily. (Previously, the TrialCheckpoint object held a reference to the location, even if it is not locally available). There are ongoing discussions on how to resolve this in the future. For now, we print an error when such a checkpoint is requested.

Depends on #25805

Signed-off-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Kai Fricke 2022-07-11 20:08:10 +01:00 committed by GitHub
parent 923209895d
commit 753f5feaf4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 92 additions and 223 deletions

View file

@ -529,7 +529,8 @@
"\n",
"def get_best_model_checkpoint(analysis):\n",
" best_bst = xgb.Booster()\n",
" best_bst.load_model(os.path.join(analysis.best_checkpoint, \"model.xgb\"))\n",
" with analysis.best_checkpoint.as_directory() as best_checkpoint_dir:\n",
" best_bst.load_model(os.path.join(best_checkpoint_dir, \"model.xgb\"))\n",
" accuracy = 1.0 - analysis.best_result[\"eval-error\"]\n",
" print(f\"Best model parameters: {analysis.best_config}\")\n",
" print(f\"Best model total accuracy: {accuracy:.4f}\")\n",
@ -574,7 +575,7 @@
" type=str,\n",
" default=None,\n",
" required=False,\n",
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
" help=\"The address of server to connect to if using Ray Client.\",\n",
" )\n",
" args, _ = parser.parse_known_args()\n",
"\n",

View file

@ -7,11 +7,11 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from ray.air.checkpoint import Checkpoint
from ray.tune.cloud import TrialCheckpoint
from ray.tune.syncer import SyncConfig
from ray.tune.utils import flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder
from ray.tune.utils.util import is_nan_or_inf, is_nan
from ray.util import log_once
try:
import pandas as pd
@ -115,10 +115,6 @@ class ExperimentAnalysis:
self._sync_config = sync_config
# If True, will return a legacy TrialCheckpoint class.
# If False, will just return a Checkpoint class.
self._legacy_checkpoint = True
def _parse_cloud_path(self, local_path: str):
"""Convert local path into cloud storage path"""
if not self._sync_config or not self._sync_config.upload_dir:
@ -451,8 +447,12 @@ class ExperimentAnalysis:
raise ValueError("trial should be a string or a Trial instance.")
def get_best_checkpoint(
self, trial: Trial, metric: Optional[str] = None, mode: Optional[str] = None
) -> Optional[Checkpoint]:
self,
trial: Trial,
metric: Optional[str] = None,
mode: Optional[str] = None,
return_path: bool = False,
) -> Optional[Union[Checkpoint, str]]:
"""Gets best persistent checkpoint path of provided trial.
Any checkpoints with an associated metric value of ``nan`` will be filtered out.
@ -463,9 +463,14 @@ class ExperimentAnalysis:
"training_iteration" is used by default if no value was
passed to ``self.default_metric``.
mode: One of [min, max]. Defaults to ``self.default_mode``.
return_path: If True, only returns the path (and not the
``Checkpoint`` object). If using Ray client, it is not
guaranteed that this path is available on the local
(client) node. Can also contain a cloud URI.
Returns:
:class:`Checkpoint <ray.air.Checkpoint>` object.
:class:`Checkpoint <ray.air.Checkpoint>` object or string
if ``return_path=True``.
"""
metric = metric or self.default_metric or TRAINING_ITERATION
mode = self._validate_mode(mode)
@ -487,23 +492,27 @@ class ExperimentAnalysis:
best_path, best_metric = best_path_metrics[0]
cloud_path = self._parse_cloud_path(best_path)
if self._legacy_checkpoint:
return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path)
if cloud_path:
# Prefer cloud path over local path for downsteam processing
if return_path:
return cloud_path
return Checkpoint.from_uri(cloud_path)
elif os.path.exists(best_path):
if return_path:
return best_path
return Checkpoint.from_directory(best_path)
else:
logger.error(
f"No checkpoint locations for {trial} available on "
f"this node. To avoid this, you "
f"should enable checkpoint synchronization with the"
f"`sync_config` argument in Ray Tune. "
f"The checkpoint may be available on a different node - "
f"please check this location on worker nodes: {best_path}"
)
if log_once("checkpoint_not_available"):
logger.error(
f"The requested checkpoint for trial {trial} is not available on "
f"this node, most likely because you are using Ray client or "
f"disabled checkpoint synchronization. To avoid this, enable "
f"checkpoint synchronization to cloud storage by specifying a "
f"`SyncConfig`. The checkpoint may be available on a different "
f"node - please check this location on worker nodes: {best_path}"
)
if return_path:
return best_path
return None
def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:

View file

@ -1,172 +0,0 @@
import os
from typing import Optional
from ray.air.checkpoint import (
Checkpoint,
_get_local_path,
_get_external_path,
)
from ray.util.annotations import Deprecated
@Deprecated
class _TrialCheckpoint(os.PathLike):
def __init__(
self, local_path: Optional[str] = None, cloud_path: Optional[str] = None
):
self._local_path = local_path
self._cloud_path_tcp = cloud_path
@property
def local_path(self):
return self._local_path
@local_path.setter
def local_path(self, path: str):
self._local_path = path
@property
def cloud_path(self):
return self._cloud_path_tcp
@cloud_path.setter
def cloud_path(self, path: str):
self._cloud_path_tcp = path
# The following magic methods are implemented to keep backwards
# compatibility with the old path-based return values.
def __str__(self):
return self.local_path or self.cloud_path
def __fspath__(self):
return self.local_path
def __eq__(self, other):
if isinstance(other, str):
return self.local_path == other
elif isinstance(other, TrialCheckpoint):
return (
self.local_path == other.local_path
and self.cloud_path == other.cloud_path
)
def __add__(self, other):
if isinstance(other, str):
return self.local_path + other
raise NotImplementedError
def __radd__(self, other):
if isinstance(other, str):
return other + self.local_path
raise NotImplementedError
def __repr__(self):
return (
f"<TrialCheckpoint "
f"local_path={self.local_path}, "
f"cloud_path={self.cloud_path}"
f">"
)
# Deprecated: Remove in Ray > 1.13
@Deprecated
class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
def __init__(
self,
local_path: Optional[str] = None,
cloud_path: Optional[str] = None,
):
_TrialCheckpoint.__init__(self)
# Checkpoint does not allow empty data, but TrialCheckpoint
# did. To keep backwards compatibility, we use a placeholder URI
# here, and manually set self._uri and self._local_dir later.
PLACEHOLDER = "s3://placeholder"
Checkpoint.__init__(self, uri=PLACEHOLDER)
# Reset local variables
self._uri = None
self._local_path = None
self._cloud_path_tcp = None
self._local_path_tcp = None
locations = set()
if local_path:
# Add _tcp to not conflict with Checkpoint._local_path
self._local_path_tcp = local_path
if os.path.exists(local_path):
self._local_path = local_path
locations.add(local_path)
if cloud_path:
self._cloud_path_tcp = cloud_path
self._uri = cloud_path
locations.add(cloud_path)
self._locations = locations
@property
def local_path(self):
local_path = _get_local_path(self._local_path)
if not local_path:
for candidate in self._locations:
local_path = _get_local_path(candidate)
if local_path:
break
return local_path or self._local_path_tcp
@local_path.setter
def local_path(self, path: str):
self._local_path = path
if not path or not os.path.exists(path):
return
self._locations.add(path)
@property
def cloud_path(self):
cloud_path = _get_external_path(self._uri)
if not cloud_path:
for candidate in self._locations:
cloud_path = _get_external_path(candidate)
if cloud_path:
break
return cloud_path or self._cloud_path_tcp
@cloud_path.setter
def cloud_path(self, path: str):
self._cloud_path_tcp = path
if not self._uri:
self._uri = path
self._locations.add(path)
def download(
self,
cloud_path: Optional[str] = None,
local_path: Optional[str] = None,
overwrite: bool = False,
) -> str:
# Deprecated: Remove whole class in Ray > 1.13
raise DeprecationWarning(
"`checkpoint.download()` is deprecated and will be removed in "
"the future. Please use `checkpoint.to_directory()` instead."
)
def upload(
self,
cloud_path: Optional[str] = None,
local_path: Optional[str] = None,
clean_before: bool = False,
):
# Deprecated: Remove whole class in Ray > 1.13
raise DeprecationWarning(
"`checkpoint.upload()` is deprecated and will be removed in "
"the future. Please use `checkpoint.to_uri()` instead."
)
def save(self, path: Optional[str] = None, force_download: bool = False):
# Deprecated: Remove whole class in Ray > 1.13
raise DeprecationWarning(
"`checkpoint.save()` is deprecated and will be removed in "
"the future. Please use `checkpoint.to_directory()` or"
"`checkpoint.to_uri()` instead."
)

View file

@ -68,14 +68,16 @@ def train_convnet(config):
def test_best_model(analysis):
"""Test the best model given output of tune.run"""
best_checkpoint_path = analysis.best_checkpoint
best_model = ConvNet()
best_checkpoint = torch.load(os.path.join(best_checkpoint_path, "checkpoint.pt"))
best_model.load_state_dict(best_checkpoint["model_state_dict"])
# Note that test only runs on a small random set of the test data, thus the
# accuracy may be different from metrics shown in tuning process.
test_acc = test(best_model, get_data_loaders()[1])
print("best model accuracy: ", test_acc)
with analysis.best_checkpoint.as_directory() as best_checkpoint_path:
best_model = ConvNet()
best_checkpoint = torch.load(
os.path.join(best_checkpoint_path, "checkpoint.pt")
)
best_model.load_state_dict(best_checkpoint["model_state_dict"])
# Note that test only runs on a small random set of the test data, thus the
# accuracy may be different from metrics shown in tuning process.
test_acc = test(best_model, get_data_loaders()[1])
print("best model accuracy: ", test_acc)
if __name__ == "__main__":

View file

@ -62,7 +62,8 @@ def train_breast_cancer_cv(config: dict):
def get_best_model_checkpoint(analysis):
best_bst = xgb.Booster()
best_bst.load_model(os.path.join(analysis.best_checkpoint, "model.xgb"))
with analysis.best_checkpoint.as_directory() as checkpoint_dir:
best_bst.load_model(os.path.join(checkpoint_dir, "model.xgb"))
accuracy = 1.0 - analysis.best_result["test-error"]
print(f"Best model parameters: {analysis.best_config}")
print(f"Best model total accuracy: {accuracy:.4f}")

View file

@ -167,7 +167,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
best_checkpoint = self.ea.get_best_checkpoint(
best_trial, self.metric, mode="max"
)
assert expected_path == best_checkpoint
assert expected_path == best_checkpoint._local_path
def testGetBestCheckpointNan(self):
"""Tests if nan values are excluded from best checkpoint."""
@ -196,7 +196,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
],
key=lambda x: x[1],
)[0]
assert best_checkpoint == expected_checkpoint_no_nan
assert best_checkpoint._local_path == expected_checkpoint_no_nan
def testGetLastCheckpoint(self):
# one more experiment with 2 iterations
@ -213,7 +213,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
)
# check if it's loaded correctly
last_checkpoint = new_ea.get_last_checkpoint().local_path
last_checkpoint = new_ea.get_last_checkpoint()._local_path
assert self.test_path in last_checkpoint
assert "checkpoint_000002" in last_checkpoint

View file

@ -31,7 +31,6 @@ def test_result_grid(ray_start_2_cpus):
f.write(json.dumps({"step": 0}))
analysis = tune.run(f, config={"a": 1})
analysis._legacy_checkpoint = False
result_grid = ResultGrid(analysis)
result = result_grid[0]
assert isinstance(result.checkpoint, Checkpoint)
@ -98,7 +97,6 @@ def test_result_grid_no_checkpoint(ray_start_2_cpus):
pass
analysis = tune.run(f)
analysis._legacy_checkpoint = False
result_grid = ResultGrid(analysis)
result = result_grid[0]
assert result.checkpoint is None
@ -213,7 +211,6 @@ def test_result_grid_df(ray_start_2_cpus):
tune.report(metric=config["nested"]["param"] * 3)
analysis = tune.run(f, config={"nested": {"param": tune.grid_search([1, 2])}})
analysis._legacy_checkpoint = False
result_grid = ResultGrid(analysis)
assert len(result_grid) == 2

View file

@ -60,7 +60,7 @@ class TrainableUtilTest(unittest.TestCase):
default_mode="max",
)
df = a.dataframe()
checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0]).local_path
checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0])._local_path
assert checkpoint_dir.endswith("/checkpoint_000001/")
def testFindCheckpointDir(self):

View file

@ -17,7 +17,6 @@ from ray.train.torch import TorchTrainer
from ray.train.trainer import BaseTrainer
from ray.train.xgboost import XGBoostTrainer
from ray.tune import Callback, TuneError
from ray.tune.cloud import TrialCheckpoint
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
@ -128,7 +127,6 @@ class TunerTest(unittest.TestCase):
_tuner_kwargs={"max_concurrent_trials": 1},
)
results = tuner.fit()
assert not isinstance(results.get_best_result().checkpoint, TrialCheckpoint)
assert len(results) == 4
def test_tuner_with_xgboost_trainer_driver_fail_and_resume(self):

View file

@ -516,6 +516,13 @@ class FunctionTrainable(Trainable):
# as a new checkpoint.
self._status_reporter.set_checkpoint(checkpoint, is_new=False)
def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir
)
checkpoint.to_directory(self.temp_checkpoint_dir)
self.restore(self.temp_checkpoint_dir)
def restore_from_object(self, obj):
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir

View file

@ -17,7 +17,6 @@ from ray.air.checkpoint import (
Checkpoint,
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY,
)
from ray.tune.cloud import TrialCheckpoint
from ray.tune.resources import Resources
from ray.tune.result import (
DEBUG_METRICS,
@ -553,7 +552,18 @@ class Trainable:
shutil.rmtree(temp_container_dir)
return obj_ref
def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None):
def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
with checkpoint.as_directory() as converted_checkpoint_path:
return self.restore(
checkpoint_path=converted_checkpoint_path,
checkpoint_node_ip=None,
)
def restore(
self,
checkpoint_path: Union[str, Checkpoint],
checkpoint_node_ip: Optional[str] = None,
):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
@ -585,9 +595,9 @@ class Trainable:
on cloud storage.
"""
# Ensure TrialCheckpoints are converted
if isinstance(checkpoint_path, TrialCheckpoint):
checkpoint_path = checkpoint_path.local_path
# Ensure Checkpoints are converted
if isinstance(checkpoint_path, Checkpoint):
return self._restore_from_checkpoint_obj(checkpoint_path)
if not self._maybe_load_from_cloud(checkpoint_path) and (
# If a checkpoint source IP is given
@ -662,15 +672,15 @@ class Trainable:
with checkpoint.as_directory() as checkpoint_path:
self.restore(checkpoint_path)
def delete_checkpoint(self, checkpoint_path: str):
def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]):
"""Deletes local copy of checkpoint.
Args:
checkpoint_path: Path to checkpoint.
"""
# Ensure TrialCheckpoints are converted
if isinstance(checkpoint_path, TrialCheckpoint):
checkpoint_path = checkpoint_path.local_path
# Ensure Checkpoints are converted
if isinstance(checkpoint_path, Checkpoint) and checkpoint_path._local_path:
checkpoint_path = checkpoint_path._local_path
try:
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

View file

@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import ray
from ray.air import Checkpoint
from ray.tune.result import NODE_IP
from ray.util import log_once
from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
from ray.util.ml_utils.util import is_nan
@ -129,7 +130,20 @@ class _TrackedCheckpoint:
checkpoint_data = ray.get(checkpoint_data)
if isinstance(checkpoint_data, str):
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
try:
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
except FileNotFoundError:
if log_once("checkpoint_not_available"):
logger.error(
f"The requested checkpoint is not available on this node, "
f"most likely because you are using Ray client or disabled "
f"checkpoint synchronization. To avoid this, enable checkpoint "
f"synchronization to cloud storage by specifying a "
f"`SyncConfig`. The checkpoint may be available on a different "
f"node - please check this location on worker nodes: "
f"{checkpoint_data}"
)
return None
checkpoint = Checkpoint.from_directory(checkpoint_dir)
elif isinstance(checkpoint_data, bytes):
checkpoint = Checkpoint.from_bytes(checkpoint_data)

View file

@ -256,8 +256,10 @@ if __name__ == "__main__":
analysis = train_mnist(args.smoke_test, num_workers, use_gpu)
print("Retrieving best model.")
best_checkpoint = analysis.best_checkpoint.local_path
model = get_remote_model(best_checkpoint)
best_checkpoint_path = analysis.get_best_checkpoint(
analysis.best_trial, return_path=True
)
model = get_remote_model(best_checkpoint_path)
print("Setting up Serve.")
setup_serve(model, use_gpu)