[Train] Move load_checkpoint to utils (#25940)

Moves load_checkpoint methods from trainer files to util files for consistency and better modularity.
This commit is contained in:
Antoni Baum 2022-06-21 22:03:56 +02:00 committed by GitHub
parent 3d6a5450c9
commit b7d4ae541d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 316 additions and 284 deletions

View file

@ -1,6 +1,8 @@
from ray.train.huggingface.huggingface_predictor import HuggingFacePredictor
from ray.train.huggingface.huggingface_trainer import (
HuggingFaceTrainer,
)
from ray.train.huggingface.utils import (
load_checkpoint,
)

View file

@ -5,23 +5,19 @@ import tempfile
import warnings
from distutils.version import LooseVersion
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type
import torch
import transformers
import transformers.modeling_utils
import transformers.trainer
import transformers.training_args
from torch.utils.data import Dataset as TorchDataset
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from ray import train
from ray.air import session
from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
save_preprocessor_to_dir,
)
from ray.air._internal.torch_utils import load_torch_model
from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.train.constants import (
@ -31,7 +27,7 @@ from ray.train.constants import (
TUNE_CHECKPOINT_ID,
)
from ray.train.data_parallel_trainer import _DataParallelCheckpointManager
from ray.train.huggingface.huggingface_utils import (
from ray.train.huggingface._huggingface_utils import (
CHECKPOINT_PATH_ON_NODE_KEY,
NODE_IP_KEY,
TrainReportCallback,
@ -407,65 +403,6 @@ class HuggingFaceTrainer(TorchTrainer):
return ret
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None,
*,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**pretrained_model_kwargs,
) -> Tuple[
Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module],
transformers.training_args.TrainingArguments,
Optional[transformers.PreTrainedTokenizer],
Optional["Preprocessor"],
]:
"""Load a Checkpoint from ``HuggingFaceTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``HuggingFaceTrainer`` run.
model: Either a ``transformers.PreTrainedModel`` class
(eg. ``AutoModelForCausalLM``), or a PyTorch model to load the
weights to. This should be the same model used for training.
tokenizer: A ``transformers.PreTrainedTokenizer`` class to load
the model tokenizer to. If not specified, the tokenizer will
not be loaded. Will throw an exception if specified, but no
tokenizer was found in the checkpoint.
tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained``
call. Ignored if ``tokenizer`` is None.
**pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained``
call. Ignored if ``model`` is not a ``transformers.PreTrainedModel``
class.
Returns:
The model, ``TrainingArguments``, tokenizer and AIR preprocessor
contained within. Those can be used to initialize a ``transformers.Trainer``
object locally.
"""
tokenizer_kwargs = tokenizer_kwargs or {}
with checkpoint.as_directory() as checkpoint_path:
preprocessor = load_preprocessor_from_dir(checkpoint_path)
if isinstance(model, torch.nn.Module):
state_dict = torch.load(
os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu"
)
model = load_torch_model(saved_model=state_dict, model_definition=model)
else:
model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs)
if tokenizer:
tokenizer = tokenizer.from_pretrained(checkpoint_path, **tokenizer_kwargs)
training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME)
if os.path.exists(training_args_path):
with open(training_args_path, "rb") as f:
training_args = torch.load(f, map_location="cpu")
else:
training_args = None
return model, training_args, tokenizer, preprocessor
def _huggingface_train_loop_per_worker(config):
"""Per-worker training loop for HuggingFace Transformers."""
trainer_init_per_worker = config.pop("_trainer_init_per_worker")

View file

