[air] Consolidate Tune and Train report (#25558)

Consolidate tune/train report/checkpoint functionality by working with a unified Session interface.
The goal of this PR is to establish a solid Session and Session.report path. 
In favor of having less merging conflict (as other folks are doing the whole package renaming) and control the scope of this PR, I have intentionally left out some migration. More PRs to follow. Feel free to comment on the ideal final state. 


To give an idea of the final directory structure. This is a for 2-worker DP training.
```
├── TensorflowTrainer_ce44d_00000_0_2022-06-15_14-40-42
│   ├── checkpoint_000000
│   │   ├── _current_checkpoint_id.meta.pkl
│   │   ├── _preprocessor.meta.pkl
│   │   ├── _timestamp.meta.pkl
│   │   ├── assets
│   │   ├── keras_metadata.pb
│   │   ├── saved_model.pb
│   │   └── variables
│   │       ├── variables.data-00000-of-00001
│   │       └── variables.index
│   ├── events.out.tfevents.1655329242.xw
│   ├── params.json
│   ├── params.pkl
│   ├── progress.csv
│   ├── rank_0
│   │   └── my_model
│   │       ├── assets
│   │       ├── keras_metadata.pb
│   │       ├── saved_model.pb
│   │       └── variables
│   │           ├── variables.data-00000-of-00001
│   │           └── variables.index
│   ├── rank_1
│   │   └── my_model
│   │       ├── assets
│   │       ├── keras_metadata.pb
│   │       ├── saved_model.pb
│   │       └── variables
│   │           ├── variables.data-00000-of-00001
│   │           └── variables.index
│   └── result.json
├── basic-variant-state-2022-06-15_14-40-42.json
├── experiment_state-2022-06-15_14-40-42.json
├── trainable.pkl
└── tuner.pkl
```
Update:
1. Updated a few classes to be backward compatible - while legacy ray train deprecation is ongoing.
2. Marked all places in 1 using "# TODO(xwjiang): Legacy Ray Train trainer clean up!". So we can easily clean those up once Antoni's work is landed.
3. All CI and release tests are passing.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
xwjiang2010 2022-06-17 13:49:01 -07:00 committed by GitHub
parent 2b270fd9cb
commit 97f42425da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 978 additions and 252 deletions

View file

@ -193,6 +193,14 @@ py_test(
deps = [":ml_lib"]
)
py_test(
name = "test_keras_callback",
size = "small",
srcs = ["tests/test_keras_callback.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test(
name = "test_remote_storage",
size = "small",

View file

@ -0,0 +1,87 @@
import abc
import logging
from typing import Dict, Optional
from ray.air.checkpoint import Checkpoint
logger = logging.getLogger(__name__)
class Session(abc.ABC):
"""The canonical session interface that both Tune and Train session implements.
User can interact with this interface to get session information,
as well as reporting metrics and saving checkpoint.
"""
@abc.abstractmethod
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save checkpoint.
Each invocation of this method will automatically increment the underlying
iteration number. The physical meaning of this "iteration" is defined by
user (or more specifically the way they call ``report``).
It does not necessarily map to one epoch.
This API is supposed to replace the legacy ``tune.report``,
``with tune.checkpoint_dir``, ``train.report`` and ``train.save_checkpoint``.
Please avoid mixing them together.
There is no requirement on what is the underlying representation of the
checkpoint.
All forms are accepted and (will be) handled by AIR in an efficient way.
Specifically, if you are passing in a directory checkpoint, AIR will move
the content of the directory to AIR managed directory. By the return of this
method, one may safely write new content to the original directory without
interfering with AIR checkpointing flow.
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""
raise NotImplementedError
@property
@abc.abstractmethod
def loaded_checkpoint(self) -> Optional[Checkpoint]:
"""Access the session's loaded checkpoint to resume from if applicable.
Returns:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""
raise NotImplementedError
@property
def trial_name(self) -> str:
"""Trial name for the corresponding trial."""
raise NotImplementedError
@property
def trial_id(self) -> str:
"""Trial id for the corresponding trial."""
raise NotImplementedError
@property
def trial_resources(self) -> Dict[str, float]:
"""Trial resources for the corresponding trial."""
raise NotImplementedError
def _get_session() -> Optional[Session]:
from ray.train._internal.session import _session_v2 as train_session
from ray.tune.session import _session_v2 as tune_session
if train_session and tune_session:
logger.warning(
"Expected to be either in tune session or train session but not both."
)
return None
if not (train_session or tune_session):
logger.warning("In neither tune session nor train session!")
return None
return train_session or tune_session

View file

@ -0,0 +1,188 @@
from collections import Counter
from typing import Dict, List, Optional, Union
from tensorflow.keras.callbacks import Callback as KerasCallback
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.util.annotations import PublicAPI
class _Callback(KerasCallback):
"""Base class for Air's Keras callbacks."""
_allowed = [
"batch_begin",
"batch_end",
"epoch_begin",
"epoch_end",
"train_batch_begin",
"train_batch_end",
"test_batch_begin",
"test_batch_end",
"predict_batch_begin",
"predict_batch_end",
"train_begin",
"train_end",
"test_begin",
"test_end",
"predict_begin",
"predict_end",
]
def __init__(self, on: Union[str, List[str]] = "validation_end"):
super(_Callback, self).__init__()
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on
def _handle(self, logs: Dict, when: str):
raise NotImplementedError
def on_batch_begin(self, batch, logs=None):
if "batch_begin" in self._on:
self._handle(logs, "batch_begin")
def on_batch_end(self, batch, logs=None):
if "batch_end" in self._on:
self._handle(logs, "batch_end")
def on_epoch_begin(self, epoch, logs=None):
if "epoch_begin" in self._on:
self._handle(logs, "epoch_begin")
def on_epoch_end(self, epoch, logs=None):
if "epoch_end" in self._on:
self._handle(logs, "epoch_end")
def on_train_batch_begin(self, batch, logs=None):
if "train_batch_begin" in self._on:
self._handle(logs, "train_batch_begin")
def on_train_batch_end(self, batch, logs=None):
if "train_batch_end" in self._on:
self._handle(logs, "train_batch_end")
def on_test_batch_begin(self, batch, logs=None):
if "test_batch_begin" in self._on:
self._handle(logs, "test_batch_begin")
def on_test_batch_end(self, batch, logs=None):
if "test_batch_end" in self._on:
self._handle(logs, "test_batch_end")
def on_predict_batch_begin(self, batch, logs=None):
if "predict_batch_begin" in self._on:
self._handle(logs, "predict_batch_begin")
def on_predict_batch_end(self, batch, logs=None):
if "predict_batch_end" in self._on:
self._handle(logs, "predict_batch_end")
def on_train_begin(self, logs=None):
if "train_begin" in self._on:
self._handle(logs, "train_begin")
def on_train_end(self, logs=None):
if "train_end" in self._on:
self._handle(logs, "train_end")
def on_test_begin(self, logs=None):
if "test_begin" in self._on:
self._handle(logs, "test_begin")
def on_test_end(self, logs=None):
if "test_end" in self._on:
self._handle(logs, "test_end")
def on_predict_begin(self, logs=None):
if "predict_begin" in self._on:
self._handle(logs, "predict_begin")
def on_predict_end(self, logs=None):
if "predict_end" in self._on:
self._handle(logs, "predict_end")
@PublicAPI(stability="beta")
class Callback(_Callback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "epoch_end",
frequency: Union[int, List[int]] = 1,
):
"""
Args:
metrics: Metrics to report. If this is a list, each item describes
the metric key reported to Keras, and it will reported under the
same name. If this is a dict, each key will be the name reported
and the respective value will be the metric key reported to Keras.
If this is None, all Keras logs will be reported.
on: When to report metrics. Must be one of
the Keras event hooks (less the ``on_``), e.g.
"train_start", or "predict_end". Defaults to "epoch_end".
frequency: Checkpoint frequency. If this is an integer `n`,
checkpoints are saved every `n` times each hook was called. If
this is a list, it specifies the checkpoint frequencies for each
hook individually.
You can use this in both TuneSession and TrainSession.
Example:
.. code-block: python
############# Using it in TrainSession ###############
from ray.air.callbacks.keras import Callback
def train_loop_per_worker():
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = build_model()
#model.compile(...)
model.fit(dataset_shard, callbacks=[Callback()])
"""
if isinstance(frequency, list):
if not isinstance(on, list) or len(frequency) != len(on):
raise ValueError(
"If you pass a list for checkpoint frequencies, the `on` "
"parameter has to be a list with the same length."
)
self._frequency = frequency
super(Callback, self).__init__(on)
self._metrics = metrics
self._counter = Counter()
def _handle(self, logs: Dict, when: str = None):
self._counter[when] += 1
if isinstance(self._frequency, list):
index = self._on.index(when)
freq = self._frequency[index]
else:
freq = self._frequency
checkpoint = None
if freq > 0 and self._counter[when] % freq == 0:
self.model.save("my_model", overwrite=True)
checkpoint = Checkpoint.from_directory("my_model")
if not self._metrics:
report_dict = logs
else:
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = logs[metric]
session.report(report_dict, checkpoint=checkpoint)

258
python/ray/air/session.py Normal file
View file

@ -0,0 +1,258 @@
from typing import TYPE_CHECKING, Dict, Optional, Union
from ray.air._internal.session import _get_session
from ray.air.checkpoint import Checkpoint
from ray.train.session import _TrainSessionImpl
if TYPE_CHECKING:
from ray.data import Dataset, DatasetPipeline
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save a checkpoint.
Each invocation of this method will automatically increment the underlying
iteration number. The physical meaning of this "iteration" is defined by
user (or more specifically the way they call ``report``).
It does not necessarily map to one epoch.
This API is the canonical way to report metrics from Tune and Train, and
replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``,
``train.report`` and ``train.save_checkpoint`` calls.
Note on directory checkpoints: AIR will take ownership of checkpoints passed
to ``report()`` by moving them to a new path. The original directory will no
longer be accessible to the caller after the report call.
Example:
.. code-block: python
from ray.air import session
from ray.air.checkpoint import Checkpoint
######## Using it in the *per worker* train loop (TrainSession) #######
def train_func():
model = build_model()
model.save("my_model", overwrite=True)
session.report(
metrics={"foo": "bar"},
checkpoint=Checkpoint.from_directory(temp_dir.name)
)
# Air guarantees by this point, you can safely write new stuff to
# "my_model" directory.
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
# If you navigate to result.checkpoint's path, you will find the
content of ``model.save()`` under it.
# If you have `SyncConfig` configured, the content should also
# show up in the corresponding cloud storage path.
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""
_get_session().report(metrics, checkpoint=checkpoint)
def get_checkpoint() -> Optional[Checkpoint]:
"""Access the session's last checkpoint to resume from if applicable.
Returns:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
Example:
.. code-block: python
######## Using it in the *per worker* train loop (TrainSession) ######
from ray.air import session
from ray.air.checkpoint import Checkpoint
def train_func():
if session.get_checkpoint():
with session.get_checkpoint().as_directory() as
loaded_checkpoint_dir:
import tensorflow as tf
model = tf.keras.models.load_model(loaded_checkpoint_dir)
else:
model = build_model()
model.save("my_model", overwrite=True)
session.report(
metrics={"iter": 1},
checkpoint=Checkpoint.from_directory("my_model")
)
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
# trainer2 will pick up from the checkpoint saved by trainer1.
trainer2 = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
# this is ultimately what is accessed through
# ``Session.get_checkpoint()``
resume_from_checkpoint=result.checkpoint,
)
result2 = trainer2.fit()
"""
return _get_session().loaded_checkpoint
def get_trial_name() -> str:
"""Trial name for the corresponding trial."""
return _get_session().trial_name
def get_trial_id() -> str:
"""Trial id for the corresponding trial."""
return _get_session().trial_id
def get_trial_resources() -> Dict[str, float]:
"""Trial resources for the corresponding trial."""
return _get_session().trial_resources
def get_world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
.. code-block:: python
import time
from ray.air import session
def train_loop_per_worker(config):
assert session.get_world_size() == 4
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config={"num_workers": 1},
datasets={"train": train_dataset})
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.world_size
def get_world_rank() -> int:
"""Get the world rank of this worker.
.. code-block:: python
import time
from ray.air import session
def train_loop_per_worker():
for iter in range(100):
time.sleep(1)
if session.get_world_rank() == 0:
print("Worker 0")
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config={"num_workers": 1},
datasets={"train": train_dataset})
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_world_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.world_rank
def get_local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. code-block:: python
import time
from ray.air import session
def train_loop_per_worker():
if torch.cuda.is_available():
torch.cuda.set_device(session.get_local_rank())
...
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config={"num_workers": 1},
datasets={"train": train_dataset})
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_local_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.local_rank
def get_dataset_shard(
dataset_name: Optional[str] = None,
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
"""Returns the Ray Dataset or DatasetPipeline shard for this worker.
You should call ``to_torch()`` or ``to_tf()`` on this shard to convert
it to the appropriate framework-specific Dataset.
.. code-block:: python
import ray
from ray import train
from ray.air import session
def train_loop_per_worker():
model = Net()
for iter in range(100):
# Trainer will automatically handle sharding.
data_shard = session.get_dataset_shard().to_torch()
model.train(data_shard)
return model
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
scaling_config={"num_workers": 2},
datasets={"train": train_dataset})
trainer.fit()
Args:
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
specifies which dataset shard to return.
Returns:
The ``Dataset`` or ``DatasetPipeline`` shard to use for this worker.
If no dataset is passed into Trainer, then return None.
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_dataset_shard` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.get_dataset_shard(dataset_name)

View file

@ -0,0 +1,64 @@
import os
import tensorflow as tf
from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.air.examples.tf.tensorflow_linear_dataset_example import (
build_model,
get_dataset,
)
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.tensorflow import TensorflowTrainer, prepare_dataset_shard
def train_func(config: dict):
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_model()
multi_worker_model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_squared_error],
)
dataset = session.get_dataset_shard("train")
for _ in range(config.get("epoch", 3)):
tf_dataset = prepare_dataset_shard(
dataset.to_tf(
label_column="y",
output_signature=(
tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
tf.TensorSpec(shape=(None), dtype=tf.float32),
),
batch_size=32,
)
)
multi_worker_model.fit(tf_dataset, callbacks=[Callback()])
def test_keras_callback():
epochs = 3
scaling_config = {"num_workers": 2}
config = {
"epochs": epochs,
}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=scaling_config,
datasets={TRAIN_DATASET_KEY: get_dataset()},
)
checkpoint = trainer.fit().checkpoint
with checkpoint.as_directory() as ckpt_dir:
assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb"))
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -1,23 +1,29 @@
import logging
import os
from collections import defaultdict
from typing import Callable, List, Optional, Dict, Type, Tuple, TypeVar
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
import ray
from ray.air.checkpoint import Checkpoint
from ray.exceptions import RayActorError
from ray.ray_constants import env_integer
from ray.train._internal.dataset_spec import RayDatasetSpec
from ray.train._internal.session import (
TrainingResult,
TrialInfo,
get_session,
init_session,
shutdown_session,
)
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import BackendConfig
from ray.train.constants import (
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
)
from ray.train.backend import BackendConfig
from ray.train._internal.dataset_spec import RayDatasetSpec
from ray.train._internal.session import TrainingResult
from ray.train._internal.session import init_session, get_session, shutdown_session
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup
from ray.util.placement_group import get_current_placement_group, remove_placement_group
T = TypeVar("T")
@ -57,6 +63,8 @@ class BackendExecutor:
def __init__(
self,
backend_config: BackendConfig,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
trial_info: Optional[TrialInfo] = None,
num_workers: int = 1,
num_cpus_per_worker: float = 1,
num_gpus_per_worker: float = 0,
@ -76,6 +84,8 @@ class BackendExecutor:
self._initialization_hook = None
self._placement_group = None
self._trial_info = trial_info
self.worker_group = InactiveWorkerGroup()
self.dataset_shards = None
@ -264,7 +274,7 @@ class BackendExecutor:
self,
train_func: Callable[[], T],
dataset_spec: RayDatasetSpec,
checkpoint: Optional[Dict] = None,
checkpoint: Optional[Checkpoint] = None,
) -> None:
"""Executes a training function on all workers in a separate thread.
@ -290,6 +300,7 @@ class BackendExecutor:
world_rank,
local_rank,
world_size,
trial_info,
checkpoint,
dataset_shard,
encode_data_fn,
@ -300,6 +311,7 @@ class BackendExecutor:
world_rank=world_rank,
local_rank=local_rank,
world_size=world_size,
trial_info=trial_info,
dataset_shard=dataset_shard,
checkpoint=checkpoint,
encode_data_fn=encode_data_fn,
@ -328,6 +340,7 @@ class BackendExecutor:
world_rank=index,
local_rank=local_rank_map[index],
world_size=len(self.worker_group),
trial_info=self._trial_info,
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint,

View file

@ -1,22 +1,21 @@
import logging
from pathlib import Path
from typing import List, Optional, Dict, Union, Callable
from typing import Callable, Dict, List, Optional, Union
from ray.air import Checkpoint
from ray.train._internal.session import TrainingResult
from ray.train._internal.utils import construct_path
from ray.train.constants import (
TIMESTAMP,
TRAIN_CHECKPOINT_SUBDIR,
TUNE_CHECKPOINT_ID,
TUNE_INSTALLED,
)
from ray.train._internal.session import TrainingResult
from ray.train._internal.utils import construct_path
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, CheckpointStrategy
from ray.util.ml_utils.checkpoint_manager import (
_CheckpointManager as CommonCheckpointManager,
_TrackedCheckpoint,
CheckpointStrategy,
CheckpointStorage,
)
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint
if TUNE_INSTALLED:
from ray import tune
@ -80,14 +79,17 @@ class CheckpointManager(CommonCheckpointManager):
if self._checkpoint_strategy.checkpoint_score_attribute is None:
self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP
# TODO(xwjiang): Legacy Ray Train trainer clean up!
def _load_checkpoint(
self, checkpoint_to_load: Optional[Union[Dict, str, Path]]
) -> Optional[Dict]:
self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
) -> Optional[Union[Dict, Checkpoint]]:
"""Load the checkpoint dictionary from the input dict or path."""
if checkpoint_to_load is None:
return None
if isinstance(checkpoint_to_load, Dict):
return checkpoint_to_load
if isinstance(checkpoint_to_load, Checkpoint):
return checkpoint_to_load
else:
# Load checkpoint from path.
return load_checkpoint_from_path(checkpoint_to_load)
@ -198,9 +200,17 @@ class CheckpointManager(CommonCheckpointManager):
class TuneCheckpointManager(CheckpointManager):
def _load_checkpoint(
self, checkpoint_to_load: Optional[Union[Dict, str, Path]]
) -> Optional[Dict]:
self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
) -> Optional[Union[Dict, Checkpoint]]:
# TODO(xwjiang): Legacy Ray Train trainer clean up!
loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load)
# New path...
if isinstance(loaded_checkpoint, Checkpoint):
# The new logic
checkpoint_dict = loaded_checkpoint.to_dict()
self._latest_checkpoint_id = checkpoint_dict[TUNE_CHECKPOINT_ID]
return loaded_checkpoint
# legacy path...
if loaded_checkpoint is not None:
# If the Tune trial is restarted, a new Trainer is instantiated.
# However, we want the checkpoint_id to continue incrementing

View file

@ -3,29 +3,30 @@ import platform
import queue
import threading
import time
from datetime import datetime
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
from typing import Callable
from typing import Optional, Dict, Type, Union
from typing import Callable, Dict, Optional, Type, Union
import ray
from ray.air.checkpoint import Checkpoint
from ray.data import Dataset, DatasetPipeline
from ray.train._internal.accelerator import Accelerator
from ray.train.constants import (
DETAILED_AUTOFILLED_KEYS,
TIME_THIS_ITER_S,
PID,
TIMESTAMP,
TIME_TOTAL_S,
NODE_IP,
TRAINING_ITERATION,
HOSTNAME,
DATE,
RESULT_FETCH_TIMEOUT,
)
from ray.train._internal.utils import PropagatingThread
from ray.train.constants import (
DATE,
DETAILED_AUTOFILLED_KEYS,
HOSTNAME,
NODE_IP,
PID,
RESULT_FETCH_TIMEOUT,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
TIMESTAMP,
TRAINING_ITERATION,
)
from ray.train.error import SessionMisuseError
from ray.train.session import _TrainSessionImpl
class TrainingResultType(Enum):
@ -33,13 +34,24 @@ class TrainingResultType(Enum):
CHECKPOINT = auto()
@dataclass
class TrialInfo:
"""The trial information to propagate to TrainSession."""
name: str
id: str
resources: Dict[str, float]
logdir: str
@dataclass
class TrainingResult:
type: TrainingResultType
data: Dict
class Session:
# TODO(xwjiang): This needs a better name.
class _TrainSession:
"""Holds information for training on each worker."""
def __init__(
@ -48,8 +60,11 @@ class Session:
world_rank: int,
local_rank: int,
world_size: int,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
checkpoint: Optional[Dict] = None,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
encode_data_fn: Callable = None,
detailed_autofilled_metrics: bool = False,
):
@ -61,7 +76,9 @@ class Session:
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
self.loaded_checkpoint = checkpoint
self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint
# Function to encode checkpoint dict before sending to the driver.
if not encode_data_fn:
@ -72,6 +89,13 @@ class Session:
encode_data_fn = noop
self._encode_data_fn = encode_data_fn
# TODO(xwjiang): Legacy Ray Train trainer clean up!
if trial_info:
# Change the working directory to `logdir`.
logdir = os.path.join(trial_info.logdir, f"rank_{self.world_rank}")
os.makedirs(logdir, exist_ok=True)
os.chdir(logdir)
# This lock is used to control the execution of the training thread.
self.continue_lock = threading.Semaphore(0)
@ -184,7 +208,7 @@ class Session:
result.update(auto_filled_metrics)
return result
def report(self, **kwargs):
def _report_legacy(self, **kwargs):
"""Adds kwargs to the queue to be consumed by main thread."""
if self.ignore_report:
return
@ -234,22 +258,32 @@ class Session:
# checkpoint has been processed.
self.continue_lock.acquire()
def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.
if checkpoint:
checkpoint_dict = checkpoint.to_dict()
self.checkpoint(**checkpoint_dict)
self._report_legacy(**metrics)
_session = None
_session: Optional[_TrainSession] = None
# V2 Session API
_session_v2: Optional[_TrainSessionImpl] = None
def init_session(*args, **kwargs) -> None:
global _session
global _session_v2
if _session:
raise ValueError(
"A Train session is already in use. Do not call "
"`init_session()` manually."
)
_session = Session(*args, **kwargs)
_session = _TrainSession(*args, **kwargs)
_session_v2 = _TrainSessionImpl(session=_session)
def get_session() -> Optional[Session]:
global _session
def get_session() -> Optional[_TrainSession]:
return _session

View file

@ -1,33 +1,22 @@
import inspect
import logging
import os
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Optional,
Tuple,
Union,
Type,
TYPE_CHECKING,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
import ray
from ray import tune
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.constants import (
TRAIN_DATASET_KEY,
WILDCARD_KEY,
)
from ray.train.trainer import BaseTrainer
from ray.air.config import ScalingConfig, RunConfig, DatasetConfig
from ray.train.trainer import GenDataset
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.train._internal.dataset_spec import DataParallelIngestSpec
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train import BackendConfig, TrainingIterator
from ray.train._internal.backend_executor import BackendExecutor
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
from ray.train._internal.checkpoint import TuneCheckpointManager
from ray.train._internal.dataset_spec import DataParallelIngestSpec
from ray.train._internal.utils import construct_train_func
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.util.annotations import DeveloperAPI
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, _TrackedCheckpoint
@ -335,8 +324,16 @@ class DataParallelTrainer(BaseTrainer):
scaling_config_dataclass.additional_resources_per_worker
)
trial_info = TrialInfo(
name=session.get_trial_name(),
id=session.get_trial_id(),
resources=session.get_trial_resources(),
logdir=os.getcwd(),
)
backend_executor = BackendExecutor(
backend_config=self._backend_config,
trial_info=trial_info,
num_workers=scaling_config_dataclass.num_workers,
num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker,
num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker,
@ -351,20 +348,13 @@ class DataParallelTrainer(BaseTrainer):
# Start the remote actors.
backend_executor.start(initialization_hook=None)
if self.resume_from_checkpoint:
resume_checkpoint_dict = self.resume_from_checkpoint.to_dict()
else:
resume_checkpoint_dict = None
# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead
# of just a Dict.
training_iterator = TrainingIterator(
backend_executor=backend_executor,
backend_config=self._backend_config,
train_func=train_loop_per_worker,
dataset_spec=self._ingest_spec,
checkpoint_manager=checkpoint_manager,
checkpoint=resume_checkpoint_dict,
checkpoint=self.resume_from_checkpoint,
checkpoint_strategy=None,
)

View file

@ -16,6 +16,7 @@ from torch.utils.data import Dataset as TorchDataset
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from ray import train
from ray.air import session
from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
save_preprocessor_to_dir,
@ -517,12 +518,14 @@ def _huggingface_train_loop_per_worker(config):
trainer.add_callback(TrainReportCallback)
checkpoint = train.load_checkpoint()
checkpoint = session.get_checkpoint()
checkpoint_path = None
remove_checkpoint_path = False
if checkpoint:
source_ip = checkpoint[NODE_IP_KEY]
source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY]
assert isinstance(checkpoint, Checkpoint)
checkpoint_dict = checkpoint.to_dict()
source_ip = checkpoint_dict[NODE_IP_KEY]
source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
target_ip = get_node_ip_address()
if source_ip == target_ip:
checkpoint_path = source_path

View file

@ -0,0 +1,79 @@
import warnings
from typing import TYPE_CHECKING, Dict, Optional, Union
from ray.air._internal.session import Session
from ray.air.checkpoint import Checkpoint
if TYPE_CHECKING:
# avoid circular import
from ray.data import Dataset, DatasetPipeline
from ray.train._internal.session import _TrainSession
class _TrainSessionImpl(Session):
"""Session client that "per worker train loop" can interact with.
Notice that each worker will automatically switch to its working
directory on entering the train loop. This is to ensure that
each worker can safely write to a local directory without racing
and overwriting each other."""
def __init__(self, session: "_TrainSession"):
self._session = session
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
self._session.report(metrics, checkpoint)
@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
ckpt = self._session.loaded_checkpoint
if ckpt:
# The new API should only interact with Checkpoint object.
assert isinstance(ckpt, Checkpoint)
return ckpt
@property
def trial_name(self) -> str:
return self._session.trial_info.name
@property
def trial_id(self) -> str:
return self._session.trial_info.id
@property
def trial_resources(self) -> Dict[str, float]:
return self._session.trial_info.resources
@property
def world_size(self) -> int:
return self._session.world_size
@property
def world_rank(self) -> int:
return self._session.world_rank
@property
def local_rank(self) -> int:
return self._session.local_rank
def get_dataset_shard(
self,
dataset_name: Optional[str] = None,
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
shard = self._session.dataset_shard
if shard is None:
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Ray Dataset to Trainer.run to use this "
"function."
)
elif isinstance(shard, dict):
if not dataset_name:
raise RuntimeError(
"Multiple datasets were passed into ``Trainer``, "
"but no ``dataset_name`` is passed into "
"``get_dataset_shard``. Please specify which "
"dataset shard to retrieve."
)
return shard.get(dataset_name)
return shard

View file

@ -1,11 +1,11 @@
import pytest
import ray
from ray import train, tune
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.train.constants import PREPROCESSOR_KEY
from ray.data.preprocessor import Preprocessor
from ray.train.constants import PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
@ -24,7 +24,7 @@ scale_config = {"num_workers": 2}
def test_fit_train(ray_start_4_cpus):
def train_func():
train.report(loss=1)
session.report({"loss": 1})
trainer = DataParallelTrainer(
train_loop_per_worker=train_func, scaling_config=scale_config
@ -35,7 +35,7 @@ def test_fit_train(ray_start_4_cpus):
def test_scaling_config(ray_start_4_cpus):
def train_func():
assert ray.available_resources()["CPU"] == 1
train.report(loss=1)
session.report({"loss": 1})
assert ray.available_resources()["CPU"] == 4
trainer = DataParallelTrainer(
@ -46,7 +46,7 @@ def test_scaling_config(ray_start_4_cpus):
def test_fit_train_config(ray_start_4_cpus):
def train_func(config):
train.report(loss=config["x"])
session.report({"loss": config["x"]})
trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
@ -65,10 +65,10 @@ def test_datasets(ray_start_4_cpus):
def get_dataset():
# Train dataset should be sharded.
train_dataset = train.get_dataset_shard("train")
train_dataset = session.get_dataset_shard("train")
assert train_dataset.count() == num_train_data / scale_config["num_workers"]
# All other datasets should not be sharded.
val_dataset = train.get_dataset_shard("val")
val_dataset = session.get_dataset_shard("val")
assert val_dataset.count() == num_val_data
trainer = DataParallelTrainer(
@ -82,7 +82,7 @@ def test_datasets(ray_start_4_cpus):
def test_checkpoint(ray_start_4_cpus):
def train_func():
for i in range(3):
train.save_checkpoint(model=i)
session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i}))
trainer = DataParallelTrainer(
train_loop_per_worker=train_func, scaling_config=scale_config
@ -99,7 +99,7 @@ def test_preprocessor_in_checkpoint(ray_start_4_cpus):
def train_func():
for i in range(3):
train.save_checkpoint(model=i)
session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i}))
trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
@ -113,13 +113,13 @@ def test_preprocessor_in_checkpoint(ray_start_4_cpus):
def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
def train_func():
checkpoint = train.load_checkpoint()
checkpoint = session.get_checkpoint()
if checkpoint:
epoch = checkpoint["epoch"]
epoch = checkpoint.to_dict()["epoch"]
else:
epoch = 0
for i in range(epoch, epoch + 2):
train.save_checkpoint(epoch=i)
session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"epoch": i}))
trainer = DataParallelTrainer(
train_loop_per_worker=train_func, scaling_config=scale_config
@ -152,7 +152,7 @@ def test_invalid_train_loop(ray_start_4_cpus):
def test_tune(ray_start_4_cpus):
def train_func(config):
train.report(loss=config["x"])
session.report({"loss": config["x"]})
trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
@ -173,7 +173,8 @@ def test_tune(ray_start_4_cpus):
if __name__ == "__main__":
import pytest
import sys
import pytest
sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -1,8 +1,12 @@
import os
import numpy as np
import pytest
import ray
from ray import train
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.examples.tf.tensorflow_linear_dataset_example import get_dataset
from ray.air.examples.tf.tensorflow_linear_dataset_example import (
train_func as tensorflow_linear_train_func,
@ -34,6 +38,8 @@ def build_model():
@pytest.mark.parametrize("num_workers", [1, 2])
def test_tensorflow_linear(ray_start_4_cpus, num_workers):
"""Also tests air Keras callback."""
def train_func(config):
result = tensorflow_linear_train_func(config)
assert len(result) == epochs
@ -83,6 +89,39 @@ def test_tensorflow_e2e(ray_start_4_cpus):
assert predictions.count() == 3
def test_report_and_load_using_ml_session(ray_start_4_cpus):
def train_func():
if session.get_checkpoint():
with session.get_checkpoint().as_directory() as checkpoint_dir:
import tensorflow as tf
model = tf.keras.models.load_model(checkpoint_dir)
else:
model = build_model()
model.save("my_model", overwrite=True)
session.report(
metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model")
)
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
trainer2 = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
resume_from_checkpoint=result.checkpoint,
)
result = trainer2.fit()
checkpoint = result.checkpoint
with checkpoint.as_directory() as ckpt_dir:
assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb"))
assert result.metrics["iter"] == 1
if __name__ == "__main__":
import sys

View file

@ -1,11 +1,8 @@
from typing import TYPE_CHECKING
from typing import Optional, Dict, Union
import warnings
from typing import TYPE_CHECKING, Dict, Optional, Union
from ray.train._internal.session import get_session
from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train._internal.session import (
get_session,
)
from ray.util import PublicAPI, log_once
if TYPE_CHECKING:
@ -118,7 +115,7 @@ def report(**kwargs) -> None:
if session is None:
_warn_session_misuse(report.__name__)
return
session.report(**kwargs)
session._report_legacy(**kwargs)
@PublicAPI(stability="beta")

View file

@ -1,23 +1,14 @@
import copy
from datetime import datetime
import logging
import os
from pathlib import Path
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, Type
import warnings
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
import ray
from ray.actor import ActorHandle
from ray.train.backend import (
BackendConfig,
)
from ray.train.callbacks.callback import TrainingCallback
from ray.train._internal.dataset_spec import RayDataset, RayDatasetSpec
from ray.train._internal.session import TrainingResultType
from ray.train._internal.utils import (
construct_train_func,
ActorWrapper,
)
from ray.air.checkpoint import Checkpoint
from ray.train._internal.backend_executor import (
BackendExecutor,
InactiveWorkerGroupError,
@ -25,35 +16,37 @@ from ray.train._internal.backend_executor import (
TrainingWorkerError,
)
from ray.train._internal.checkpoint import (
TuneCheckpointManager,
CheckpointManager,
TuneCheckpointManager,
load_checkpoint_from_path,
)
from ray.train.constants import (
TUNE_INSTALLED,
DEFAULT_RESULTS_DIR,
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
)
from ray.train._internal.dataset_spec import RayDataset, RayDatasetSpec
from ray.train._internal.session import TrainingResultType
# Ray Train should be usable even if Tune is not installed.
from ray.train._internal.utils import construct_path
from ray.train._internal.utils import ActorWrapper, construct_path, construct_train_func
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI, Deprecated
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
from ray.train.backend import BackendConfig
from ray.train.base_trainer import ( # noqa: F401
BaseTrainer,
GenDataset,
TrainingFailedError,
)
from ray.train.callbacks.callback import TrainingCallback
from ray.train.constants import (
DEFAULT_RESULTS_DIR,
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TUNE_INSTALLED,
)
from ray.util.annotations import Deprecated, DeveloperAPI
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
if TUNE_INSTALLED:
from ray import tune
from ray.tune import Trainable
from ray.tune import PlacementGroupFactory
from ray.tune import PlacementGroupFactory, Trainable
from ray.tune.function_runner import wrap_function
else:
tune = PlacementGroupFactory = Trainable = object
@ -676,7 +669,7 @@ class TrainingIterator:
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
dataset_spec: RayDatasetSpec,
checkpoint_manager: CheckpointManager,
checkpoint: Optional[Union[Dict, str, Path]],
checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
checkpoint_strategy: Optional[CheckpointStrategy],
run_dir: Optional[Path] = None,
):
@ -715,12 +708,12 @@ class TrainingIterator:
run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id,
)
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)
self._run_with_error_handling(
lambda: self._backend_executor.start_training(
train_func=train_func,
dataset_spec=dataset_spec,
checkpoint=checkpoint_dict,
checkpoint=checkpoint,
)
)

View file

@ -1,34 +1,34 @@
import inspect
import logging
import os
import sys
import time
import inspect
import shutil
import sys
import threading
import time
import uuid
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from ray.util.annotations import DeveloperAPI
from six.moves import queue
from ray.util.debug import log_once
from ray.air.checkpoint import Checkpoint
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (
DEFAULT_METRIC,
TIME_THIS_ITER_S,
RESULT_DUPLICATE,
SHOULD_CHECKPOINT,
TIME_THIS_ITER_S,
)
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.utils import (
detect_checkpoint_function,
detect_config_single,
detect_reporter,
)
from ray.tune.utils.trainable import with_parameters # noqa: F401
from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@ -123,8 +123,6 @@ class FuncCheckpointUtil:
class _StatusReporter:
"""Object passed into your function that you can report status through."""
def __init__(
self,
result_queue,
@ -145,6 +143,10 @@ class _StatusReporter:
self._last_checkpoint = None
self._fresh_checkpoint = False
self._trial_resources = trial_resources
# Also used as a marker of whether new `report()` API is being used,
# in which case, `_iter` will be incremented from 0 every time `report`
# is called.
self._iter = None
def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None):
self._trial_name = trial_name
@ -153,6 +155,7 @@ class _StatusReporter:
self._last_checkpoint = None
self._fresh_checkpoint = False
self._trial_resources = trial_resources
self._iter = None
def __call__(self, _metric=None, **kwargs):
"""Report updated training status.
@ -168,7 +171,7 @@ class _StatusReporter:
"""
assert self._last_report_time is not None, (
"StatusReporter._start() must be called before the first "
"_StatusReporter._start() must be called before the first "
"report __call__ is made to ensure correct runtime metrics."
)
@ -229,6 +232,26 @@ class _StatusReporter:
def _start(self):
self._last_report_time = time.time()
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): Tons of optimizations.
if not self._iter:
self._iter = 0
if checkpoint:
checkpoint_dir = self.make_checkpoint_dir(step=self._iter)
self.set_checkpoint(checkpoint_dir)
checkpoint.to_directory(checkpoint_dir)
# TODO(krfricke): Remove this once support is added in Checkpoint.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
self.__call__(**metrics)
self._iter += 1
@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
if self._last_checkpoint:
assert isinstance(self._last_checkpoint, str)
return Checkpoint.from_directory(self._last_checkpoint)
return None
@property
def logdir(self):
return self._logdir
@ -484,7 +507,7 @@ class FunctionRunner(Trainable):
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
return obj
def load_checkpoint(self, checkpoint):
def load_checkpoint(self, checkpoint: str):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]

View file

@ -1,113 +1,9 @@
from collections import Counter
from typing import Dict, List, Union, Optional
from tensorflow.keras.callbacks import Callback
from ray import tune
import os
from collections import Counter
from typing import Dict, List, Optional, Union
class TuneCallback(Callback):
"""Base class for Tune's Keras callbacks."""
_allowed = [
"batch_begin",
"batch_end",
"epoch_begin",
"epoch_end",
"train_batch_begin",
"train_batch_end",
"test_batch_begin",
"test_batch_end",
"predict_batch_begin",
"predict_batch_end",
"train_begin",
"train_end",
"test_begin",
"test_end",
"predict_begin",
"predict_end",
]
def __init__(self, on: Union[str, List[str]] = "validation_end"):
super(TuneCallback, self).__init__()
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on
def _handle(self, logs: Dict, when: str):
raise NotImplementedError
def on_batch_begin(self, batch, logs=None):
if "batch_begin" in self._on:
self._handle(logs, "batch_begin")
def on_batch_end(self, batch, logs=None):
if "batch_end" in self._on:
self._handle(logs, "batch_end")
def on_epoch_begin(self, epoch, logs=None):
if "epoch_begin" in self._on:
self._handle(logs, "epoch_begin")
def on_epoch_end(self, epoch, logs=None):
if "epoch_end" in self._on:
self._handle(logs, "epoch_end")
def on_train_batch_begin(self, batch, logs=None):
if "train_batch_begin" in self._on:
self._handle(logs, "train_batch_begin")
def on_train_batch_end(self, batch, logs=None):
if "train_batch_end" in self._on:
self._handle(logs, "train_batch_end")
def on_test_batch_begin(self, batch, logs=None):
if "test_batch_begin" in self._on:
self._handle(logs, "test_batch_begin")
def on_test_batch_end(self, batch, logs=None):
if "test_batch_end" in self._on:
self._handle(logs, "test_batch_end")
def on_predict_batch_begin(self, batch, logs=None):
if "predict_batch_begin" in self._on:
self._handle(logs, "predict_batch_begin")
def on_predict_batch_end(self, batch, logs=None):
if "predict_batch_end" in self._on:
self._handle(logs, "predict_batch_end")
def on_train_begin(self, logs=None):
if "train_begin" in self._on:
self._handle(logs, "train_begin")
def on_train_end(self, logs=None):
if "train_end" in self._on:
self._handle(logs, "train_end")
def on_test_begin(self, logs=None):
if "test_begin" in self._on:
self._handle(logs, "test_begin")
def on_test_end(self, logs=None):
if "test_end" in self._on:
self._handle(logs, "test_end")
def on_predict_begin(self, logs=None):
if "predict_begin" in self._on:
self._handle(logs, "predict_begin")
def on_predict_end(self, logs=None):
if "predict_end" in self._on:
self._handle(logs, "predict_end")
from ray import tune
from ray.air.callbacks.keras import _Callback as TuneCallback
class TuneReportCallback(TuneCallback):

View file

@ -1,23 +1,54 @@
from contextlib import contextmanager
import inspect
import os
import logging
import os
import traceback
from contextlib import contextmanager
from typing import Dict, Optional, Set
import ray
from ray.air._internal.session import Session
from ray.air.checkpoint import Checkpoint
from ray.tune.error import TuneError
from ray.tune.function_runner import _StatusReporter
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.util.placement_group import _valid_resource_shape
from ray.util.scheduling_strategies import (
SchedulingStrategyT,
PlacementGroupSchedulingStrategy,
SchedulingStrategyT,
)
from ray.tune.error import TuneError
logger = logging.getLogger(__name__)
_session = None
_session: Optional[_StatusReporter] = None
# V2 Session API.
_session_v2: Optional["_TuneSessionImpl"] = None
class _TuneSessionImpl(Session):
"""Session client that function trainable can interact with."""
def __init__(self, status_reporter: _StatusReporter):
self._status_reporter = status_reporter
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
self._status_reporter.report(metrics, checkpoint=checkpoint)
@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
return self._status_reporter.loaded_checkpoint
@property
def trial_name(self) -> str:
return self._status_reporter.trial_name
@property
def trial_id(self) -> str:
return self._status_reporter.trial_id
@property
def trial_resources(self) -> Dict[str, float]:
return self._status_reporter.trial_resources.required_resources
@PublicAPI
@ -50,6 +81,7 @@ def get_session():
def init(reporter, ignore_reinit_error=True):
"""Initializes the global trial context for this process."""
global _session
global _session_v2
if _session is not None:
# TODO(ng): would be nice to stack crawl at creation time to report
@ -83,6 +115,7 @@ def init(reporter, ignore_reinit_error=True):
remote_function._task_launch_hook = tune_task_and_actor_launch_hook
_session = reporter
_session_v2 = _TuneSessionImpl(status_reporter=reporter)
# Cache of resource dicts that have been checked by the launch hook already.
@ -183,6 +216,11 @@ def report(_metric=None, **kwargs):
"""
_session = get_session()
if _session:
if _session._iter:
raise ValueError(
"It is not allowed to mix `tune.report` with `session.report`."
)
return _session(_metric, **kwargs)
@ -242,6 +280,11 @@ def checkpoint_dir(step: int):
raise ValueError("checkpoint_dir(step) must be provided - got None.")
if _session:
if _session._iter:
raise ValueError(
"It is not allowed to mix `with tune.checkpoint_dir` "
"with `session.report`."
)
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
else:
_checkpoint_dir = os.path.abspath("./")