mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[tune] Remove TrialCheckpoint class (#25406)
The old user-facing TrialCheckpoint class has been deprecated in favor of `ray.ml.Checkpoint` and will be removed with this PR. The main change in this PR is to delete the old `TrialCheckpoint` class and replace remaining API calls (e.g. `checkpoint.local_path`) with the correct AIR equivalents. One issue that comes up is that with Ray client usage, checkpoint directories are not available on the local node (the client). Thus, we can't construct `Checkpoint` objects easily. (Previously, the TrialCheckpoint object held a reference to the location, even if it is not locally available). There are ongoing discussions on how to resolve this in the future. For now, we print an error when such a checkpoint is requested. Depends on #25805 Signed-off-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
parent
923209895d
commit
753f5feaf4
13 changed files with 92 additions and 223 deletions
|
@ -529,7 +529,8 @@
|
|||
"\n",
|
||||
"def get_best_model_checkpoint(analysis):\n",
|
||||
" best_bst = xgb.Booster()\n",
|
||||
" best_bst.load_model(os.path.join(analysis.best_checkpoint, \"model.xgb\"))\n",
|
||||
" with analysis.best_checkpoint.as_directory() as best_checkpoint_dir:\n",
|
||||
" best_bst.load_model(os.path.join(best_checkpoint_dir, \"model.xgb\"))\n",
|
||||
" accuracy = 1.0 - analysis.best_result[\"eval-error\"]\n",
|
||||
" print(f\"Best model parameters: {analysis.best_config}\")\n",
|
||||
" print(f\"Best model total accuracy: {accuracy:.4f}\")\n",
|
||||
|
@ -574,7 +575,7 @@
|
|||
" type=str,\n",
|
||||
" default=None,\n",
|
||||
" required=False,\n",
|
||||
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
|
||||
" help=\"The address of server to connect to if using Ray Client.\",\n",
|
||||
" )\n",
|
||||
" args, _ = parser.parse_known_args()\n",
|
||||
"\n",
|
||||
|
|
|
@ -7,11 +7,11 @@ from pathlib import Path
|
|||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.tune.cloud import TrialCheckpoint
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.serialization import TuneFunctionDecoder
|
||||
from ray.tune.utils.util import is_nan_or_inf, is_nan
|
||||
from ray.util import log_once
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
|
@ -115,10 +115,6 @@ class ExperimentAnalysis:
|
|||
|
||||
self._sync_config = sync_config
|
||||
|
||||
# If True, will return a legacy TrialCheckpoint class.
|
||||
# If False, will just return a Checkpoint class.
|
||||
self._legacy_checkpoint = True
|
||||
|
||||
def _parse_cloud_path(self, local_path: str):
|
||||
"""Convert local path into cloud storage path"""
|
||||
if not self._sync_config or not self._sync_config.upload_dir:
|
||||
|
@ -451,8 +447,12 @@ class ExperimentAnalysis:
|
|||
raise ValueError("trial should be a string or a Trial instance.")
|
||||
|
||||
def get_best_checkpoint(
|
||||
self, trial: Trial, metric: Optional[str] = None, mode: Optional[str] = None
|
||||
) -> Optional[Checkpoint]:
|
||||
self,
|
||||
trial: Trial,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
return_path: bool = False,
|
||||
) -> Optional[Union[Checkpoint, str]]:
|
||||
"""Gets best persistent checkpoint path of provided trial.
|
||||
|
||||
Any checkpoints with an associated metric value of ``nan`` will be filtered out.
|
||||
|
@ -463,9 +463,14 @@ class ExperimentAnalysis:
|
|||
"training_iteration" is used by default if no value was
|
||||
passed to ``self.default_metric``.
|
||||
mode: One of [min, max]. Defaults to ``self.default_mode``.
|
||||
return_path: If True, only returns the path (and not the
|
||||
``Checkpoint`` object). If using Ray client, it is not
|
||||
guaranteed that this path is available on the local
|
||||
(client) node. Can also contain a cloud URI.
|
||||
|
||||
Returns:
|
||||
:class:`Checkpoint <ray.air.Checkpoint>` object.
|
||||
:class:`Checkpoint <ray.air.Checkpoint>` object or string
|
||||
if ``return_path=True``.
|
||||
"""
|
||||
metric = metric or self.default_metric or TRAINING_ITERATION
|
||||
mode = self._validate_mode(mode)
|
||||
|
@ -487,23 +492,27 @@ class ExperimentAnalysis:
|
|||
best_path, best_metric = best_path_metrics[0]
|
||||
cloud_path = self._parse_cloud_path(best_path)
|
||||
|
||||
if self._legacy_checkpoint:
|
||||
return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path)
|
||||
|
||||
if cloud_path:
|
||||
# Prefer cloud path over local path for downsteam processing
|
||||
if return_path:
|
||||
return cloud_path
|
||||
return Checkpoint.from_uri(cloud_path)
|
||||
elif os.path.exists(best_path):
|
||||
if return_path:
|
||||
return best_path
|
||||
return Checkpoint.from_directory(best_path)
|
||||
else:
|
||||
logger.error(
|
||||
f"No checkpoint locations for {trial} available on "
|
||||
f"this node. To avoid this, you "
|
||||
f"should enable checkpoint synchronization with the"
|
||||
f"`sync_config` argument in Ray Tune. "
|
||||
f"The checkpoint may be available on a different node - "
|
||||
f"please check this location on worker nodes: {best_path}"
|
||||
)
|
||||
if log_once("checkpoint_not_available"):
|
||||
logger.error(
|
||||
f"The requested checkpoint for trial {trial} is not available on "
|
||||
f"this node, most likely because you are using Ray client or "
|
||||
f"disabled checkpoint synchronization. To avoid this, enable "
|
||||
f"checkpoint synchronization to cloud storage by specifying a "
|
||||
f"`SyncConfig`. The checkpoint may be available on a different "
|
||||
f"node - please check this location on worker nodes: {best_path}"
|
||||
)
|
||||
if return_path:
|
||||
return best_path
|
||||
return None
|
||||
|
||||
def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
|
||||
|
|
|
@ -1,172 +0,0 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ray.air.checkpoint import (
|
||||
Checkpoint,
|
||||
_get_local_path,
|
||||
_get_external_path,
|
||||
)
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
@Deprecated
|
||||
class _TrialCheckpoint(os.PathLike):
|
||||
def __init__(
|
||||
self, local_path: Optional[str] = None, cloud_path: Optional[str] = None
|
||||
):
|
||||
self._local_path = local_path
|
||||
self._cloud_path_tcp = cloud_path
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
return self._local_path
|
||||
|
||||
@local_path.setter
|
||||
def local_path(self, path: str):
|
||||
self._local_path = path
|
||||
|
||||
@property
|
||||
def cloud_path(self):
|
||||
return self._cloud_path_tcp
|
||||
|
||||
@cloud_path.setter
|
||||
def cloud_path(self, path: str):
|
||||
self._cloud_path_tcp = path
|
||||
|
||||
# The following magic methods are implemented to keep backwards
|
||||
# compatibility with the old path-based return values.
|
||||
def __str__(self):
|
||||
return self.local_path or self.cloud_path
|
||||
|
||||
def __fspath__(self):
|
||||
return self.local_path
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.local_path == other
|
||||
elif isinstance(other, TrialCheckpoint):
|
||||
return (
|
||||
self.local_path == other.local_path
|
||||
and self.cloud_path == other.cloud_path
|
||||
)
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.local_path + other
|
||||
raise NotImplementedError
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, str):
|
||||
return other + self.local_path
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<TrialCheckpoint "
|
||||
f"local_path={self.local_path}, "
|
||||
f"cloud_path={self.cloud_path}"
|
||||
f">"
|
||||
)
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
||||
def __init__(
|
||||
self,
|
||||
local_path: Optional[str] = None,
|
||||
cloud_path: Optional[str] = None,
|
||||
):
|
||||
_TrialCheckpoint.__init__(self)
|
||||
|
||||
# Checkpoint does not allow empty data, but TrialCheckpoint
|
||||
# did. To keep backwards compatibility, we use a placeholder URI
|
||||
# here, and manually set self._uri and self._local_dir later.
|
||||
PLACEHOLDER = "s3://placeholder"
|
||||
Checkpoint.__init__(self, uri=PLACEHOLDER)
|
||||
|
||||
# Reset local variables
|
||||
self._uri = None
|
||||
self._local_path = None
|
||||
|
||||
self._cloud_path_tcp = None
|
||||
self._local_path_tcp = None
|
||||
|
||||
locations = set()
|
||||
if local_path:
|
||||
# Add _tcp to not conflict with Checkpoint._local_path
|
||||
self._local_path_tcp = local_path
|
||||
if os.path.exists(local_path):
|
||||
self._local_path = local_path
|
||||
locations.add(local_path)
|
||||
if cloud_path:
|
||||
self._cloud_path_tcp = cloud_path
|
||||
self._uri = cloud_path
|
||||
locations.add(cloud_path)
|
||||
self._locations = locations
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
local_path = _get_local_path(self._local_path)
|
||||
if not local_path:
|
||||
for candidate in self._locations:
|
||||
local_path = _get_local_path(candidate)
|
||||
if local_path:
|
||||
break
|
||||
return local_path or self._local_path_tcp
|
||||
|
||||
@local_path.setter
|
||||
def local_path(self, path: str):
|
||||
self._local_path = path
|
||||
if not path or not os.path.exists(path):
|
||||
return
|
||||
self._locations.add(path)
|
||||
|
||||
@property
|
||||
def cloud_path(self):
|
||||
cloud_path = _get_external_path(self._uri)
|
||||
if not cloud_path:
|
||||
for candidate in self._locations:
|
||||
cloud_path = _get_external_path(candidate)
|
||||
if cloud_path:
|
||||
break
|
||||
return cloud_path or self._cloud_path_tcp
|
||||
|
||||
@cloud_path.setter
|
||||
def cloud_path(self, path: str):
|
||||
self._cloud_path_tcp = path
|
||||
if not self._uri:
|
||||
self._uri = path
|
||||
self._locations.add(path)
|
||||
|
||||
def download(
|
||||
self,
|
||||
cloud_path: Optional[str] = None,
|
||||
local_path: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> str:
|
||||
# Deprecated: Remove whole class in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"`checkpoint.download()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_directory()` instead."
|
||||
)
|
||||
|
||||
def upload(
|
||||
self,
|
||||
cloud_path: Optional[str] = None,
|
||||
local_path: Optional[str] = None,
|
||||
clean_before: bool = False,
|
||||
):
|
||||
# Deprecated: Remove whole class in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"`checkpoint.upload()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_uri()` instead."
|
||||
)
|
||||
|
||||
def save(self, path: Optional[str] = None, force_download: bool = False):
|
||||
# Deprecated: Remove whole class in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"`checkpoint.save()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_directory()` or"
|
||||
"`checkpoint.to_uri()` instead."
|
||||
)
|
|
@ -68,14 +68,16 @@ def train_convnet(config):
|
|||
|
||||
def test_best_model(analysis):
|
||||
"""Test the best model given output of tune.run"""
|
||||
best_checkpoint_path = analysis.best_checkpoint
|
||||
best_model = ConvNet()
|
||||
best_checkpoint = torch.load(os.path.join(best_checkpoint_path, "checkpoint.pt"))
|
||||
best_model.load_state_dict(best_checkpoint["model_state_dict"])
|
||||
# Note that test only runs on a small random set of the test data, thus the
|
||||
# accuracy may be different from metrics shown in tuning process.
|
||||
test_acc = test(best_model, get_data_loaders()[1])
|
||||
print("best model accuracy: ", test_acc)
|
||||
with analysis.best_checkpoint.as_directory() as best_checkpoint_path:
|
||||
best_model = ConvNet()
|
||||
best_checkpoint = torch.load(
|
||||
os.path.join(best_checkpoint_path, "checkpoint.pt")
|
||||
)
|
||||
best_model.load_state_dict(best_checkpoint["model_state_dict"])
|
||||
# Note that test only runs on a small random set of the test data, thus the
|
||||
# accuracy may be different from metrics shown in tuning process.
|
||||
test_acc = test(best_model, get_data_loaders()[1])
|
||||
print("best model accuracy: ", test_acc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -62,7 +62,8 @@ def train_breast_cancer_cv(config: dict):
|
|||
|
||||
def get_best_model_checkpoint(analysis):
|
||||
best_bst = xgb.Booster()
|
||||
best_bst.load_model(os.path.join(analysis.best_checkpoint, "model.xgb"))
|
||||
with analysis.best_checkpoint.as_directory() as checkpoint_dir:
|
||||
best_bst.load_model(os.path.join(checkpoint_dir, "model.xgb"))
|
||||
accuracy = 1.0 - analysis.best_result["test-error"]
|
||||
print(f"Best model parameters: {analysis.best_config}")
|
||||
print(f"Best model total accuracy: {accuracy:.4f}")
|
||||
|
|
|
@ -167,7 +167,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
best_checkpoint = self.ea.get_best_checkpoint(
|
||||
best_trial, self.metric, mode="max"
|
||||
)
|
||||
assert expected_path == best_checkpoint
|
||||
assert expected_path == best_checkpoint._local_path
|
||||
|
||||
def testGetBestCheckpointNan(self):
|
||||
"""Tests if nan values are excluded from best checkpoint."""
|
||||
|
@ -196,7 +196,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
],
|
||||
key=lambda x: x[1],
|
||||
)[0]
|
||||
assert best_checkpoint == expected_checkpoint_no_nan
|
||||
assert best_checkpoint._local_path == expected_checkpoint_no_nan
|
||||
|
||||
def testGetLastCheckpoint(self):
|
||||
# one more experiment with 2 iterations
|
||||
|
@ -213,7 +213,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
)
|
||||
|
||||
# check if it's loaded correctly
|
||||
last_checkpoint = new_ea.get_last_checkpoint().local_path
|
||||
last_checkpoint = new_ea.get_last_checkpoint()._local_path
|
||||
assert self.test_path in last_checkpoint
|
||||
assert "checkpoint_000002" in last_checkpoint
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ def test_result_grid(ray_start_2_cpus):
|
|||
f.write(json.dumps({"step": 0}))
|
||||
|
||||
analysis = tune.run(f, config={"a": 1})
|
||||
analysis._legacy_checkpoint = False
|
||||
result_grid = ResultGrid(analysis)
|
||||
result = result_grid[0]
|
||||
assert isinstance(result.checkpoint, Checkpoint)
|
||||
|
@ -98,7 +97,6 @@ def test_result_grid_no_checkpoint(ray_start_2_cpus):
|
|||
pass
|
||||
|
||||
analysis = tune.run(f)
|
||||
analysis._legacy_checkpoint = False
|
||||
result_grid = ResultGrid(analysis)
|
||||
result = result_grid[0]
|
||||
assert result.checkpoint is None
|
||||
|
@ -213,7 +211,6 @@ def test_result_grid_df(ray_start_2_cpus):
|
|||
tune.report(metric=config["nested"]["param"] * 3)
|
||||
|
||||
analysis = tune.run(f, config={"nested": {"param": tune.grid_search([1, 2])}})
|
||||
analysis._legacy_checkpoint = False
|
||||
result_grid = ResultGrid(analysis)
|
||||
|
||||
assert len(result_grid) == 2
|
||||
|
|
|
@ -60,7 +60,7 @@ class TrainableUtilTest(unittest.TestCase):
|
|||
default_mode="max",
|
||||
)
|
||||
df = a.dataframe()
|
||||
checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0]).local_path
|
||||
checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0])._local_path
|
||||
assert checkpoint_dir.endswith("/checkpoint_000001/")
|
||||
|
||||
def testFindCheckpointDir(self):
|
||||
|
|
|
@ -17,7 +17,6 @@ from ray.train.torch import TorchTrainer
|
|||
from ray.train.trainer import BaseTrainer
|
||||
from ray.train.xgboost import XGBoostTrainer
|
||||
from ray.tune import Callback, TuneError
|
||||
from ray.tune.cloud import TrialCheckpoint
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.tune_config import TuneConfig
|
||||
from ray.tune.tuner import Tuner
|
||||
|
@ -128,7 +127,6 @@ class TunerTest(unittest.TestCase):
|
|||
_tuner_kwargs={"max_concurrent_trials": 1},
|
||||
)
|
||||
results = tuner.fit()
|
||||
assert not isinstance(results.get_best_result().checkpoint, TrialCheckpoint)
|
||||
assert len(results) == 4
|
||||
|
||||
def test_tuner_with_xgboost_trainer_driver_fail_and_resume(self):
|
||||
|
|
|
@ -516,6 +516,13 @@ class FunctionTrainable(Trainable):
|
|||
# as a new checkpoint.
|
||||
self._status_reporter.set_checkpoint(checkpoint, is_new=False)
|
||||
|
||||
def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
|
||||
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
|
||||
self.logdir
|
||||
)
|
||||
checkpoint.to_directory(self.temp_checkpoint_dir)
|
||||
self.restore(self.temp_checkpoint_dir)
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
|
||||
self.logdir
|
||||
|
|
|
@ -17,7 +17,6 @@ from ray.air.checkpoint import (
|
|||
Checkpoint,
|
||||
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY,
|
||||
)
|
||||
from ray.tune.cloud import TrialCheckpoint
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.result import (
|
||||
DEBUG_METRICS,
|
||||
|
@ -553,7 +552,18 @@ class Trainable:
|
|||
shutil.rmtree(temp_container_dir)
|
||||
return obj_ref
|
||||
|
||||
def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None):
|
||||
def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
|
||||
with checkpoint.as_directory() as converted_checkpoint_path:
|
||||
return self.restore(
|
||||
checkpoint_path=converted_checkpoint_path,
|
||||
checkpoint_node_ip=None,
|
||||
)
|
||||
|
||||
def restore(
|
||||
self,
|
||||
checkpoint_path: Union[str, Checkpoint],
|
||||
checkpoint_node_ip: Optional[str] = None,
|
||||
):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
|
@ -585,9 +595,9 @@ class Trainable:
|
|||
on cloud storage.
|
||||
|
||||
"""
|
||||
# Ensure TrialCheckpoints are converted
|
||||
if isinstance(checkpoint_path, TrialCheckpoint):
|
||||
checkpoint_path = checkpoint_path.local_path
|
||||
# Ensure Checkpoints are converted
|
||||
if isinstance(checkpoint_path, Checkpoint):
|
||||
return self._restore_from_checkpoint_obj(checkpoint_path)
|
||||
|
||||
if not self._maybe_load_from_cloud(checkpoint_path) and (
|
||||
# If a checkpoint source IP is given
|
||||
|
@ -662,15 +672,15 @@ class Trainable:
|
|||
with checkpoint.as_directory() as checkpoint_path:
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def delete_checkpoint(self, checkpoint_path: str):
|
||||
def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]):
|
||||
"""Deletes local copy of checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint.
|
||||
"""
|
||||
# Ensure TrialCheckpoints are converted
|
||||
if isinstance(checkpoint_path, TrialCheckpoint):
|
||||
checkpoint_path = checkpoint_path.local_path
|
||||
# Ensure Checkpoints are converted
|
||||
if isinstance(checkpoint_path, Checkpoint) and checkpoint_path._local_path:
|
||||
checkpoint_path = checkpoint_path._local_path
|
||||
|
||||
try:
|
||||
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|||
import ray
|
||||
from ray.air import Checkpoint
|
||||
from ray.tune.result import NODE_IP
|
||||
from ray.util import log_once
|
||||
from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
|
||||
from ray.util.ml_utils.util import is_nan
|
||||
|
||||
|
@ -129,7 +130,20 @@ class _TrackedCheckpoint:
|
|||
checkpoint_data = ray.get(checkpoint_data)
|
||||
|
||||
if isinstance(checkpoint_data, str):
|
||||
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
|
||||
try:
|
||||
checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
|
||||
except FileNotFoundError:
|
||||
if log_once("checkpoint_not_available"):
|
||||
logger.error(
|
||||
f"The requested checkpoint is not available on this node, "
|
||||
f"most likely because you are using Ray client or disabled "
|
||||
f"checkpoint synchronization. To avoid this, enable checkpoint "
|
||||
f"synchronization to cloud storage by specifying a "
|
||||
f"`SyncConfig`. The checkpoint may be available on a different "
|
||||
f"node - please check this location on worker nodes: "
|
||||
f"{checkpoint_data}"
|
||||
)
|
||||
return None
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_dir)
|
||||
elif isinstance(checkpoint_data, bytes):
|
||||
checkpoint = Checkpoint.from_bytes(checkpoint_data)
|
||||
|
|
|
@ -256,8 +256,10 @@ if __name__ == "__main__":
|
|||
analysis = train_mnist(args.smoke_test, num_workers, use_gpu)
|
||||
|
||||
print("Retrieving best model.")
|
||||
best_checkpoint = analysis.best_checkpoint.local_path
|
||||
model = get_remote_model(best_checkpoint)
|
||||
best_checkpoint_path = analysis.get_best_checkpoint(
|
||||
analysis.best_trial, return_path=True
|
||||
)
|
||||
model = get_remote_model(best_checkpoint_path)
|
||||
|
||||
print("Setting up Serve.")
|
||||
setup_serve(model, use_gpu)
|
||||
|
|
Loading…
Add table
Reference in a new issue