@ -0,0 +1,77 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
import torch
import transformers
import transformers.modeling_utils
import transformers.trainer
import transformers.training_args
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
)
from ray.air._internal.torch_utils import load_torch_model
from ray.air.checkpoint import Checkpoint
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None,
*,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**pretrained_model_kwargs,
) -> Tuple[
Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module],
transformers.training_args.TrainingArguments,
Optional[transformers.PreTrainedTokenizer],
Optional["Preprocessor"],
]:
"""Load a Checkpoint from ``HuggingFaceTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``HuggingFaceTrainer`` run.
model: Either a ``transformers.PreTrainedModel`` class
(eg. ``AutoModelForCausalLM``), or a PyTorch model to load the
weights to. This should be the same model used for training.
tokenizer: A ``transformers.PreTrainedTokenizer`` class to load
the model tokenizer to. If not specified, the tokenizer will
not be loaded. Will throw an exception if specified, but no
tokenizer was found in the checkpoint.
tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained``
call. Ignored if ``tokenizer`` is None.
**pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained``
call. Ignored if ``model`` is not a ``transformers.PreTrainedModel``
class.
Returns:
The model, ``TrainingArguments``, tokenizer and AIR preprocessor
contained within. Those can be used to initialize a ``transformers.Trainer``
object locally.
"""
tokenizer_kwargs = tokenizer_kwargs or {}
with checkpoint.as_directory() as checkpoint_path:
preprocessor = load_preprocessor_from_dir(checkpoint_path)
if isinstance(model, torch.nn.Module):
state_dict = torch.load(
os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu"
)
model = load_torch_model(saved_model=state_dict, model_definition=model)
else:
model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs)
if tokenizer:
tokenizer = tokenizer.from_pretrained(checkpoint_path, **tokenizer_kwargs)
training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME)
if os.path.exists(training_args_path):
with open(training_args_path, "rb") as f:
training_args = torch.load(f, map_location="cpu")
else:
training_args = None
return model, training_args, tokenizer, preprocessor

View file

