[air] Consolidate Tune and Train report (#25558)

Consolidate tune/train report/checkpoint functionality by working with a unified Session interface.
The goal of this PR is to establish a solid Session and Session.report path. 
In favor of having less merging conflict (as other folks are doing the whole package renaming) and control the scope of this PR, I have intentionally left out some migration. More PRs to follow. Feel free to comment on the ideal final state. 


To give an idea of the final directory structure. This is a for 2-worker DP training.
```
├── TensorflowTrainer_ce44d_00000_0_2022-06-15_14-40-42
│   ├── checkpoint_000000
│   │   ├── _current_checkpoint_id.meta.pkl
│   │   ├── _preprocessor.meta.pkl
│   │   ├── _timestamp.meta.pkl
│   │   ├── assets
│   │   ├── keras_metadata.pb
│   │   ├── saved_model.pb
│   │   └── variables
│   │       ├── variables.data-00000-of-00001
│   │       └── variables.index
│   ├── events.out.tfevents.1655329242.xw
│   ├── params.json
│   ├── params.pkl
│   ├── progress.csv
│   ├── rank_0
│   │   └── my_model
│   │       ├── assets
│   │       ├── keras_metadata.pb
│   │       ├── saved_model.pb
│   │       └── variables
│   │           ├── variables.data-00000-of-00001
│   │           └── variables.index
│   ├── rank_1
│   │   └── my_model
│   │       ├── assets
│   │       ├── keras_metadata.pb
│   │       ├── saved_model.pb
│   │       └── variables
│   │           ├── variables.data-00000-of-00001
│   │           └── variables.index
│   └── result.json
├── basic-variant-state-2022-06-15_14-40-42.json
├── experiment_state-2022-06-15_14-40-42.json
├── trainable.pkl
└── tuner.pkl
```
Update:
1. Updated a few classes to be backward compatible - while legacy ray train deprecation is ongoing.
2. Marked all places in 1 using "# TODO(xwjiang): Legacy Ray Train trainer clean up!". So we can easily clean those up once Antoni's work is landed.
3. All CI and release tests are passing.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
xwjiang2010 2022-06-17 13:49:01 -07:00 committed by GitHub
parent 2b270fd9cb
commit 97f42425da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 978 additions and 252 deletions

View file

@ -193,6 +193,14 @@ py_test(
deps = [":ml_lib"] deps = [":ml_lib"]
) )
py_test(
name = "test_keras_callback",
size = "small",
srcs = ["tests/test_keras_callback.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test( py_test(
name = "test_remote_storage", name = "test_remote_storage",
size = "small", size = "small",

View file

@ -0,0 +1,87 @@
import abc
import logging
from typing import Dict, Optional
from ray.air.checkpoint import Checkpoint
logger = logging.getLogger(__name__)
class Session(abc.ABC):
"""The canonical session interface that both Tune and Train session implements.
User can interact with this interface to get session information,
as well as reporting metrics and saving checkpoint.
"""
@abc.abstractmethod
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save checkpoint.
Each invocation of this method will automatically increment the underlying
iteration number. The physical meaning of this "iteration" is defined by
user (or more specifically the way they call ``report``).
It does not necessarily map to one epoch.
This API is supposed to replace the legacy ``tune.report``,
``with tune.checkpoint_dir``, ``train.report`` and ``train.save_checkpoint``.
Please avoid mixing them together.
There is no requirement on what is the underlying representation of the
checkpoint.
All forms are accepted and (will be) handled by AIR in an efficient way.
Specifically, if you are passing in a directory checkpoint, AIR will move
the content of the directory to AIR managed directory. By the return of this
method, one may safely write new content to the original directory without
interfering with AIR checkpointing flow.
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""
raise NotImplementedError
@property
@abc.abstractmethod
def loaded_checkpoint(self) -> Optional[Checkpoint]:
"""Access the session's loaded checkpoint to resume from if applicable.
Returns:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""
raise NotImplementedError
@property
def trial_name(self) -> str:
"""Trial name for the corresponding trial."""
raise NotImplementedError
@property
def trial_id(self) -> str:
"""Trial id for the corresponding trial."""
raise NotImplementedError
@property
def trial_resources(self) -> Dict[str, float]:
"""Trial resources for the corresponding trial."""
raise NotImplementedError
def _get_session() -> Optional[Session]:
from ray.train._internal.session import _session_v2 as train_session
from ray.tune.session import _session_v2 as tune_session
if train_session and tune_session:
logger.warning(
"Expected to be either in tune session or train session but not both."
)
return None
if not (train_session or tune_session):
logger.warning("In neither tune session nor train session!")
return None
return train_session or tune_session

View file

@ -0,0 +1,188 @@
from collections import Counter
from typing import Dict, List, Optional, Union
from tensorflow.keras.callbacks import Callback as KerasCallback
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.util.annotations import PublicAPI
class _Callback(KerasCallback):
"""Base class for Air's Keras callbacks."""
_allowed = [
"batch_begin",
"batch_end",
"epoch_begin",
"epoch_end",
"train_batch_begin",
"train_batch_end",
"test_batch_begin",
"test_batch_end",
"predict_batch_begin",
"predict_batch_end",
"train_begin",
"train_end",
"test_begin",
"test_end",
"predict_begin",
"predict_end",
]
def __init__(self, on: Union[str, List[str]] = "validation_end"):
super(_Callback, self).__init__()
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on
def _handle(self, logs: Dict, when: str):
raise NotImplementedError
def on_batch_begin(self, batch, logs=None):
if "batch_begin" in self._on:
self._handle(logs, "batch_begin")
def on_batch_end(self, batch, logs=None):
if "batch_end" in self._on:
self._handle(logs, "batch_end")
def on_epoch_begin(self, epoch, logs=None):
if "epoch_begin" in self._on:
self._handle(logs, "epoch_begin")
def on_epoch_end(self, epoch, logs=None):
if "epoch_end" in self._on:
self._handle(logs, "epoch_end")
def on_train_batch_begin(self, batch, logs=None):
if "train_batch_begin" in self._on:
self._handle(logs, "train_batch_begin")
def on_train_batch_end(self, batch, logs=None):
if "train_batch_end" in self._on:
self._handle(logs, "train_batch_end")
def on_test_batch_begin(self, batch, logs=None):
if "test_batch_begin" in self._on:
self._handle(logs, "test_batch_begin")
def on_test_batch_end(self, batch, logs=None):
if "test_batch_end" in self._on:
self._handle(logs, "test_batch_end")
def on_predict_batch_begin(self, batch, logs=None):
if "predict_batch_begin" in self._on:
self._handle(logs, "predict_batch_begin")
def on_predict_batch_end(self, batch, logs=None):
if "predict_batch_end" in self._on:
self._handle(logs, "predict_batch_end")
def on_train_begin(self, logs=None):
if "train_begin" in self._on:
self._handle(logs, "train_begin")
def on_train_end(self, logs=None):
if "train_end" in self._on:
self._handle(logs, "train_end")
def on_test_begin(self, logs=None):
if "test_begin" in self._on:
self._handle(logs, "test_begin")
def on_test_end(self, logs=None):
if "test_end" in self._on:
self._handle(logs, "test_end")
def on_predict_begin(self, logs=None):
if "predict_begin" in self._on:
self._handle(logs, "predict_begin")
def on_predict_end(self, logs=None):
if "predict_end" in self._on:
self._handle(logs, "predict_end")
@PublicAPI(stability="beta")
class Callback(_Callback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "epoch_end",
frequency: Union[int, List[int]] = 1,
):
"""
Args:
metrics: Metrics to report. If this is a list, each item describes
the metric key reported to Keras, and it will reported under the
same name. If this is a dict, each key will be the name reported
and the respective value will be the metric key reported to Keras.
If this is None, all Keras logs will be reported.
on: When to report metrics. Must be one of
the Keras event hooks (less the ``on_``), e.g.
"train_start", or "predict_end". Defaults to "epoch_end".
frequency: Checkpoint frequency. If this is an integer `n`,
checkpoints are saved every `n` times each hook was called. If
this is a list, it specifies the checkpoint frequencies for each
hook individually.
You can use this in both TuneSession and TrainSession.
Example:
.. code-block: python
############# Using it in TrainSession ###############
from ray.air.callbacks.keras import Callback
def train_loop_per_worker():
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = build_model()
#model.compile(...)
model.fit(dataset_shard, callbacks=[Callback()])
"""
if isinstance(frequency, list):
if not isinstance(on, list) or len(frequency) != len(on):
raise ValueError(
"If you pass a list for checkpoint frequencies, the `on` "
"parameter has to be a list with the same length."
)
self._frequency = frequency
super(Callback, self).__init__(on)
self._metrics = metrics
self._counter = Counter()
def _handle(self, logs: Dict, when: str = None):
self._counter[when] += 1
if isinstance(self._frequency, list):
index = self._on.index(when)
freq = self._frequency[index]
else:
freq = self._frequency
checkpoint = None
if freq > 0 and self._counter[when] % freq == 0:
self.model.save("my_model", overwrite=True)
checkpoint = Checkpoint.from_directory("my_model")
if not self._metrics:
report_dict = logs
else:
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = logs[metric]
session.report(report_dict, checkpoint=checkpoint)

258
python/ray/air/session.py Normal file
View file

@ -0,0 +1,258 @@
from typing import TYPE_CHECKING, Dict, Optional, Union
from ray.air._internal.session import _get_session
from ray.air.checkpoint import Checkpoint
from ray.train.session import _TrainSessionImpl
if TYPE_CHECKING:
from ray.data import Dataset, DatasetPipeline
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save a checkpoint.
Each invocation of this method will automatically increment the underlying
iteration number. The physical meaning of this "iteration" is defined by
user (or more specifically the way they call ``report``).
It does not necessarily map to one epoch.
This API is the canonical way to report metrics from Tune and Train, and
replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``,
``train.report`` and ``train.save_checkpoint`` calls.
Note on directory checkpoints: AIR will take ownership of checkpoints passed
to ``report()`` by moving them to a new path. The original directory will no
longer be accessible to the caller after the report call.
Example:
.. code-block: python
from ray.air import session
from ray.air.checkpoint import Checkpoint
######## Using it in the *per worker* train loop (TrainSession) #######
def train_func():
model = build_model()
model.save("my_model", overwrite=True)
session.report(
metrics={"foo": "bar"},
checkpoint=Checkpoint.from_directory(temp_dir.name)
)
# Air guarantees by this point, you can safely write new stuff to
# "my_model" directory.
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
# If you navigate to result.checkpoint's path, you will find the
content of ``model.save()`` under it.
# If you have `SyncConfig` configured, the content should also
# show up in the corresponding cloud storage path.
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""
_get_session().report(metrics, checkpoint=checkpoint)
def get_checkpoint() -> Optional[Checkpoint]:
"""Access the session's last checkpoint to resume from if applicable.
Returns:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
Example:
.. code-block: python
######## Using it in the *per worker* train loop (TrainSession) ######
from ray.air import session
from ray.air.checkpoint import Checkpoint
def train_func():
if session.get_checkpoint():
with session.get_checkpoint().as_directory() as
loaded_checkpoint_dir:
import tensorflow as tf
model = tf.keras.models.load_model(loaded_checkpoint_dir)
else:
model = build_model()
model.save("my_model", overwrite=True)
session.report(
metrics={"iter": 1},
checkpoint=Checkpoint.from_directory("my_model")
)
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
# trainer2 will pick up from the checkpoint saved by trainer1.
trainer2 = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
# this is ultimately what is accessed through
# ``Session.get_checkpoint()``
resume_from_checkpoint=result.checkpoint,
)
result2 = trainer2.fit()
"""
return _get_session().loaded_checkpoint
def get_trial_name() -> str:
"""Trial name for the corresponding trial."""
return _get_session().trial_name
def get_trial_id() -> str:
"""Trial id for the corresponding trial."""
return _get_session().trial_id
def get_trial_resources() -> Dict[str, float]:
"""Trial resources for the corresponding trial."""
return _get_session().trial_resources
def get_world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
.. code-block:: python
import time
from ray.air import session
def train_loop_per_worker(config):
assert session.get_world_size() == 4
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config={"num_workers": 1},
datasets={"train": train_dataset})
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.world_size
def get_world_rank() -> int:
"""Get the world rank of this worker.
.. code-block:: python
import time
from ray.air import session
def train_loop_per_worker():
for iter in range(100):
time.sleep(1)
if session.get_world_rank() == 0:
print("Worker 0")
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config={"num_workers": 1},
datasets={"train": train_dataset})
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_world_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.world_rank
def get_local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. code-block:: python
import time
from ray.air import session
def train_loop_per_worker():
if torch.cuda.is_available():
torch.cuda.set_device(session.get_local_rank())
...
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config={"num_workers": 1},
datasets={"train": train_dataset})
trainer.fit()
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_local_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.local_rank
def get_dataset_shard(
dataset_name: Optional[str] = None,
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
"""Returns the Ray Dataset or DatasetPipeline shard for this worker.
You should call ``to_torch()`` or ``to_tf()`` on this shard to convert
it to the appropriate framework-specific Dataset.
.. code-block:: python
import ray
from ray import train
from ray.air import session
def train_loop_per_worker():
model = Net()
for iter in range(100):
# Trainer will automatically handle sharding.
data_shard = session.get_dataset_shard().to_torch()
model.train(data_shard)
return model
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
scaling_config={"num_workers": 2},
datasets={"train": train_dataset})
trainer.fit()
Args:
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
specifies which dataset shard to return.
Returns:
The ``Dataset`` or ``DatasetPipeline`` shard to use for this worker.
If no dataset is passed into Trainer, then return None.
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_dataset_shard` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.get_dataset_shard(dataset_name)

View file

@ -0,0 +1,64 @@
import os
import tensorflow as tf
from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.air.examples.tf.tensorflow_linear_dataset_example import (
build_model,
get_dataset,
)
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.tensorflow import TensorflowTrainer, prepare_dataset_shard
def train_func(config: dict):
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_model()
multi_worker_model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_squared_error],
)
dataset = session.get_dataset_shard("train")
for _ in range(config.get("epoch", 3)):
tf_dataset = prepare_dataset_shard(
dataset.to_tf(
label_column="y",
output_signature=(
tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
tf.TensorSpec(shape=(None), dtype=tf.float32),
),
batch_size=32,
)
)
multi_worker_model.fit(tf_dataset, callbacks=[Callback()])
def test_keras_callback():
epochs = 3
scaling_config = {"num_workers": 2}
config = {
"epochs": epochs,
}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=scaling_config,
datasets={TRAIN_DATASET_KEY: get_dataset()},
)
checkpoint = trainer.fit().checkpoint
with checkpoint.as_directory() as ckpt_dir:
assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb"))
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -1,23 +1,29 @@
import logging import logging
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Callable, List, Optional, Dict, Type, Tuple, TypeVar from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
import ray import ray
from ray.air.checkpoint import Checkpoint
from ray.exceptions import RayActorError from ray.exceptions import RayActorError
from ray.ray_constants import env_integer from ray.ray_constants import env_integer
from ray.train._internal.dataset_spec import RayDatasetSpec
from ray.train._internal.session import (
TrainingResult,
TrialInfo,
get_session,
init_session,
shutdown_session,
)
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import BackendConfig
from ray.train.constants import ( from ray.train.constants import (
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV, TRAIN_ENABLE_WORKER_SPREAD_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
) )
from ray.train.backend import BackendConfig
from ray.train._internal.dataset_spec import RayDatasetSpec
from ray.train._internal.session import TrainingResult
from ray.train._internal.session import init_session, get_session, shutdown_session
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup
from ray.util.placement_group import get_current_placement_group, remove_placement_group from ray.util.placement_group import get_current_placement_group, remove_placement_group
T = TypeVar("T") T = TypeVar("T")
@ -57,6 +63,8 @@ class BackendExecutor:
def __init__( def __init__(
self, self,
backend_config: BackendConfig, backend_config: BackendConfig,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
trial_info: Optional[TrialInfo] = None,
num_workers: int = 1, num_workers: int = 1,
num_cpus_per_worker: float = 1, num_cpus_per_worker: float = 1,
num_gpus_per_worker: float = 0, num_gpus_per_worker: float = 0,
@ -76,6 +84,8 @@ class BackendExecutor:
self._initialization_hook = None self._initialization_hook = None
self._placement_group = None self._placement_group = None
self._trial_info = trial_info
self.worker_group = InactiveWorkerGroup() self.worker_group = InactiveWorkerGroup()
self.dataset_shards = None self.dataset_shards = None
@ -264,7 +274,7 @@ class BackendExecutor:
self, self,
train_func: Callable[[], T], train_func: Callable[[], T],
dataset_spec: RayDatasetSpec, dataset_spec: RayDatasetSpec,
checkpoint: Optional[Dict] = None, checkpoint: Optional[Checkpoint] = None,
) -> None: ) -> None:
"""Executes a training function on all workers in a separate thread. """Executes a training function on all workers in a separate thread.
@ -290,6 +300,7 @@ class BackendExecutor:
world_rank, world_rank,
local_rank, local_rank,
world_size, world_size,
trial_info,
checkpoint, checkpoint,
dataset_shard, dataset_shard,
encode_data_fn, encode_data_fn,
@ -300,6 +311,7 @@ class BackendExecutor:
world_rank=world_rank, world_rank=world_rank,
local_rank=local_rank, local_rank=local_rank,
world_size=world_size, world_size=world_size,
trial_info=trial_info,
dataset_shard=dataset_shard, dataset_shard=dataset_shard,
checkpoint=checkpoint, checkpoint=checkpoint,
encode_data_fn=encode_data_fn, encode_data_fn=encode_data_fn,
@ -328,6 +340,7 @@ class BackendExecutor:
world_rank=index, world_rank=index,
local_rank=local_rank_map[index], local_rank=local_rank_map[index],
world_size=len(self.worker_group), world_size=len(self.worker_group),
trial_info=self._trial_info,
train_func=train_func, train_func=train_func,
dataset_shard=self.dataset_shards[index], dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint, checkpoint=checkpoint,

View file

@ -1,22 +1,21 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Union, Callable from typing import Callable, Dict, List, Optional, Union
from ray.air import Checkpoint from ray.air import Checkpoint
from ray.train._internal.session import TrainingResult
from ray.train._internal.utils import construct_path
from ray.train.constants import ( from ray.train.constants import (
TIMESTAMP, TIMESTAMP,
TRAIN_CHECKPOINT_SUBDIR, TRAIN_CHECKPOINT_SUBDIR,
TUNE_CHECKPOINT_ID, TUNE_CHECKPOINT_ID,
TUNE_INSTALLED, TUNE_INSTALLED,
) )
from ray.train._internal.session import TrainingResult from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, CheckpointStrategy
from ray.train._internal.utils import construct_path
from ray.util.ml_utils.checkpoint_manager import ( from ray.util.ml_utils.checkpoint_manager import (
_CheckpointManager as CommonCheckpointManager, _CheckpointManager as CommonCheckpointManager,
_TrackedCheckpoint,
CheckpointStrategy,
CheckpointStorage,
) )
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint
if TUNE_INSTALLED: if TUNE_INSTALLED:
from ray import tune from ray import tune
@ -80,14 +79,17 @@ class CheckpointManager(CommonCheckpointManager):
if self._checkpoint_strategy.checkpoint_score_attribute is None: if self._checkpoint_strategy.checkpoint_score_attribute is None:
self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP
# TODO(xwjiang): Legacy Ray Train trainer clean up!
def _load_checkpoint( def _load_checkpoint(
self, checkpoint_to_load: Optional[Union[Dict, str, Path]] self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
) -> Optional[Dict]: ) -> Optional[Union[Dict, Checkpoint]]:
"""Load the checkpoint dictionary from the input dict or path.""" """Load the checkpoint dictionary from the input dict or path."""
if checkpoint_to_load is None: if checkpoint_to_load is None:
return None return None
if isinstance(checkpoint_to_load, Dict): if isinstance(checkpoint_to_load, Dict):
return checkpoint_to_load return checkpoint_to_load
if isinstance(checkpoint_to_load, Checkpoint):
return checkpoint_to_load
else: else:
# Load checkpoint from path. # Load checkpoint from path.
return load_checkpoint_from_path(checkpoint_to_load) return load_checkpoint_from_path(checkpoint_to_load)
@ -198,9 +200,17 @@ class CheckpointManager(CommonCheckpointManager):
class TuneCheckpointManager(CheckpointManager): class TuneCheckpointManager(CheckpointManager):
def _load_checkpoint( def _load_checkpoint(
self, checkpoint_to_load: Optional[Union[Dict, str, Path]] self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
) -> Optional[Dict]: ) -> Optional[Union[Dict, Checkpoint]]:
# TODO(xwjiang): Legacy Ray Train trainer clean up!
loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load) loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load)
# New path...
if isinstance(loaded_checkpoint, Checkpoint):
# The new logic
checkpoint_dict = loaded_checkpoint.to_dict()
self._latest_checkpoint_id = checkpoint_dict[TUNE_CHECKPOINT_ID]
return loaded_checkpoint
# legacy path...
if loaded_checkpoint is not None: if loaded_checkpoint is not None:
# If the Tune trial is restarted, a new Trainer is instantiated. # If the Tune trial is restarted, a new Trainer is instantiated.
# However, we want the checkpoint_id to continue incrementing # However, we want the checkpoint_id to continue incrementing

View file

@ -3,29 +3,30 @@ import platform
import queue import queue
import threading import threading
import time import time
from datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
from typing import Callable from typing import Callable, Dict, Optional, Type, Union
from typing import Optional, Dict, Type, Union
import ray import ray
from ray.air.checkpoint import Checkpoint
from ray.data import Dataset, DatasetPipeline from ray.data import Dataset, DatasetPipeline
from ray.train._internal.accelerator import Accelerator from ray.train._internal.accelerator import Accelerator
from ray.train.constants import (
DETAILED_AUTOFILLED_KEYS,
TIME_THIS_ITER_S,
PID,
TIMESTAMP,
TIME_TOTAL_S,
NODE_IP,
TRAINING_ITERATION,
HOSTNAME,
DATE,
RESULT_FETCH_TIMEOUT,
)
from ray.train._internal.utils import PropagatingThread from ray.train._internal.utils import PropagatingThread
from ray.train.constants import (
DATE,
DETAILED_AUTOFILLED_KEYS,
HOSTNAME,
NODE_IP,
PID,
RESULT_FETCH_TIMEOUT,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
TIMESTAMP,
TRAINING_ITERATION,
)
from ray.train.error import SessionMisuseError from ray.train.error import SessionMisuseError
from ray.train.session import _TrainSessionImpl
class TrainingResultType(Enum): class TrainingResultType(Enum):
@ -33,13 +34,24 @@ class TrainingResultType(Enum):
CHECKPOINT = auto() CHECKPOINT = auto()
@dataclass
class TrialInfo:
"""The trial information to propagate to TrainSession."""
name: str
id: str
resources: Dict[str, float]
logdir: str
@dataclass @dataclass
class TrainingResult: class TrainingResult:
type: TrainingResultType type: TrainingResultType
data: Dict data: Dict
class Session: # TODO(xwjiang): This needs a better name.
class _TrainSession:
"""Holds information for training on each worker.""" """Holds information for training on each worker."""
def __init__( def __init__(
@ -48,8 +60,11 @@ class Session:
world_rank: int, world_rank: int,
local_rank: int, local_rank: int,
world_size: int, world_size: int,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None, dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
checkpoint: Optional[Dict] = None, # TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
encode_data_fn: Callable = None, encode_data_fn: Callable = None,
detailed_autofilled_metrics: bool = False, detailed_autofilled_metrics: bool = False,
): ):
@ -61,7 +76,9 @@ class Session:
self.world_rank = world_rank self.world_rank = world_rank
self.local_rank = local_rank self.local_rank = local_rank
self.world_size = world_size self.world_size = world_size
self.loaded_checkpoint = checkpoint self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint
# Function to encode checkpoint dict before sending to the driver. # Function to encode checkpoint dict before sending to the driver.
if not encode_data_fn: if not encode_data_fn:
@ -72,6 +89,13 @@ class Session:
encode_data_fn = noop encode_data_fn = noop
self._encode_data_fn = encode_data_fn self._encode_data_fn = encode_data_fn
# TODO(xwjiang): Legacy Ray Train trainer clean up!
if trial_info:
# Change the working directory to `logdir`.
logdir = os.path.join(trial_info.logdir, f"rank_{self.world_rank}")
os.makedirs(logdir, exist_ok=True)
os.chdir(logdir)
# This lock is used to control the execution of the training thread. # This lock is used to control the execution of the training thread.
self.continue_lock = threading.Semaphore(0) self.continue_lock = threading.Semaphore(0)
@ -184,7 +208,7 @@ class Session:
result.update(auto_filled_metrics) result.update(auto_filled_metrics)
return result return result
def report(self, **kwargs): def _report_legacy(self, **kwargs):
"""Adds kwargs to the queue to be consumed by main thread.""" """Adds kwargs to the queue to be consumed by main thread."""
if self.ignore_report: if self.ignore_report:
return return
@ -234,22 +258,32 @@ class Session:
# checkpoint has been processed. # checkpoint has been processed.
self.continue_lock.acquire() self.continue_lock.acquire()
def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.
if checkpoint:
checkpoint_dict = checkpoint.to_dict()
self.checkpoint(**checkpoint_dict)
self._report_legacy(**metrics)
_session = None
_session: Optional[_TrainSession] = None
# V2 Session API
_session_v2: Optional[_TrainSessionImpl] = None
def init_session(*args, **kwargs) -> None: def init_session(*args, **kwargs) -> None:
global _session global _session
global _session_v2
if _session: if _session:
raise ValueError( raise ValueError(
"A Train session is already in use. Do not call " "A Train session is already in use. Do not call "
"`init_session()` manually." "`init_session()` manually."
) )
_session = Session(*args, **kwargs) _session = _TrainSession(*args, **kwargs)
_session_v2 = _TrainSessionImpl(session=_session)
def get_session() -> Optional[Session]: def get_session() -> Optional[_TrainSession]:
global _session
return _session return _session

View file

@ -1,33 +1,22 @@
import inspect import inspect
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import ( from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
Any,
Callable,
Dict,
Optional,
Tuple,
Union,
Type,
TYPE_CHECKING,
)
import ray import ray
from ray import tune from ray import tune
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.air import session
from ray.train.constants import (
TRAIN_DATASET_KEY,
WILDCARD_KEY,
)
from ray.train.trainer import BaseTrainer
from ray.air.config import ScalingConfig, RunConfig, DatasetConfig
from ray.train.trainer import GenDataset
from ray.air.checkpoint import Checkpoint from ray.air.checkpoint import Checkpoint
from ray.train._internal.dataset_spec import DataParallelIngestSpec from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train import BackendConfig, TrainingIterator from ray.train import BackendConfig, TrainingIterator
from ray.train._internal.backend_executor import BackendExecutor from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
from ray.train._internal.checkpoint import TuneCheckpointManager from ray.train._internal.checkpoint import TuneCheckpointManager
from ray.train._internal.dataset_spec import DataParallelIngestSpec
from ray.train._internal.utils import construct_train_func from ray.train._internal.utils import construct_train_func
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, _TrackedCheckpoint from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, _TrackedCheckpoint
@ -335,8 +324,16 @@ class DataParallelTrainer(BaseTrainer):
scaling_config_dataclass.additional_resources_per_worker scaling_config_dataclass.additional_resources_per_worker
) )
trial_info = TrialInfo(
name=session.get_trial_name(),
id=session.get_trial_id(),
resources=session.get_trial_resources(),
logdir=os.getcwd(),
)
backend_executor = BackendExecutor( backend_executor = BackendExecutor(
backend_config=self._backend_config, backend_config=self._backend_config,
trial_info=trial_info,
num_workers=scaling_config_dataclass.num_workers, num_workers=scaling_config_dataclass.num_workers,
num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker, num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker,
num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker, num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker,
@ -351,20 +348,13 @@ class DataParallelTrainer(BaseTrainer):
# Start the remote actors. # Start the remote actors.
backend_executor.start(initialization_hook=None) backend_executor.start(initialization_hook=None)
if self.resume_from_checkpoint:
resume_checkpoint_dict = self.resume_from_checkpoint.to_dict()
else:
resume_checkpoint_dict = None
# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead
# of just a Dict.
training_iterator = TrainingIterator( training_iterator = TrainingIterator(
backend_executor=backend_executor, backend_executor=backend_executor,
backend_config=self._backend_config, backend_config=self._backend_config,
train_func=train_loop_per_worker, train_func=train_loop_per_worker,
dataset_spec=self._ingest_spec, dataset_spec=self._ingest_spec,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
checkpoint=resume_checkpoint_dict, checkpoint=self.resume_from_checkpoint,
checkpoint_strategy=None, checkpoint_strategy=None,
) )

View file

@ -16,6 +16,7 @@ from torch.utils.data import Dataset as TorchDataset
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from ray import train from ray import train
from ray.air import session
from ray.air._internal.checkpointing import ( from ray.air._internal.checkpointing import (
load_preprocessor_from_dir, load_preprocessor_from_dir,
save_preprocessor_to_dir, save_preprocessor_to_dir,
@ -517,12 +518,14 @@ def _huggingface_train_loop_per_worker(config):
trainer.add_callback(TrainReportCallback) trainer.add_callback(TrainReportCallback)
checkpoint = train.load_checkpoint() checkpoint = session.get_checkpoint()
checkpoint_path = None checkpoint_path = None
remove_checkpoint_path = False remove_checkpoint_path = False
if checkpoint: if checkpoint:
source_ip = checkpoint[NODE_IP_KEY] assert isinstance(checkpoint, Checkpoint)
source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] checkpoint_dict = checkpoint.to_dict()
source_ip = checkpoint_dict[NODE_IP_KEY]
source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
target_ip = get_node_ip_address() target_ip = get_node_ip_address()
if source_ip == target_ip: if source_ip == target_ip:
checkpoint_path = source_path checkpoint_path = source_path

View file

@ -0,0 +1,79 @@
import warnings
from typing import TYPE_CHECKING, Dict, Optional, Union
from ray.air._internal.session import Session
from ray.air.checkpoint import Checkpoint
if TYPE_CHECKING:
# avoid circular import
from ray.data import Dataset, DatasetPipeline
from ray.train._internal.session import _TrainSession
class _TrainSessionImpl(Session):
"""Session client that "per worker train loop" can interact with.
Notice that each worker will automatically switch to its working
directory on entering the train loop. This is to ensure that
each worker can safely write to a local directory without racing
and overwriting each other."""
def __init__(self, session: "_TrainSession"):
self._session = session
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
self._session.report(metrics, checkpoint)
@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
ckpt = self._session.loaded_checkpoint
if ckpt:
# The new API should only interact with Checkpoint object.
assert isinstance(ckpt, Checkpoint)
return ckpt
@property
def trial_name(self) -> str:
return self._session.trial_info.name
@property
def trial_id(self) -> str:
return self._session.trial_info.id
@property
def trial_resources(self) -> Dict[str, float]:
return self._session.trial_info.resources
@property
def world_size(self) -> int:
return self._session.world_size
@property
def world_rank(self) -> int:
return self._session.world_rank
@property
def local_rank(self) -> int:
return self._session.local_rank
def get_dataset_shard(
self,
dataset_name: Optional[str] = None,
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
shard = self._session.dataset_shard
if shard is None:
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Ray Dataset to Trainer.run to use this "
"function."
)
elif isinstance(shard, dict):
if not dataset_name:
raise RuntimeError(
"Multiple datasets were passed into ``Trainer``, "
"but no ``dataset_name`` is passed into "
"``get_dataset_shard``. Please specify which "
"dataset shard to retrieve."
)
return shard.get(dataset_name)
return shard

View file

@ -1,11 +1,11 @@
import pytest import pytest
import ray import ray
from ray import train, tune from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint from ray.air.checkpoint import Checkpoint
from ray.train.constants import PREPROCESSOR_KEY
from ray.data.preprocessor import Preprocessor from ray.data.preprocessor import Preprocessor
from ray.train.constants import PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import DataParallelTrainer from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune.tune_config import TuneConfig from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner from ray.tune.tuner import Tuner
@ -24,7 +24,7 @@ scale_config = {"num_workers": 2}
def test_fit_train(ray_start_4_cpus): def test_fit_train(ray_start_4_cpus):
def train_func(): def train_func():
train.report(loss=1) session.report({"loss": 1})
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
train_loop_per_worker=train_func, scaling_config=scale_config train_loop_per_worker=train_func, scaling_config=scale_config
@ -35,7 +35,7 @@ def test_fit_train(ray_start_4_cpus):
def test_scaling_config(ray_start_4_cpus): def test_scaling_config(ray_start_4_cpus):
def train_func(): def train_func():
assert ray.available_resources()["CPU"] == 1 assert ray.available_resources()["CPU"] == 1
train.report(loss=1) session.report({"loss": 1})
assert ray.available_resources()["CPU"] == 4 assert ray.available_resources()["CPU"] == 4
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
@ -46,7 +46,7 @@ def test_scaling_config(ray_start_4_cpus):
def test_fit_train_config(ray_start_4_cpus): def test_fit_train_config(ray_start_4_cpus):
def train_func(config): def train_func(config):
train.report(loss=config["x"]) session.report({"loss": config["x"]})
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
train_loop_per_worker=train_func, train_loop_per_worker=train_func,
@ -65,10 +65,10 @@ def test_datasets(ray_start_4_cpus):
def get_dataset(): def get_dataset():
# Train dataset should be sharded. # Train dataset should be sharded.
train_dataset = train.get_dataset_shard("train") train_dataset = session.get_dataset_shard("train")
assert train_dataset.count() == num_train_data / scale_config["num_workers"] assert train_dataset.count() == num_train_data / scale_config["num_workers"]
# All other datasets should not be sharded. # All other datasets should not be sharded.
val_dataset = train.get_dataset_shard("val") val_dataset = session.get_dataset_shard("val")
assert val_dataset.count() == num_val_data assert val_dataset.count() == num_val_data
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
@ -82,7 +82,7 @@ def test_datasets(ray_start_4_cpus):
def test_checkpoint(ray_start_4_cpus): def test_checkpoint(ray_start_4_cpus):
def train_func(): def train_func():
for i in range(3): for i in range(3):
train.save_checkpoint(model=i) session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i}))
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
train_loop_per_worker=train_func, scaling_config=scale_config train_loop_per_worker=train_func, scaling_config=scale_config
@ -99,7 +99,7 @@ def test_preprocessor_in_checkpoint(ray_start_4_cpus):
def train_func(): def train_func():
for i in range(3): for i in range(3):
train.save_checkpoint(model=i) session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i}))
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
train_loop_per_worker=train_func, train_loop_per_worker=train_func,
@ -113,13 +113,13 @@ def test_preprocessor_in_checkpoint(ray_start_4_cpus):
def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir): def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
def train_func(): def train_func():
checkpoint = train.load_checkpoint() checkpoint = session.get_checkpoint()
if checkpoint: if checkpoint:
epoch = checkpoint["epoch"] epoch = checkpoint.to_dict()["epoch"]
else: else:
epoch = 0 epoch = 0
for i in range(epoch, epoch + 2): for i in range(epoch, epoch + 2):
train.save_checkpoint(epoch=i) session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"epoch": i}))
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
train_loop_per_worker=train_func, scaling_config=scale_config train_loop_per_worker=train_func, scaling_config=scale_config
@ -152,7 +152,7 @@ def test_invalid_train_loop(ray_start_4_cpus):
def test_tune(ray_start_4_cpus): def test_tune(ray_start_4_cpus):
def train_func(config): def train_func(config):
train.report(loss=config["x"]) session.report({"loss": config["x"]})
trainer = DataParallelTrainer( trainer = DataParallelTrainer(
train_loop_per_worker=train_func, train_loop_per_worker=train_func,
@ -173,7 +173,8 @@ def test_tune(ray_start_4_cpus):
if __name__ == "__main__": if __name__ == "__main__":
import pytest
import sys import sys
import pytest
sys.exit(pytest.main(["-v", "-x", __file__])) sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -1,8 +1,12 @@
import os
import numpy as np import numpy as np
import pytest import pytest
import ray import ray
from ray import train from ray import train
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.examples.tf.tensorflow_linear_dataset_example import get_dataset from ray.air.examples.tf.tensorflow_linear_dataset_example import get_dataset
from ray.air.examples.tf.tensorflow_linear_dataset_example import ( from ray.air.examples.tf.tensorflow_linear_dataset_example import (
train_func as tensorflow_linear_train_func, train_func as tensorflow_linear_train_func,
@ -34,6 +38,8 @@ def build_model():
@pytest.mark.parametrize("num_workers", [1, 2]) @pytest.mark.parametrize("num_workers", [1, 2])
def test_tensorflow_linear(ray_start_4_cpus, num_workers): def test_tensorflow_linear(ray_start_4_cpus, num_workers):
"""Also tests air Keras callback."""
def train_func(config): def train_func(config):
result = tensorflow_linear_train_func(config) result = tensorflow_linear_train_func(config)
assert len(result) == epochs assert len(result) == epochs
@ -83,6 +89,39 @@ def test_tensorflow_e2e(ray_start_4_cpus):
assert predictions.count() == 3 assert predictions.count() == 3
def test_report_and_load_using_ml_session(ray_start_4_cpus):
def train_func():
if session.get_checkpoint():
with session.get_checkpoint().as_directory() as checkpoint_dir:
import tensorflow as tf
model = tf.keras.models.load_model(checkpoint_dir)
else:
model = build_model()
model.save("my_model", overwrite=True)
session.report(
metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model")
)
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
trainer2 = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
resume_from_checkpoint=result.checkpoint,
)
result = trainer2.fit()
checkpoint = result.checkpoint
with checkpoint.as_directory() as ckpt_dir:
assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb"))
assert result.metrics["iter"] == 1
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys

View file

@ -1,11 +1,8 @@
from typing import TYPE_CHECKING
from typing import Optional, Dict, Union
import warnings import warnings
from typing import TYPE_CHECKING, Dict, Optional, Union
from ray.train._internal.session import get_session
from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train._internal.session import (
get_session,
)
from ray.util import PublicAPI, log_once from ray.util import PublicAPI, log_once
if TYPE_CHECKING: if TYPE_CHECKING:
@ -118,7 +115,7 @@ def report(**kwargs) -> None:
if session is None: if session is None:
_warn_session_misuse(report.__name__) _warn_session_misuse(report.__name__)
return return
session.report(**kwargs) session._report_legacy(**kwargs)
@PublicAPI(stability="beta") @PublicAPI(stability="beta")

View file

@ -1,23 +1,14 @@
import copy import copy
from datetime import datetime
import logging import logging
import os import os
from pathlib import Path
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, Type
import warnings import warnings
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
import ray import ray
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.train.backend import ( from ray.air.checkpoint import Checkpoint
BackendConfig,
)
from ray.train.callbacks.callback import TrainingCallback
from ray.train._internal.dataset_spec import RayDataset, RayDatasetSpec
from ray.train._internal.session import TrainingResultType
from ray.train._internal.utils import (
construct_train_func,
ActorWrapper,
)
from ray.train._internal.backend_executor import ( from ray.train._internal.backend_executor import (
BackendExecutor, BackendExecutor,
InactiveWorkerGroupError, InactiveWorkerGroupError,
@ -25,35 +16,37 @@ from ray.train._internal.backend_executor import (
TrainingWorkerError, TrainingWorkerError,
) )
from ray.train._internal.checkpoint import ( from ray.train._internal.checkpoint import (
TuneCheckpointManager,
CheckpointManager, CheckpointManager,
TuneCheckpointManager,
load_checkpoint_from_path, load_checkpoint_from_path,
) )
from ray.train.constants import ( from ray.train._internal.dataset_spec import RayDataset, RayDatasetSpec
TUNE_INSTALLED, from ray.train._internal.session import TrainingResultType
DEFAULT_RESULTS_DIR,
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
)
# Ray Train should be usable even if Tune is not installed. # Ray Train should be usable even if Tune is not installed.
from ray.train._internal.utils import construct_path from ray.train._internal.utils import ActorWrapper, construct_path, construct_train_func
from ray.train._internal.worker_group import WorkerGroup from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI, Deprecated from ray.train.backend import BackendConfig
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
from ray.train.base_trainer import ( # noqa: F401 from ray.train.base_trainer import ( # noqa: F401
BaseTrainer, BaseTrainer,
GenDataset, GenDataset,
TrainingFailedError, TrainingFailedError,
) )
from ray.train.callbacks.callback import TrainingCallback
from ray.train.constants import (
DEFAULT_RESULTS_DIR,
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TUNE_INSTALLED,
)
from ray.util.annotations import Deprecated, DeveloperAPI
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
if TUNE_INSTALLED: if TUNE_INSTALLED:
from ray import tune from ray import tune
from ray.tune import Trainable from ray.tune import PlacementGroupFactory, Trainable
from ray.tune import PlacementGroupFactory
from ray.tune.function_runner import wrap_function from ray.tune.function_runner import wrap_function
else: else:
tune = PlacementGroupFactory = Trainable = object tune = PlacementGroupFactory = Trainable = object
@ -676,7 +669,7 @@ class TrainingIterator:
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]], train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
dataset_spec: RayDatasetSpec, dataset_spec: RayDatasetSpec,
checkpoint_manager: CheckpointManager, checkpoint_manager: CheckpointManager,
checkpoint: Optional[Union[Dict, str, Path]], checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
checkpoint_strategy: Optional[CheckpointStrategy], checkpoint_strategy: Optional[CheckpointStrategy],
run_dir: Optional[Path] = None, run_dir: Optional[Path] = None,
): ):
@ -715,12 +708,12 @@ class TrainingIterator:
run_dir=run_dir, run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id, latest_checkpoint_id=latest_checkpoint_id,
) )
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint) checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)
self._run_with_error_handling( self._run_with_error_handling(
lambda: self._backend_executor.start_training( lambda: self._backend_executor.start_training(
train_func=train_func, train_func=train_func,
dataset_spec=dataset_spec, dataset_spec=dataset_spec,
checkpoint=checkpoint_dict, checkpoint=checkpoint,
) )
) )

View file

@ -1,34 +1,34 @@
import inspect
import logging import logging
import os import os
import sys
import time
import inspect
import shutil import shutil
import sys
import threading import threading
import time
import uuid import uuid
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from ray.util.annotations import DeveloperAPI
from six.moves import queue from six.moves import queue
from ray.util.debug import log_once from ray.air.checkpoint import Checkpoint
from ray.tune import TuneError, session from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import ( from ray.tune.result import (
DEFAULT_METRIC, DEFAULT_METRIC,
TIME_THIS_ITER_S,
RESULT_DUPLICATE, RESULT_DUPLICATE,
SHOULD_CHECKPOINT, SHOULD_CHECKPOINT,
TIME_THIS_ITER_S,
) )
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.utils import ( from ray.tune.utils import (
detect_checkpoint_function, detect_checkpoint_function,
detect_config_single, detect_config_single,
detect_reporter, detect_reporter,
) )
from ray.tune.utils.trainable import with_parameters # noqa: F401 from ray.tune.utils.trainable import with_parameters # noqa: F401
from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -123,8 +123,6 @@ class FuncCheckpointUtil:
class _StatusReporter: class _StatusReporter:
"""Object passed into your function that you can report status through."""
def __init__( def __init__(
self, self,
result_queue, result_queue,
@ -145,6 +143,10 @@ class _StatusReporter:
self._last_checkpoint = None self._last_checkpoint = None
self._fresh_checkpoint = False self._fresh_checkpoint = False
self._trial_resources = trial_resources self._trial_resources = trial_resources
# Also used as a marker of whether new `report()` API is being used,
# in which case, `_iter` will be incremented from 0 every time `report`
# is called.
self._iter = None
def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None): def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None):
self._trial_name = trial_name self._trial_name = trial_name
@ -153,6 +155,7 @@ class _StatusReporter:
self._last_checkpoint = None self._last_checkpoint = None
self._fresh_checkpoint = False self._fresh_checkpoint = False
self._trial_resources = trial_resources self._trial_resources = trial_resources
self._iter = None
def __call__(self, _metric=None, **kwargs): def __call__(self, _metric=None, **kwargs):
"""Report updated training status. """Report updated training status.
@ -168,7 +171,7 @@ class _StatusReporter:
""" """
assert self._last_report_time is not None, ( assert self._last_report_time is not None, (
"StatusReporter._start() must be called before the first " "_StatusReporter._start() must be called before the first "
"report __call__ is made to ensure correct runtime metrics." "report __call__ is made to ensure correct runtime metrics."
) )
@ -229,6 +232,26 @@ class _StatusReporter:
def _start(self): def _start(self):
self._last_report_time = time.time() self._last_report_time = time.time()
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): Tons of optimizations.
if not self._iter:
self._iter = 0
if checkpoint:
checkpoint_dir = self.make_checkpoint_dir(step=self._iter)
self.set_checkpoint(checkpoint_dir)
checkpoint.to_directory(checkpoint_dir)
# TODO(krfricke): Remove this once support is added in Checkpoint.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
self.__call__(**metrics)
self._iter += 1
@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
if self._last_checkpoint:
assert isinstance(self._last_checkpoint, str)
return Checkpoint.from_directory(self._last_checkpoint)
return None
@property @property
def logdir(self): def logdir(self):
return self._logdir return self._logdir
@ -484,7 +507,7 @@ class FunctionRunner(Trainable):
obj = TrainableUtil.checkpoint_to_object(checkpoint_path) obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
return obj return obj
def load_checkpoint(self, checkpoint): def load_checkpoint(self, checkpoint: str):
# This should be removed once Trainables are refactored. # This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint: if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"] del checkpoint["tune_checkpoint_path"]

View file

@ -1,113 +1,9 @@
from collections import Counter
from typing import Dict, List, Union, Optional
from tensorflow.keras.callbacks import Callback
from ray import tune
import os import os
from collections import Counter
from typing import Dict, List, Optional, Union
from ray import tune
class TuneCallback(Callback): from ray.air.callbacks.keras import _Callback as TuneCallback
"""Base class for Tune's Keras callbacks."""
_allowed = [
"batch_begin",
"batch_end",
"epoch_begin",
"epoch_end",
"train_batch_begin",
"train_batch_end",
"test_batch_begin",
"test_batch_end",
"predict_batch_begin",
"predict_batch_end",
"train_begin",
"train_end",
"test_begin",
"test_end",
"predict_begin",
"predict_end",
]
def __init__(self, on: Union[str, List[str]] = "validation_end"):
super(TuneCallback, self).__init__()
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on
def _handle(self, logs: Dict, when: str):
raise NotImplementedError
def on_batch_begin(self, batch, logs=None):
if "batch_begin" in self._on:
self._handle(logs, "batch_begin")
def on_batch_end(self, batch, logs=None):
if "batch_end" in self._on:
self._handle(logs, "batch_end")
def on_epoch_begin(self, epoch, logs=None):
if "epoch_begin" in self._on:
self._handle(logs, "epoch_begin")
def on_epoch_end(self, epoch, logs=None):
if "epoch_end" in self._on:
self._handle(logs, "epoch_end")
def on_train_batch_begin(self, batch, logs=None):
if "train_batch_begin" in self._on:
self._handle(logs, "train_batch_begin")
def on_train_batch_end(self, batch, logs=None):
if "train_batch_end" in self._on:
self._handle(logs, "train_batch_end")
def on_test_batch_begin(self, batch, logs=None):
if "test_batch_begin" in self._on:
self._handle(logs, "test_batch_begin")
def on_test_batch_end(self, batch, logs=None):
if "test_batch_end" in self._on:
self._handle(logs, "test_batch_end")
def on_predict_batch_begin(self, batch, logs=None):
if "predict_batch_begin" in self._on:
self._handle(logs, "predict_batch_begin")
def on_predict_batch_end(self, batch, logs=None):
if "predict_batch_end" in self._on:
self._handle(logs, "predict_batch_end")
def on_train_begin(self, logs=None):
if "train_begin" in self._on:
self._handle(logs, "train_begin")
def on_train_end(self, logs=None):
if "train_end" in self._on:
self._handle(logs, "train_end")
def on_test_begin(self, logs=None):
if "test_begin" in self._on:
self._handle(logs, "test_begin")
def on_test_end(self, logs=None):
if "test_end" in self._on:
self._handle(logs, "test_end")
def on_predict_begin(self, logs=None):
if "predict_begin" in self._on:
self._handle(logs, "predict_begin")
def on_predict_end(self, logs=None):
if "predict_end" in self._on:
self._handle(logs, "predict_end")
class TuneReportCallback(TuneCallback): class TuneReportCallback(TuneCallback):

View file

@ -1,23 +1,54 @@
from contextlib import contextmanager
import inspect import inspect
import os
import logging import logging
import os
import traceback import traceback
from contextlib import contextmanager
from typing import Dict, Optional, Set from typing import Dict, Optional, Set
import ray import ray
from ray.air._internal.session import Session
from ray.air.checkpoint import Checkpoint
from ray.tune.error import TuneError
from ray.tune.function_runner import _StatusReporter
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.util.placement_group import _valid_resource_shape from ray.util.placement_group import _valid_resource_shape
from ray.util.scheduling_strategies import ( from ray.util.scheduling_strategies import (
SchedulingStrategyT,
PlacementGroupSchedulingStrategy, PlacementGroupSchedulingStrategy,
SchedulingStrategyT,
) )
from ray.tune.error import TuneError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_session = None _session: Optional[_StatusReporter] = None
# V2 Session API.
_session_v2: Optional["_TuneSessionImpl"] = None
class _TuneSessionImpl(Session):
"""Session client that function trainable can interact with."""
def __init__(self, status_reporter: _StatusReporter):
self._status_reporter = status_reporter
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
self._status_reporter.report(metrics, checkpoint=checkpoint)
@property
def loaded_checkpoint(self) -> Optional[Checkpoint]:
return self._status_reporter.loaded_checkpoint
@property
def trial_name(self) -> str:
return self._status_reporter.trial_name
@property
def trial_id(self) -> str:
return self._status_reporter.trial_id
@property
def trial_resources(self) -> Dict[str, float]:
return self._status_reporter.trial_resources.required_resources
@PublicAPI @PublicAPI
@ -50,6 +81,7 @@ def get_session():
def init(reporter, ignore_reinit_error=True): def init(reporter, ignore_reinit_error=True):
"""Initializes the global trial context for this process.""" """Initializes the global trial context for this process."""
global _session global _session
global _session_v2
if _session is not None: if _session is not None:
# TODO(ng): would be nice to stack crawl at creation time to report # TODO(ng): would be nice to stack crawl at creation time to report
@ -83,6 +115,7 @@ def init(reporter, ignore_reinit_error=True):
remote_function._task_launch_hook = tune_task_and_actor_launch_hook remote_function._task_launch_hook = tune_task_and_actor_launch_hook
_session = reporter _session = reporter
_session_v2 = _TuneSessionImpl(status_reporter=reporter)
# Cache of resource dicts that have been checked by the launch hook already. # Cache of resource dicts that have been checked by the launch hook already.
@ -183,6 +216,11 @@ def report(_metric=None, **kwargs):
""" """
_session = get_session() _session = get_session()
if _session: if _session:
if _session._iter:
raise ValueError(
"It is not allowed to mix `tune.report` with `session.report`."
)
return _session(_metric, **kwargs) return _session(_metric, **kwargs)
@ -242,6 +280,11 @@ def checkpoint_dir(step: int):
raise ValueError("checkpoint_dir(step) must be provided - got None.") raise ValueError("checkpoint_dir(step) must be provided - got None.")
if _session: if _session:
if _session._iter:
raise ValueError(
"It is not allowed to mix `with tune.checkpoint_dir` "
"with `session.report`."
)
_checkpoint_dir = _session.make_checkpoint_dir(step=step) _checkpoint_dir = _session.make_checkpoint_dir(step=step)
else: else:
_checkpoint_dir = os.path.abspath("./") _checkpoint_dir = os.path.abspath("./")