mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Train] Move load_checkpoint
to utils (#25940)
Moves load_checkpoint methods from trainer files to util files for consistency and better modularity.
This commit is contained in:
parent
3d6a5450c9
commit
b7d4ae541d
28 changed files with 316 additions and 284 deletions
|
@ -1,6 +1,8 @@
|
|||
from ray.train.huggingface.huggingface_predictor import HuggingFacePredictor
|
||||
from ray.train.huggingface.huggingface_trainer import (
|
||||
HuggingFaceTrainer,
|
||||
)
|
||||
from ray.train.huggingface.utils import (
|
||||
load_checkpoint,
|
||||
)
|
||||
|
||||
|
|
|
@ -5,23 +5,19 @@ import tempfile
|
|||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
import transformers.trainer
|
||||
import transformers.training_args
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
|
||||
from ray import train
|
||||
from ray.air import session
|
||||
from ray.air._internal.checkpointing import (
|
||||
load_preprocessor_from_dir,
|
||||
save_preprocessor_to_dir,
|
||||
)
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
|
||||
from ray.train.constants import (
|
||||
|
@ -31,7 +27,7 @@ from ray.train.constants import (
|
|||
TUNE_CHECKPOINT_ID,
|
||||
)
|
||||
from ray.train.data_parallel_trainer import _DataParallelCheckpointManager
|
||||
from ray.train.huggingface.huggingface_utils import (
|
||||
from ray.train.huggingface._huggingface_utils import (
|
||||
CHECKPOINT_PATH_ON_NODE_KEY,
|
||||
NODE_IP_KEY,
|
||||
TrainReportCallback,
|
||||
|
@ -407,65 +403,6 @@ class HuggingFaceTrainer(TorchTrainer):
|
|||
return ret
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
|
||||
tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None,
|
||||
*,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**pretrained_model_kwargs,
|
||||
) -> Tuple[
|
||||
Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module],
|
||||
transformers.training_args.TrainingArguments,
|
||||
Optional[transformers.PreTrainedTokenizer],
|
||||
Optional["Preprocessor"],
|
||||
]:
|
||||
"""Load a Checkpoint from ``HuggingFaceTrainer``.
|
||||
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``HuggingFaceTrainer`` run.
|
||||
model: Either a ``transformers.PreTrainedModel`` class
|
||||
(eg. ``AutoModelForCausalLM``), or a PyTorch model to load the
|
||||
weights to. This should be the same model used for training.
|
||||
tokenizer: A ``transformers.PreTrainedTokenizer`` class to load
|
||||
the model tokenizer to. If not specified, the tokenizer will
|
||||
not be loaded. Will throw an exception if specified, but no
|
||||
tokenizer was found in the checkpoint.
|
||||
tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained``
|
||||
call. Ignored if ``tokenizer`` is None.
|
||||
**pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained``
|
||||
call. Ignored if ``model`` is not a ``transformers.PreTrainedModel``
|
||||
class.
|
||||
|
||||
Returns:
|
||||
The model, ``TrainingArguments``, tokenizer and AIR preprocessor
|
||||
contained within. Those can be used to initialize a ``transformers.Trainer``
|
||||
object locally.
|
||||
"""
|
||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
if isinstance(model, torch.nn.Module):
|
||||
state_dict = torch.load(
|
||||
os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu"
|
||||
)
|
||||
model = load_torch_model(saved_model=state_dict, model_definition=model)
|
||||
else:
|
||||
model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs)
|
||||
if tokenizer:
|
||||
tokenizer = tokenizer.from_pretrained(checkpoint_path, **tokenizer_kwargs)
|
||||
training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME)
|
||||
if os.path.exists(training_args_path):
|
||||
with open(training_args_path, "rb") as f:
|
||||
training_args = torch.load(f, map_location="cpu")
|
||||
else:
|
||||
training_args = None
|
||||
return model, training_args, tokenizer, preprocessor
|
||||
|
||||
|
||||
def _huggingface_train_loop_per_worker(config):
|
||||
"""Per-worker training loop for HuggingFace Transformers."""
|
||||
trainer_init_per_worker = config.pop("_trainer_init_per_worker")
|
||||
|
|
77
python/ray/train/huggingface/utils.py
Normal file
77
python/ray/train/huggingface/utils.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
import transformers.trainer
|
||||
import transformers.training_args
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
|
||||
from ray.air._internal.checkpointing import (
|
||||
load_preprocessor_from_dir,
|
||||
)
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
|
||||
tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None,
|
||||
*,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**pretrained_model_kwargs,
|
||||
) -> Tuple[
|
||||
Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module],
|
||||
transformers.training_args.TrainingArguments,
|
||||
Optional[transformers.PreTrainedTokenizer],
|
||||
Optional["Preprocessor"],
|
||||
]:
|
||||
"""Load a Checkpoint from ``HuggingFaceTrainer``.
|
||||
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``HuggingFaceTrainer`` run.
|
||||
model: Either a ``transformers.PreTrainedModel`` class
|
||||
(eg. ``AutoModelForCausalLM``), or a PyTorch model to load the
|
||||
weights to. This should be the same model used for training.
|
||||
tokenizer: A ``transformers.PreTrainedTokenizer`` class to load
|
||||
the model tokenizer to. If not specified, the tokenizer will
|
||||
not be loaded. Will throw an exception if specified, but no
|
||||
tokenizer was found in the checkpoint.
|
||||
tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained``
|
||||
call. Ignored if ``tokenizer`` is None.
|
||||
**pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained``
|
||||
call. Ignored if ``model`` is not a ``transformers.PreTrainedModel``
|
||||
class.
|
||||
|
||||
Returns:
|
||||
The model, ``TrainingArguments``, tokenizer and AIR preprocessor
|
||||
contained within. Those can be used to initialize a ``transformers.Trainer``
|
||||
object locally.
|
||||
"""
|
||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
if isinstance(model, torch.nn.Module):
|
||||
state_dict = torch.load(
|
||||
os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu"
|
||||
)
|
||||
model = load_torch_model(saved_model=state_dict, model_definition=model)
|
||||
else:
|
||||
model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs)
|
||||
if tokenizer:
|
||||
tokenizer = tokenizer.from_pretrained(checkpoint_path, **tokenizer_kwargs)
|
||||
training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME)
|
||||
if os.path.exists(training_args_path):
|
||||
with open(training_args_path, "rb") as f:
|
||||
training_args = torch.load(f, map_location="cpu")
|
||||
else:
|
||||
training_args = None
|
||||
return model, training_args, tokenizer, preprocessor
|
|
@ -1,6 +1,6 @@
|
|||
from ray.train.lightgbm.lightgbm_predictor import LightGBMPredictor
|
||||
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer, load_checkpoint
|
||||
from ray.train.lightgbm.utils import to_air_checkpoint
|
||||
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer
|
||||
from ray.train.lightgbm.utils import to_air_checkpoint, load_checkpoint
|
||||
|
||||
__all__ = [
|
||||
"LightGBMPredictor",
|
||||
|
|
|
@ -5,7 +5,7 @@ import pandas as pd
|
|||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.lightgbm.lightgbm_trainer import load_checkpoint
|
||||
from ray.train.lightgbm.utils import load_checkpoint
|
||||
from ray.train.predictor import Predictor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
|
||||
import os
|
||||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.gbdt_trainer import GBDTTrainer
|
||||
from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.train.constants import MODEL_KEY
|
||||
from ray.train.lightgbm.utils import load_checkpoint
|
||||
|
||||
import lightgbm
|
||||
import lightgbm_ray
|
||||
|
@ -85,25 +83,3 @@ class LightGBMTrainer(GBDTTrainer):
|
|||
self, checkpoint: Checkpoint
|
||||
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
|
||||
return load_checkpoint(checkpoint)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``LightGBMTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``LightGBMTrainer`` run.
|
||||
|
||||
Returns:
|
||||
The model and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
lgbm_model = lightgbm.Booster(
|
||||
model_file=os.path.join(checkpoint_path, MODEL_KEY)
|
||||
)
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return lgbm_model, preprocessor
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import lightgbm
|
||||
|
||||
from ray.air._internal.checkpointing import save_preprocessor_to_dir
|
||||
from ray.air._internal.checkpointing import (
|
||||
save_preprocessor_to_dir,
|
||||
load_preprocessor_from_dir,
|
||||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY
|
||||
|
||||
|
@ -34,3 +37,25 @@ def to_air_checkpoint(
|
|||
checkpoint = Checkpoint.from_directory(path)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``LightGBMTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``LightGBMTrainer`` run.
|
||||
|
||||
Returns:
|
||||
The model and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
lgbm_model = lightgbm.Booster(
|
||||
model_file=os.path.join(checkpoint_path, MODEL_KEY)
|
||||
)
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return lgbm_model, preprocessor
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from ray.train.rl.rl_predictor import RLPredictor
|
||||
from ray.train.rl.rl_trainer import RLTrainer, load_checkpoint
|
||||
from ray.train.rl.rl_trainer import RLTrainer
|
||||
from ray.train.rl.utils import load_checkpoint
|
||||
|
||||
__all__ = ["RLPredictor", "RLTrainer", "load_checkpoint"]
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.air.constants import TENSOR_COLUMN_NAME
|
|||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import EnvType
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.train.rl.rl_trainer import load_checkpoint
|
||||
from ray.train.rl.utils import load_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
import inspect
|
||||
import os
|
||||
from typing import Optional, Dict, Tuple, Type, Union, Callable, Any, TYPE_CHECKING
|
||||
from typing import Optional, Dict, Type, Union, Callable, Any, TYPE_CHECKING
|
||||
|
||||
import ray.cloudpickle as cpickle
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.config import ScalingConfig, RunConfig
|
||||
from ray.train.trainer import BaseTrainer, GenDataset
|
||||
from ray.air._internal.checkpointing import (
|
||||
load_preprocessor_from_dir,
|
||||
save_preprocessor_to_dir,
|
||||
)
|
||||
from ray.rllib.algorithms.algorithm import Algorithm as RLlibAlgo
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import PartialAlgorithmConfigDict, EnvType
|
||||
from ray.tune import Trainable, PlacementGroupFactory
|
||||
from ray.tune.logger import Logger
|
||||
|
@ -20,13 +18,11 @@ from ray.tune.resources import Resources
|
|||
from ray.tune.syncer import Syncer
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
from ray.train.rl.utils import RL_TRAINER_CLASS_FILE, RL_CONFIG_FILE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
|
||||
RL_CONFIG_FILE = "config.pkl"
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class RLTrainer(BaseTrainer):
|
||||
|
@ -253,64 +249,3 @@ class RLTrainer(BaseTrainer):
|
|||
|
||||
AIRRLTrainer.__name__ = f"AIR{rllib_trainer.__name__}"
|
||||
return AIRRLTrainer
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
env: Optional[EnvType] = None,
|
||||
) -> Tuple[Policy, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``RLTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the policy and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``RLTrainer`` run.
|
||||
env: Optional environment to instantiate the trainer with. If not given,
|
||||
it is parsed from the saved trainer configuration instead.
|
||||
|
||||
Returns:
|
||||
The policy and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE)
|
||||
config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)
|
||||
|
||||
if not os.path.exists(trainer_class_path):
|
||||
raise ValueError(
|
||||
f"RLPredictor only works with checkpoints created by "
|
||||
f"RLTrainer. The checkpoint you specified is missing the "
|
||||
f"`{RL_TRAINER_CLASS_FILE}` file."
|
||||
)
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(
|
||||
f"RLPredictor only works with checkpoints created by "
|
||||
f"RLTrainer. The checkpoint you specified is missing the "
|
||||
f"`{RL_CONFIG_FILE}` file."
|
||||
)
|
||||
|
||||
with open(trainer_class_path, "rb") as fp:
|
||||
trainer_cls = cpickle.load(fp)
|
||||
|
||||
with open(config_path, "rb") as fp:
|
||||
config = cpickle.load(fp)
|
||||
|
||||
checkpoint_data_path = None
|
||||
for file in os.listdir(checkpoint_path):
|
||||
if file.startswith("checkpoint") and not file.endswith(".tune_metadata"):
|
||||
checkpoint_data_path = os.path.join(checkpoint_path, file)
|
||||
|
||||
if not checkpoint_data_path:
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint data in RLlib checkpoint. "
|
||||
f"Found files: {list(os.listdir(checkpoint_path))}"
|
||||
)
|
||||
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
config.get("evaluation_config", {}).pop("in_evaluation", None)
|
||||
trainer = trainer_cls(config=config, env=env)
|
||||
trainer.restore(checkpoint_data_path)
|
||||
|
||||
policy = trainer.get_policy()
|
||||
return policy, preprocessor
|
||||
|
|
77
python/ray/train/rl/utils.py
Normal file
77
python/ray/train/rl/utils.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import ray.cloudpickle as cpickle
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air._internal.checkpointing import (
|
||||
load_preprocessor_from_dir,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import EnvType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
|
||||
RL_CONFIG_FILE = "config.pkl"
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
env: Optional[EnvType] = None,
|
||||
) -> Tuple[Policy, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``RLTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the policy and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``RLTrainer`` run.
|
||||
env: Optional environment to instantiate the trainer with. If not given,
|
||||
it is parsed from the saved trainer configuration instead.
|
||||
|
||||
Returns:
|
||||
The policy and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE)
|
||||
config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)
|
||||
|
||||
if not os.path.exists(trainer_class_path):
|
||||
raise ValueError(
|
||||
f"RLPredictor only works with checkpoints created by "
|
||||
f"RLTrainer. The checkpoint you specified is missing the "
|
||||
f"`{RL_TRAINER_CLASS_FILE}` file."
|
||||
)
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(
|
||||
f"RLPredictor only works with checkpoints created by "
|
||||
f"RLTrainer. The checkpoint you specified is missing the "
|
||||
f"`{RL_CONFIG_FILE}` file."
|
||||
)
|
||||
|
||||
with open(trainer_class_path, "rb") as fp:
|
||||
trainer_cls = cpickle.load(fp)
|
||||
|
||||
with open(config_path, "rb") as fp:
|
||||
config = cpickle.load(fp)
|
||||
|
||||
checkpoint_data_path = None
|
||||
for file in os.listdir(checkpoint_path):
|
||||
if file.startswith("checkpoint") and not file.endswith(".tune_metadata"):
|
||||
checkpoint_data_path = os.path.join(checkpoint_path, file)
|
||||
|
||||
if not checkpoint_data_path:
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint data in RLlib checkpoint. "
|
||||
f"Found files: {list(os.listdir(checkpoint_path))}"
|
||||
)
|
||||
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
config.get("evaluation_config", {}).pop("in_evaluation", None)
|
||||
trainer = trainer_cls(config=config, env=env)
|
||||
trainer.restore(checkpoint_data_path)
|
||||
|
||||
policy = trainer.get_policy()
|
||||
return policy, preprocessor
|
|
@ -1,5 +1,5 @@
|
|||
from ray.train.sklearn.sklearn_predictor import SklearnPredictor
|
||||
from ray.train.sklearn.sklearn_trainer import SklearnTrainer, load_checkpoint
|
||||
from ray.train.sklearn.utils import to_air_checkpoint
|
||||
from ray.train.sklearn.sklearn_trainer import SklearnTrainer
|
||||
from ray.train.sklearn.utils import to_air_checkpoint, load_checkpoint
|
||||
|
||||
__all__ = ["SklearnPredictor", "SklearnTrainer", "load_checkpoint", "to_air_checkpoint"]
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.air.checkpoint import Checkpoint
|
|||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.train.sklearn._sklearn_utils import _set_cpu_params
|
||||
from ray.train.sklearn.sklearn_trainer import load_checkpoint
|
||||
from ray.train.sklearn.utils import load_checkpoint
|
||||
from ray.util.joblib import register_ray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
from collections import defaultdict
|
||||
from time import time
|
||||
from traceback import format_exc
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -19,10 +19,8 @@ from sklearn.model_selection._validation import _check_multimetric_scoring, _sco
|
|||
import ray.cloudpickle as cpickle
|
||||
from ray import tune
|
||||
from ray.air._internal.checkpointing import (
|
||||
load_preprocessor_from_dir,
|
||||
save_preprocessor_to_dir,
|
||||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.config import RunConfig, ScalingConfig
|
||||
from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY
|
||||
from ray.train.sklearn._sklearn_utils import _has_cpu_params, _set_cpu_params
|
||||
|
@ -434,25 +432,3 @@ class SklearnTrainer(BaseTrainer):
|
|||
"fit_time": fit_time,
|
||||
}
|
||||
tune.report(**results)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``SklearnTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the estimator and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``SklearnTrainer`` run.
|
||||
|
||||
Returns:
|
||||
The estimator and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
estimator_path = os.path.join(checkpoint_path, MODEL_KEY)
|
||||
with open(estimator_path, "rb") as f:
|
||||
estimator = cpickle.load(f)
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return estimator, preprocessor
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
import ray.cloudpickle as cpickle
|
||||
from ray.air._internal.checkpointing import save_preprocessor_to_dir
|
||||
from ray.air._internal.checkpointing import (
|
||||
save_preprocessor_to_dir,
|
||||
load_preprocessor_from_dir,
|
||||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY
|
||||
|
||||
|
@ -36,3 +39,25 @@ def to_air_checkpoint(
|
|||
checkpoint = Checkpoint.from_directory(path)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``SklearnTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the estimator and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``SklearnTrainer`` run.
|
||||
|
||||
Returns:
|
||||
The estimator and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
estimator_path = os.path.join(checkpoint_path, MODEL_KEY)
|
||||
with open(estimator_path, "rb") as f:
|
||||
estimator = cpickle.load(f)
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return estimator, preprocessor
|
||||
|
|
|
@ -10,9 +10,9 @@ except ModuleNotFoundError:
|
|||
|
||||
from ray.train.tensorflow.config import TensorflowConfig
|
||||
from ray.train.tensorflow.tensorflow_predictor import TensorflowPredictor
|
||||
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer, load_checkpoint
|
||||
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer
|
||||
from ray.train.tensorflow.train_loop_utils import prepare_dataset_shard
|
||||
from ray.train.tensorflow.utils import to_air_checkpoint
|
||||
from ray.train.tensorflow.utils import to_air_checkpoint, load_checkpoint
|
||||
|
||||
__all__ = [
|
||||
"TensorflowConfig",
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Callable, Optional, Dict, Tuple, Type, Union, TYPE_CHECKING
|
||||
import tensorflow as tf
|
||||
from typing import Callable, Optional, Dict, Union, TYPE_CHECKING
|
||||
|
||||
from ray.train.tensorflow.config import TensorflowConfig
|
||||
from ray.train.trainer import GenDataset
|
||||
from ray.train.data_parallel_trainer import DataParallelTrainer, _load_checkpoint
|
||||
from ray.train.data_parallel_trainer import DataParallelTrainer
|
||||
from ray.air.config import ScalingConfig, RunConfig, DatasetConfig
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.util import PublicAPI
|
||||
|
@ -185,27 +184,3 @@ class TensorflowTrainer(DataParallelTrainer):
|
|||
preprocessor=preprocessor,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
|
||||
) -> Tuple[tf.keras.Model, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``TensorflowTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``TensorflowTrainer`` run.
|
||||
model: A callable that returns a TensorFlow Keras model
|
||||
to use, or an instantiated model.
|
||||
Model weights will be loaded from the checkpoint.
|
||||
|
||||
Returns:
|
||||
The model with set weights and AIR preprocessor contained within.
|
||||
"""
|
||||
model_weights, preprocessor = _load_checkpoint(checkpoint, "TensorflowTrainer")
|
||||
if isinstance(model, type) or callable(model):
|
||||
model = model()
|
||||
model.set_weights(model_weights)
|
||||
return model, preprocessor
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union, Callable, Type, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
@ -25,3 +27,27 @@ def to_air_checkpoint(
|
|||
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()}
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
|
||||
) -> Tuple[tf.keras.Model, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``TensorflowTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``TensorflowTrainer`` run.
|
||||
model: A callable that returns a TensorFlow Keras model
|
||||
to use, or an instantiated model.
|
||||
Model weights will be loaded from the checkpoint.
|
||||
|
||||
Returns:
|
||||
The model with set weights and AIR preprocessor contained within.
|
||||
"""
|
||||
model_weights, preprocessor = _load_checkpoint(checkpoint, "TensorflowTrainer")
|
||||
if isinstance(model, type) or callable(model):
|
||||
model = model()
|
||||
model.set_weights(model_weights)
|
||||
return model, preprocessor
|
||||
|
|
|
@ -14,7 +14,7 @@ from transformers.trainer_callback import TrainerState
|
|||
import ray.data
|
||||
from ray.train.batch_predictor import BatchPredictor
|
||||
from ray.train.huggingface import HuggingFacePredictor, HuggingFaceTrainer
|
||||
from ray.train.huggingface.huggingface_utils import TrainReportCallback
|
||||
from ray.train.huggingface._huggingface_utils import TrainReportCallback
|
||||
from ray.train.tests._huggingface_data import train_data, validation_data
|
||||
|
||||
# 16 first rows of tokenized wikitext-2-raw-v1 training & validation
|
||||
|
|
|
@ -9,7 +9,7 @@ except ModuleNotFoundError:
|
|||
|
||||
from ray.train.torch.config import TorchConfig
|
||||
from ray.train.torch.torch_predictor import TorchPredictor
|
||||
from ray.train.torch.torch_trainer import TorchTrainer, load_checkpoint
|
||||
from ray.train.torch.torch_trainer import TorchTrainer
|
||||
from ray.train.torch.train_loop_utils import (
|
||||
TorchWorkerProfiler,
|
||||
accelerate,
|
||||
|
@ -20,7 +20,7 @@ from ray.train.torch.train_loop_utils import (
|
|||
prepare_model,
|
||||
prepare_optimizer,
|
||||
)
|
||||
from ray.train.torch.utils import to_air_checkpoint
|
||||
from ray.train.torch.utils import to_air_checkpoint, load_checkpoint
|
||||
|
||||
__all__ = [
|
||||
"TorchTrainer",
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from ray.air._internal.torch_utils import convert_pandas_to_torch_tensor
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.predictor import DataBatchType, Predictor
|
||||
from ray.train.torch.torch_trainer import load_checkpoint
|
||||
from ray.train.torch.utils import load_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
|
||||
from ray.train.data_parallel_trainer import DataParallelTrainer, _load_checkpoint
|
||||
from ray.train.data_parallel_trainer import DataParallelTrainer
|
||||
from ray.train.torch.config import TorchConfig
|
||||
from ray.train.trainer import GenDataset
|
||||
from ray.util import PublicAPI
|
||||
|
@ -196,24 +193,3 @@ class TorchTrainer(DataParallelTrainer):
|
|||
preprocessor=preprocessor,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
|
||||
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``TorchTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``TorchTrainer`` run.
|
||||
model: If the checkpoint contains a model state dict, and not
|
||||
the model itself, then the state dict will be loaded to this
|
||||
``model``.
|
||||
|
||||
Returns:
|
||||
The model with set weights and AIR preprocessor contained within.
|
||||
"""
|
||||
saved_model, preprocessor = _load_checkpoint(checkpoint, "TorchTrainer")
|
||||
model = load_torch_model(saved_model=saved_model, model_definition=model)
|
||||
return model, preprocessor
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
@ -25,3 +27,24 @@ def to_air_checkpoint(
|
|||
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model}
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
|
||||
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``TorchTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``TorchTrainer`` run.
|
||||
model: If the checkpoint contains a model state dict, and not
|
||||
the model itself, then the state dict will be loaded to this
|
||||
``model``.
|
||||
|
||||
Returns:
|
||||
The model with set weights and AIR preprocessor contained within.
|
||||
"""
|
||||
saved_model, preprocessor = _load_checkpoint(checkpoint, "TorchTrainer")
|
||||
model = load_torch_model(saved_model=saved_model, model_definition=model)
|
||||
return model, preprocessor
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from ray.train.xgboost.utils import to_air_checkpoint
|
||||
from ray.train.xgboost.utils import load_checkpoint, to_air_checkpoint
|
||||
from ray.train.xgboost.xgboost_predictor import XGBoostPredictor
|
||||
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer, load_checkpoint
|
||||
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer
|
||||
|
||||
__all__ = ["XGBoostPredictor", "XGBoostTrainer", "load_checkpoint", "to_air_checkpoint"]
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import xgboost
|
||||
|
||||
from ray.air._internal.checkpointing import save_preprocessor_to_dir
|
||||
from ray.air._internal.checkpointing import (
|
||||
save_preprocessor_to_dir,
|
||||
load_preprocessor_from_dir,
|
||||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY
|
||||
|
||||
|
@ -34,3 +37,24 @@ def to_air_checkpoint(
|
|||
checkpoint = Checkpoint.from_directory(path)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``XGBoostTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``XGBoostTrainer`` run.
|
||||
|
||||
Returns:
|
||||
The model and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
xgb_model = xgboost.Booster()
|
||||
xgb_model.load_model(os.path.join(checkpoint_path, MODEL_KEY))
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return xgb_model, preprocessor
|
||||
|
|
|
@ -6,7 +6,7 @@ import xgboost
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.train.xgboost.xgboost_trainer import load_checkpoint
|
||||
from ray.train.xgboost.utils import load_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.gbdt_trainer import GBDTTrainer
|
||||
from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.train.constants import MODEL_KEY
|
||||
from ray.train.xgboost.utils import load_checkpoint
|
||||
|
||||
import xgboost
|
||||
import xgboost_ray
|
||||
|
@ -75,24 +73,3 @@ class XGBoostTrainer(GBDTTrainer):
|
|||
self, checkpoint: Checkpoint
|
||||
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
|
||||
return load_checkpoint(checkpoint)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
|
||||
"""Load a Checkpoint from ``XGBoostTrainer``.
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint to load the model and
|
||||
preprocessor from. It is expected to be from the result of a
|
||||
``XGBoostTrainer`` run.
|
||||
|
||||
Returns:
|
||||
The model and AIR preprocessor contained within.
|
||||
"""
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
xgb_model = xgboost.Booster()
|
||||
xgb_model.load_model(os.path.join(checkpoint_path, MODEL_KEY))
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return xgb_model, preprocessor
|
||||
|
|
Loading…
Add table
Reference in a new issue