@ -1,6 +1,6 @@
from ray.train.lightgbm.lightgbm_predictor import LightGBMPredictor
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer, load_checkpoint
from ray.train.lightgbm.utils import to_air_checkpoint
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer
from ray.train.lightgbm.utils import to_air_checkpoint, load_checkpoint
__all__ = [
"LightGBMPredictor",

View file

@ -5,7 +5,7 @@ import pandas as pd
from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.lightgbm.lightgbm_trainer import load_checkpoint
from ray.train.lightgbm.utils import load_checkpoint
from ray.train.predictor import Predictor
if TYPE_CHECKING:

View file

@ -1,11 +1,9 @@
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
import os
from ray.air.checkpoint import Checkpoint
from ray.train.gbdt_trainer import GBDTTrainer
from ray.air._internal.checkpointing import load_preprocessor_from_dir
from ray.util.annotations import PublicAPI
from ray.train.constants import MODEL_KEY
from ray.train.lightgbm.utils import load_checkpoint
import lightgbm
import lightgbm_ray
@ -85,25 +83,3 @@ class LightGBMTrainer(GBDTTrainer):
self, checkpoint: Checkpoint
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
return load_checkpoint(checkpoint)
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``LightGBMTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``LightGBMTrainer`` run.
Returns:
The model and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
lgbm_model = lightgbm.Booster(
model_file=os.path.join(checkpoint_path, MODEL_KEY)
)
preprocessor = load_preprocessor_from_dir(checkpoint_path)
return lgbm_model, preprocessor

View file

@ -1,9 +1,12 @@
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple
import lightgbm
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air._internal.checkpointing import (
save_preprocessor_to_dir,
load_preprocessor_from_dir,
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
@ -34,3 +37,25 @@ def to_air_checkpoint(
checkpoint = Checkpoint.from_directory(path)
return checkpoint
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``LightGBMTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``LightGBMTrainer`` run.
Returns:
The model and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
lgbm_model = lightgbm.Booster(
model_file=os.path.join(checkpoint_path, MODEL_KEY)
)
preprocessor = load_preprocessor_from_dir(checkpoint_path)
return lgbm_model, preprocessor

View file

@ -1,4 +1,5 @@
from ray.train.rl.rl_predictor import RLPredictor
from ray.train.rl.rl_trainer import RLTrainer, load_checkpoint
from ray.train.rl.rl_trainer import RLTrainer
from ray.train.rl.utils import load_checkpoint
__all__ = ["RLPredictor", "RLTrainer", "load_checkpoint"]

View file

@ -8,7 +8,7 @@ from ray.air.constants import TENSOR_COLUMN_NAME
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import EnvType
from ray.train.predictor import Predictor
from ray.train.rl.rl_trainer import load_checkpoint
from ray.train.rl.utils import load_checkpoint
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor

View file

@ -1,17 +1,15 @@
import inspect
import os
from typing import Optional, Dict, Tuple, Type, Union, Callable, Any, TYPE_CHECKING
from typing import Optional, Dict, Type, Union, Callable, Any, TYPE_CHECKING
import ray.cloudpickle as cpickle
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig, RunConfig
from ray.train.trainer import BaseTrainer, GenDataset
from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
save_preprocessor_to_dir,
)
from ray.rllib.algorithms.algorithm import Algorithm as RLlibAlgo
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import PartialAlgorithmConfigDict, EnvType
from ray.tune import Trainable, PlacementGroupFactory
from ray.tune.logger import Logger
@ -20,13 +18,11 @@ from ray.tune.resources import Resources
from ray.tune.syncer import Syncer
from ray.util.annotations import PublicAPI
from ray.util.ml_utils.dict import merge_dicts
from ray.train.rl.utils import RL_TRAINER_CLASS_FILE, RL_CONFIG_FILE
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
RL_CONFIG_FILE = "config.pkl"
@PublicAPI(stability="alpha")
class RLTrainer(BaseTrainer):
@ -253,64 +249,3 @@ class RLTrainer(BaseTrainer):
AIRRLTrainer.__name__ = f"AIR{rllib_trainer.__name__}"
return AIRRLTrainer
def load_checkpoint(
checkpoint: Checkpoint,
env: Optional[EnvType] = None,
) -> Tuple[Policy, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``RLTrainer``.
Args:
checkpoint: The checkpoint to load the policy and
preprocessor from. It is expected to be from the result of a
``RLTrainer`` run.
env: Optional environment to instantiate the trainer with. If not given,
it is parsed from the saved trainer configuration instead.
Returns:
The policy and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE)
config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)
if not os.path.exists(trainer_class_path):
raise ValueError(
f"RLPredictor only works with checkpoints created by "
f"RLTrainer. The checkpoint you specified is missing the "
f"`{RL_TRAINER_CLASS_FILE}` file."
)
if not os.path.exists(config_path):
raise ValueError(
f"RLPredictor only works with checkpoints created by "
f"RLTrainer. The checkpoint you specified is missing the "
f"`{RL_CONFIG_FILE}` file."
)
with open(trainer_class_path, "rb") as fp:
trainer_cls = cpickle.load(fp)
with open(config_path, "rb") as fp:
config = cpickle.load(fp)
checkpoint_data_path = None
for file in os.listdir(checkpoint_path):
if file.startswith("checkpoint") and not file.endswith(".tune_metadata"):
checkpoint_data_path = os.path.join(checkpoint_path, file)
if not checkpoint_data_path:
raise ValueError(
f"Could not find checkpoint data in RLlib checkpoint. "
f"Found files: {list(os.listdir(checkpoint_path))}"
)
preprocessor = load_preprocessor_from_dir(checkpoint_path)
config.get("evaluation_config", {}).pop("in_evaluation", None)
trainer = trainer_cls(config=config, env=env)
trainer.restore(checkpoint_data_path)
policy = trainer.get_policy()
return policy, preprocessor

View file

@ -0,0 +1,77 @@
import os
from typing import Optional, Tuple, TYPE_CHECKING
import ray.cloudpickle as cpickle
from ray.air.checkpoint import Checkpoint
from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import EnvType
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
RL_CONFIG_FILE = "config.pkl"
def load_checkpoint(
checkpoint: Checkpoint,
env: Optional[EnvType] = None,
) -> Tuple[Policy, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``RLTrainer``.
Args:
checkpoint: The checkpoint to load the policy and
preprocessor from. It is expected to be from the result of a
``RLTrainer`` run.
env: Optional environment to instantiate the trainer with. If not given,
it is parsed from the saved trainer configuration instead.
Returns:
The policy and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE)
config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)
if not os.path.exists(trainer_class_path):
raise ValueError(
f"RLPredictor only works with checkpoints created by "
f"RLTrainer. The checkpoint you specified is missing the "
f"`{RL_TRAINER_CLASS_FILE}` file."
)
if not os.path.exists(config_path):
raise ValueError(
f"RLPredictor only works with checkpoints created by "
f"RLTrainer. The checkpoint you specified is missing the "
f"`{RL_CONFIG_FILE}` file."
)
with open(trainer_class_path, "rb") as fp:
trainer_cls = cpickle.load(fp)
with open(config_path, "rb") as fp:
config = cpickle.load(fp)
checkpoint_data_path = None
for file in os.listdir(checkpoint_path):
if file.startswith("checkpoint") and not file.endswith(".tune_metadata"):
checkpoint_data_path = os.path.join(checkpoint_path, file)
if not checkpoint_data_path:
raise ValueError(
f"Could not find checkpoint data in RLlib checkpoint. "
f"Found files: {list(os.listdir(checkpoint_path))}"
)
preprocessor = load_preprocessor_from_dir(checkpoint_path)
config.get("evaluation_config", {}).pop("in_evaluation", None)
trainer = trainer_cls(config=config, env=env)
trainer.restore(checkpoint_data_path)
policy = trainer.get_policy()
return policy, preprocessor

View file

@ -1,5 +1,5 @@
from ray.train.sklearn.sklearn_predictor import SklearnPredictor
from ray.train.sklearn.sklearn_trainer import SklearnTrainer, load_checkpoint
from ray.train.sklearn.utils import to_air_checkpoint
from ray.train.sklearn.sklearn_trainer import SklearnTrainer
from ray.train.sklearn.utils import to_air_checkpoint, load_checkpoint
__all__ = ["SklearnPredictor", "SklearnTrainer", "load_checkpoint", "to_air_checkpoint"]

View file

@ -8,7 +8,7 @@ from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.predictor import Predictor
from ray.train.sklearn._sklearn_utils import _set_cpu_params
from ray.train.sklearn.sklearn_trainer import load_checkpoint
from ray.train.sklearn.utils import load_checkpoint
from ray.util.joblib import register_ray
if TYPE_CHECKING:

View file

@ -4,7 +4,7 @@ import warnings
from collections import defaultdict
from time import time
from traceback import format_exc
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, Tuple
import numpy as np
import pandas as pd
@ -19,10 +19,8 @@ from sklearn.model_selection._validation import _check_multimetric_scoring, _sco
import ray.cloudpickle as cpickle
from ray import tune
from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
save_preprocessor_to_dir,
)
from ray.air.checkpoint import Checkpoint
from ray.air.config import RunConfig, ScalingConfig
from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY
from ray.train.sklearn._sklearn_utils import _has_cpu_params, _set_cpu_params
@ -434,25 +432,3 @@ class SklearnTrainer(BaseTrainer):
"fit_time": fit_time,
}
tune.report(**results)
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``SklearnTrainer``.
Args:
checkpoint: The checkpoint to load the estimator and
preprocessor from. It is expected to be from the result of a
``SklearnTrainer`` run.
Returns:
The estimator and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
estimator_path = os.path.join(checkpoint_path, MODEL_KEY)
with open(estimator_path, "rb") as f:
estimator = cpickle.load(f)
preprocessor = load_preprocessor_from_dir(checkpoint_path)
return estimator, preprocessor

View file

@ -1,10 +1,13 @@
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple
from sklearn.base import BaseEstimator
import ray.cloudpickle as cpickle
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air._internal.checkpointing import (
save_preprocessor_to_dir,
load_preprocessor_from_dir,
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
@ -36,3 +39,25 @@ def to_air_checkpoint(
checkpoint = Checkpoint.from_directory(path)
return checkpoint
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``SklearnTrainer``.
Args:
checkpoint: The checkpoint to load the estimator and
preprocessor from. It is expected to be from the result of a
``SklearnTrainer`` run.
Returns:
The estimator and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
estimator_path = os.path.join(checkpoint_path, MODEL_KEY)
with open(estimator_path, "rb") as f:
estimator = cpickle.load(f)
preprocessor = load_preprocessor_from_dir(checkpoint_path)
return estimator, preprocessor

View file

@ -10,9 +10,9 @@ except ModuleNotFoundError:
from ray.train.tensorflow.config import TensorflowConfig
from ray.train.tensorflow.tensorflow_predictor import TensorflowPredictor
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer, load_checkpoint
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer
from ray.train.tensorflow.train_loop_utils import prepare_dataset_shard
from ray.train.tensorflow.utils import to_air_checkpoint
from ray.train.tensorflow.utils import to_air_checkpoint, load_checkpoint
__all__ = [
"TensorflowConfig",

View file

@ -1,9 +1,8 @@
from typing import Callable, Optional, Dict, Tuple, Type, Union, TYPE_CHECKING
import tensorflow as tf
from typing import Callable, Optional, Dict, Union, TYPE_CHECKING
from ray.train.tensorflow.config import TensorflowConfig
from ray.train.trainer import GenDataset
from ray.train.data_parallel_trainer import DataParallelTrainer, _load_checkpoint
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.air.config import ScalingConfig, RunConfig, DatasetConfig
from ray.air.checkpoint import Checkpoint
from ray.util import PublicAPI
@ -185,27 +184,3 @@ class TensorflowTrainer(DataParallelTrainer):
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
) -> Tuple[tf.keras.Model, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``TensorflowTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``TensorflowTrainer`` run.
model: A callable that returns a TensorFlow Keras model
to use, or an instantiated model.
Model weights will be loaded from the checkpoint.
Returns:
The model with set weights and AIR preprocessor contained within.
"""
model_weights, preprocessor = _load_checkpoint(checkpoint, "TensorflowTrainer")
if isinstance(model, type) or callable(model):
model = model()
model.set_weights(model_weights)
return model, preprocessor

View file

@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union, Callable, Type, Tuple
import tensorflow as tf
from tensorflow import keras
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@ -25,3 +27,27 @@ def to_air_checkpoint(
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()}
)
return checkpoint
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
) -> Tuple[tf.keras.Model, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``TensorflowTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``TensorflowTrainer`` run.
model: A callable that returns a TensorFlow Keras model
to use, or an instantiated model.
Model weights will be loaded from the checkpoint.
Returns:
The model with set weights and AIR preprocessor contained within.
"""
model_weights, preprocessor = _load_checkpoint(checkpoint, "TensorflowTrainer")
if isinstance(model, type) or callable(model):
model = model()
model.set_weights(model_weights)
return model, preprocessor

View file

@ -14,7 +14,7 @@ from transformers.trainer_callback import TrainerState
import ray.data
from ray.train.batch_predictor import BatchPredictor
from ray.train.huggingface import HuggingFacePredictor, HuggingFaceTrainer
from ray.train.huggingface.huggingface_utils import TrainReportCallback
from ray.train.huggingface._huggingface_utils import TrainReportCallback
from ray.train.tests._huggingface_data import train_data, validation_data
# 16 first rows of tokenized wikitext-2-raw-v1 training & validation

View file

@ -9,7 +9,7 @@ except ModuleNotFoundError:
from ray.train.torch.config import TorchConfig
from ray.train.torch.torch_predictor import TorchPredictor
from ray.train.torch.torch_trainer import TorchTrainer, load_checkpoint
from ray.train.torch.torch_trainer import TorchTrainer
from ray.train.torch.train_loop_utils import (
TorchWorkerProfiler,
accelerate,
@ -20,7 +20,7 @@ from ray.train.torch.train_loop_utils import (
prepare_model,
prepare_optimizer,
)
from ray.train.torch.utils import to_air_checkpoint
from ray.train.torch.utils import to_air_checkpoint, load_checkpoint
__all__ = [
"TorchTrainer",

View file

@ -7,7 +7,7 @@ import torch
from ray.air._internal.torch_utils import convert_pandas_to_torch_tensor
from ray.air.checkpoint import Checkpoint
from ray.train.predictor import DataBatchType, Predictor
from ray.train.torch.torch_trainer import load_checkpoint
from ray.train.torch.utils import load_checkpoint
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor

View file

@ -1,11 +1,8 @@
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import torch
from ray.air._internal.torch_utils import load_torch_model
from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer, _load_checkpoint
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.torch.config import TorchConfig
from ray.train.trainer import GenDataset
from ray.util import PublicAPI
@ -196,24 +193,3 @@ class TorchTrainer(DataParallelTrainer):
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
def load_checkpoint(
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``TorchTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``TorchTrainer`` run.
model: If the checkpoint contains a model state dict, and not
the model itself, then the state dict will be loaded to this
``model``.
Returns:
The model with set weights and AIR preprocessor contained within.
"""
saved_model, preprocessor = _load_checkpoint(checkpoint, "TorchTrainer")
model = load_torch_model(saved_model=saved_model, model_definition=model)
return model, preprocessor

View file

@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple
import torch
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.air._internal.torch_utils import load_torch_model
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@ -25,3 +27,24 @@ def to_air_checkpoint(
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model}
)
return checkpoint
def load_checkpoint(
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``TorchTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``TorchTrainer`` run.
model: If the checkpoint contains a model state dict, and not
the model itself, then the state dict will be loaded to this
``model``.
Returns:
The model with set weights and AIR preprocessor contained within.
"""
saved_model, preprocessor = _load_checkpoint(checkpoint, "TorchTrainer")
model = load_torch_model(saved_model=saved_model, model_definition=model)
return model, preprocessor

View file

@ -1,5 +1,5 @@
from ray.train.xgboost.utils import to_air_checkpoint
from ray.train.xgboost.utils import load_checkpoint, to_air_checkpoint
from ray.train.xgboost.xgboost_predictor import XGBoostPredictor
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer, load_checkpoint
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer
__all__ = ["XGBoostPredictor", "XGBoostTrainer", "load_checkpoint", "to_air_checkpoint"]

View file

@ -1,9 +1,12 @@
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple
import xgboost
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air._internal.checkpointing import (
save_preprocessor_to_dir,
load_preprocessor_from_dir,
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
@ -34,3 +37,24 @@ def to_air_checkpoint(
checkpoint = Checkpoint.from_directory(path)
return checkpoint
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``XGBoostTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``XGBoostTrainer`` run.
Returns:
The model and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
xgb_model = xgboost.Booster()
xgb_model.load_model(os.path.join(checkpoint_path, MODEL_KEY))
preprocessor = load_preprocessor_from_dir(checkpoint_path)
return xgb_model, preprocessor

View file

@ -6,7 +6,7 @@ import xgboost
from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.predictor import Predictor
from ray.train.xgboost.xgboost_trainer import load_checkpoint
from ray.train.xgboost.utils import load_checkpoint
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor

View file

@ -1,11 +1,9 @@
import os
from typing import Optional, Tuple, TYPE_CHECKING
from ray.air.checkpoint import Checkpoint
from ray.train.gbdt_trainer import GBDTTrainer
from ray.air._internal.checkpointing import load_preprocessor_from_dir
from ray.util.annotations import PublicAPI
from ray.train.constants import MODEL_KEY
from ray.train.xgboost.utils import load_checkpoint
import xgboost
import xgboost_ray
@ -75,24 +73,3 @@ class XGBoostTrainer(GBDTTrainer):
self, checkpoint: Checkpoint
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
return load_checkpoint(checkpoint)
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
"""Load a Checkpoint from ``XGBoostTrainer``.
Args:
checkpoint: The checkpoint to load the model and
preprocessor from. It is expected to be from the result of a
``XGBoostTrainer`` run.
Returns:
The model and AIR preprocessor contained within.
"""
with checkpoint.as_directory() as checkpoint_path:
xgb_model = xgboost.Booster()
xgb_model.load_model(os.path.join(checkpoint_path, MODEL_KEY))
preprocessor = load_preprocessor_from_dir(checkpoint_path)
return xgb_model, preprocessor