[train] add TorchTensorboardProfilerCallback (#22345)

The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`.

```
  | File "ray_sgd_training.py", line 18, in <module>
  | from ray import train
  | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module>
  | from ray.train.callbacks import TrainingCallback
  | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module>
  | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
  | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module>
  | from torch.profiler import profile
  | ModuleNotFoundError: No module named 'torch.profiler'
```

A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes:
1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized.
2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed:

```
>>> import ray
>>> import ray.train
>>> import ray.train.torch
>>> from ray.train.torch import TorchWorkerProfiler
>>> twp = TorchWorkerProfiler()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__
    "Torch Profiler requires torch>=1.8.1. "
ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler.
```
This commit is contained in:
matthewdeng 2022-02-14 16:16:55 -08:00 committed by GitHub
parent 35a157948e
commit 8f9e0d7f6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 393 additions and 7 deletions

View file

@ -236,6 +236,7 @@ MOCK_MODULES = [
"torch.distributed", "torch.distributed",
"torch.nn", "torch.nn",
"torch.nn.parallel", "torch.nn.parallel",
"torch.profiler",
"torch.utils.data", "torch.utils.data",
"torch.utils.data.distributed", "torch.utils.data.distributed",
"wandb", "wandb",

View file

@ -86,6 +86,14 @@ MLflowLoggerCallback
.. autoclass:: ray.train.callbacks.MLflowLoggerCallback .. autoclass:: ray.train.callbacks.MLflowLoggerCallback
.. _train-api-torch-tensorboard-profiler-callback:
TorchTensorboardProfilerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.callbacks.TorchTensorboardProfilerCallback
ResultsPreprocessors ResultsPreprocessors
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
@ -175,6 +183,14 @@ train.torch.get_device
.. autofunction:: ray.train.torch.get_device .. autofunction:: ray.train.torch.get_device
.. _train-api-torch-worker-profiler:
train.torch.TorchWorkerProfiler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.torch.TorchWorkerProfiler
:members:
TensorFlow Training Function Utilities TensorFlow Training Function Utilities
-------------------------------------- --------------------------------------

View file

@ -20,6 +20,7 @@ In this guide, we cover examples for the following use cases:
* How do I :ref:`monitor <train-monitoring>` my training? * How do I :ref:`monitor <train-monitoring>` my training?
* How do I run my training on pre-emptible instances * How do I run my training on pre-emptible instances
(:ref:`fault tolerance <train-fault-tolerance>`)? (:ref:`fault tolerance <train-fault-tolerance>`)?
* How do I :ref:`profile <train-profiling>` my training?
* How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`? * How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`?
* How do I :ref:`tune <train-tune>` my Ray Train model? * How do I :ref:`tune <train-tune>` my Ray Train model?
@ -429,6 +430,7 @@ The following ``TrainingCallback``\s are available and will log the intermediate
2. :ref:`train-api-json-logger-callback` 2. :ref:`train-api-json-logger-callback`
3. :ref:`train-api-tbx-logger-callback` 3. :ref:`train-api-tbx-logger-callback`
4. :ref:`train-api-mlflow-logger-callback` 4. :ref:`train-api-mlflow-logger-callback`
5. :ref:`train-api-torch-tensorboard-profiler-callback`
Example: Logging to MLflow and TensorBoard Example: Logging to MLflow and TensorBoard
++++++++++++++++++++++++++++++++++++++++++ ++++++++++++++++++++++++++++++++++++++++++
@ -919,6 +921,60 @@ number of retries is configurable through the ``max_retries`` argument of the
.. TODO. .. TODO.
.. _train-profiling:
Profiling
---------
Ray Train comes with an integration with `PyTorch Profiler <https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/>`_.
Specifically, it comes with a :ref:`TorchWorkerProfiler <train-api-torch-worker-profiler>` utility class and :ref:`train-api-torch-tensorboard-profiler-callback` callback
that allow you to use the PyTorch Profiler as you would in a non-distributed PyTorch script, and synchronize the generated Tensorboard traces onto
the disk that from which your script was executed from.
**Step 1: Update training function with** ``TorchWorkerProfiler``
.. code-block:: bash
from ray.train.torch import TorchWorkerProfiler
def train_func():
twp = TorchWorkerProfiler()
with profile(..., on_trace_ready=twp.trace_handler) as p:
...
profile_results = twp.get_and_clear_profile_traces()
train.report(..., **profile_results)
...
**Step 2: Run training function with** ``TorchTensorboardProfilerCallback``
.. code-block:: python
from ray.train import Trainer
from ray.train.callbacks import TorchTensorboardProfilerCallback
trainer = Trainer(backend="torch", num_workers=2)
trainer.start()
trainer.run(train_func, callbacks=[TorchTensorboardProfilerCallback()])
trainer.shutdown()
**Step 3: Visualize the logs**
.. code-block:: bash
# Navigate to the run directory of the trainer.
# For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001/pytorch_profiler`
$ cd <TRAINER_RUN_DIR>/pytorch_profiler
# Install the PyTorch Profiler TensorBoard Plugin.
$ pip install torch_tb_profiler
# Star the TensorBoard UI.
$ tensorboard --logdir .
# View the PyTorch Profiler traces.
$ open http://localhost:6006/#pytorch_profiler
.. _train-datasets: .. _train-datasets:
Distributed Data Ingest (Ray Datasets) Distributed Data Ingest (Ray Datasets)

View file

@ -39,6 +39,15 @@ py_test(
deps = [":train_lib"] deps = [":train_lib"]
) )
py_test(
name = "torch_tensorboard_profiler_example",
size = "small",
main = "examples/torch_tensorboard_profiler_example.py",
srcs = ["examples/torch_tensorboard_profiler_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib"]
)
py_test( py_test(
name = "transformers_example", name = "transformers_example",
size = "large", size = "large",

View file

@ -5,11 +5,13 @@ from ray.train.callbacks.logging import (
TBXLoggerCallback, TBXLoggerCallback,
) )
from ray.train.callbacks.print import PrintCallback from ray.train.callbacks.print import PrintCallback
from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
__all__ = [ __all__ = [
"TrainingCallback", "TrainingCallback",
"JsonLoggerCallback", "JsonLoggerCallback",
"MLflowLoggerCallback", "MLflowLoggerCallback",
"TBXLoggerCallback", "TBXLoggerCallback",
"TorchTensorboardProfilerCallback",
"PrintCallback", "PrintCallback",
] ]

View file

@ -1,13 +1,22 @@
import abc import abc
from typing import List, Dict from typing import List, Dict
from ray.train.callbacks.results_preprocessors import ResultsPreprocessor from ray.train.callbacks.results_preprocessors import (
ResultsPreprocessor,
ExcludedKeysResultsPreprocessor,
SequentialResultsPreprocessor,
)
from ray.train.constants import ALL_RESERVED_KEYS
class TrainingCallback(abc.ABC): class TrainingCallback(abc.ABC):
"""Abstract Train callback class.""" """Abstract Train callback class."""
results_preprocessor: ResultsPreprocessor = None results_preprocessor: ResultsPreprocessor = None
# Reserved keys used by this specific Callback.
# This should be set in a Callback class implementation so that the keys
# are not filtered out. See ``_preprocess_results`` for more details.
RESERVED_KEYS = {}
def start_training(self, logdir: str, config: Dict, **info): def start_training(self, logdir: str, config: Dict, **info):
"""Called once on training start. """Called once on training start.
@ -34,10 +43,37 @@ class TrainingCallback(abc.ABC):
the training function from each worker. the training function from each worker.
**info: kwargs dict for forward compatibility. **info: kwargs dict for forward compatibility.
""" """
if self.results_preprocessor: results = self._preprocess_results(results)
results = self.results_preprocessor.preprocess(results)
self.handle_result(results, **info) self.handle_result(results, **info)
def _preprocess_results(self, results: List[Dict]) -> List[Dict]:
"""Preprocesses the reported training results.
This will:
* Exclude all keys that are present in ``self.ALL_RESERVED_KEYS`` but
not ``self.RESERVED_KEYS``
* Execute ``self.results_preprocessor`` if defined.
Args:
results (List[Dict]): List of results from the training
function. Each value in the list corresponds to the output of
the training function from each worker.
Returns:
The preprocessed results.
"""
results_to_exclude = ALL_RESERVED_KEYS.difference(self.RESERVED_KEYS)
system_preprocessor = ExcludedKeysResultsPreprocessor(results_to_exclude)
if self.results_preprocessor:
self.results_preprocessor = SequentialResultsPreprocessor(
[system_preprocessor, self.results_preprocessor]
)
else:
self.results_preprocessor = system_preprocessor
results = self.results_preprocessor.preprocess(results)
return results
def handle_result(self, results: List[Dict], **info): def handle_result(self, results: List[Dict], **info):
"""Called every time train.report() is called after preprocessing. """Called every time train.report() is called after preprocessing.

View file

@ -59,14 +59,14 @@ class TrainCallbackLogdirManager:
self._logdir = Path(logdir) if logdir else None self._logdir = Path(logdir) if logdir else None
self._create_logdir = create_logdir self._create_logdir = create_logdir
def setup_logdir(self, default_logdir: str) -> Path: def setup_logdir(self, default_logdir: Union[str, Path]) -> Path:
"""Sets up the logdir. """Sets up the logdir.
The directory will be created if it does not exist and The directory will be created if it does not exist and
``create_logdir`` is set to True. ``create_logdir`` is set to True.
Args: Args:
default_logdir (str): The default logdir to use, only if the default_logdir (str|Path): The default logdir to use, only if the
``TrainCallbackLogdirManager`` was not initialized with a ``logdir``. ``TrainCallbackLogdirManager`` was not initialized with a ``logdir``.
Returns: Returns:

View file

@ -0,0 +1,53 @@
import logging
from pathlib import Path
from typing import List, Dict, Optional, Union
from ray.train.callbacks import TrainingCallback
from ray.train.callbacks.logging import TrainCallbackLogdirManager
from ray.train.callbacks.results_preprocessors import IndexedResultsPreprocessor
from ray.train.constants import PYTORCH_PROFILER_KEY
logger = logging.getLogger(__name__)
DRIVER_TRACE_DIR_NAME = "pytorch_profiler"
class TorchTensorboardProfilerCallback(TrainingCallback):
"""Synchronizes PyTorch Profiler traces onto disk.
This should typically be used in conjunction with ``TorchWorkerProfiler``,
though the actual requirement is for the ``_train_torch_profiler`` key
to be populated in the results from ``train.report()``.
Args:
logdir (Optional[str]): The directory to store traces. If ``None``,
this will use a default temporary dir.
workers_to_log (Optional[int|List[int]]): Worker indices to log.
If ``None``, will log all workers. By default, this will log all
workers.
"""
RESERVED_KEYS = [PYTORCH_PROFILER_KEY]
def __init__(
self,
logdir: Optional[str] = None,
workers_to_log: Optional[Union[int, List[int]]] = None,
) -> None:
super().__init__()
self._logdir = logdir
self._logdir_manager = TrainCallbackLogdirManager(logdir=logdir)
self.results_preprocessor = IndexedResultsPreprocessor(indices=workers_to_log)
def start_training(self, logdir: str, **info):
default_logdir = Path(logdir).joinpath(DRIVER_TRACE_DIR_NAME)
self._logdir_manager.setup_logdir(default_logdir=default_logdir)
def handle_result(self, results: List[Dict], **info):
for result in results:
if PYTORCH_PROFILER_KEY in result and result[PYTORCH_PROFILER_KEY]:
profile_traces = result[PYTORCH_PROFILER_KEY]
for (name, data) in profile_traces:
path = self._logdir_manager.logdir_path.joinpath(name)
with path.open("w") as f:
f.write(data)

View file

@ -64,3 +64,13 @@ TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"
# The key used to identify whether we have already warned about ray.train # The key used to identify whether we have already warned about ray.train
# functions being used outside of the session # functions being used outside of the session
SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse" SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse"
# Reserved keyword used by the ``TorchWorkerProfiler`` and
# ``TorchTensorboardProfilerCallback`` for passing PyTorch Profiler data
# through ``train.report()``
PYTORCH_PROFILER_KEY = "_train_torch_profiler"
# Reserved keys used across all Callbacks.
# By default these will be filtered out from ``train.report()``.
# See ``TrainingCallback._preprocess_results`` for more details.
ALL_RESERVED_KEYS = {PYTORCH_PROFILER_KEY}

View file

@ -0,0 +1,84 @@
import argparse
import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.profiler import profile, record_function, schedule
import ray
import ray.train as train
from ray.train import Trainer
from ray.train.callbacks import TBXLoggerCallback
from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
from ray.train.torch import TorchWorkerProfiler
def train_func():
twp = TorchWorkerProfiler()
with profile(
activities=[],
schedule=schedule(wait=0, warmup=0, active=1),
on_trace_ready=twp.trace_handler,
) as p:
# Setup model.
model = torch.nn.Linear(1, 1)
model = train.torch.prepare_model(model)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
# Setup data.
input = torch.randn(1000, 1)
labels = input * 2
dataset = torch.utils.data.TensorDataset(input, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
dataloader = train.torch.prepare_data_loader(dataloader)
# Train.
for epoch in range(5):
with record_function("train_epoch"):
for X, y in dataloader:
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with record_function("train_checkpoint"):
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
train.save_checkpoint(epoch=epoch, model_weights=state_dict)
p.step()
with record_function("train_report"):
profile_results = twp.get_and_clear_profile_traces()
train.report(epoch=epoch, **profile_results)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
args = parser.parse_args()
ray.init(address=args.address)
callbacks = [TorchTensorboardProfilerCallback(), TBXLoggerCallback()]
trainer = Trainer(
backend="torch", num_workers=args.num_workers, use_gpu=args.use_gpu
)
trainer.start()
trainer.run(train_func, callbacks=callbacks)
trainer.shutdown()

View file

@ -11,7 +11,12 @@ import ray
import ray.train as train import ray.train as train
from ray.train import Trainer from ray.train import Trainer
from ray.train.backend import BackendConfig, Backend from ray.train.backend import BackendConfig, Backend
from ray.train.callbacks import JsonLoggerCallback, PrintCallback, TBXLoggerCallback from ray.train.callbacks import (
JsonLoggerCallback,
PrintCallback,
TBXLoggerCallback,
TorchTensorboardProfilerCallback,
)
from ray.train.callbacks.logging import MLflowLoggerCallback, TrainCallbackLogdirManager from ray.train.callbacks.logging import MLflowLoggerCallback, TrainCallbackLogdirManager
from ray.train.constants import ( from ray.train.constants import (
TRAINING_ITERATION, TRAINING_ITERATION,
@ -255,6 +260,47 @@ def test_mlflow(ray_start_4_cpus, tmp_path):
assert rewards == [4, 5, 6] assert rewards == [4, 5, 6]
def test_torch_tensorboard_profiler_callback(ray_start_4_cpus, tmp_path):
config = TestConfig()
temp_dir = tmp_path
num_workers = 4
num_epochs = 2
def train_func():
from ray.train.torch import TorchWorkerProfiler
from torch.profiler import profile, record_function, schedule
twp = TorchWorkerProfiler()
with profile(
activities=[],
schedule=schedule(wait=0, warmup=0, active=1),
on_trace_ready=twp.trace_handler,
) as p:
for epoch in range(num_epochs):
with record_function("test_function"):
pass
p.step()
profile_results = twp.get_and_clear_profile_traces()
train.report(epoch=epoch, **profile_results)
callback = TorchTensorboardProfilerCallback(temp_dir)
trainer = Trainer(config, num_workers=num_workers)
trainer.start()
trainer.run(train_func, callbacks=[callback])
assert temp_dir.exists()
count = 0
for path in temp_dir.iterdir():
assert path.is_file()
count += 1
assert count == num_workers * num_epochs
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
import sys import sys

View file

@ -6,7 +6,7 @@ import ray
import ray.train as train import ray.train as train
from ray.train import Trainer from ray.train import Trainer
from ray.train.backend import BackendConfig, Backend from ray.train.backend import BackendConfig, Backend
from ray.train.callbacks.callback import TrainingCallback from ray.train.callbacks import TrainingCallback
from ray.train.worker_group import WorkerGroup from ray.train.worker_group import WorkerGroup

View file

@ -1,14 +1,17 @@
import tempfile
from dataclasses import dataclass from dataclasses import dataclass
import io import io
import logging import logging
import os import os
from datetime import timedelta from datetime import timedelta
from pathlib import Path
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
import ray import ray
from ray import train from ray import train
from ray.train.backend import BackendConfig, Backend, EncodedData from ray.train.backend import BackendConfig, Backend, EncodedData
from ray.train.constants import PYTORCH_PROFILER_KEY
from ray.train.worker_group import WorkerGroup from ray.train.worker_group import WorkerGroup
from ray.train.utils import get_address_and_port from ray.train.utils import get_address_and_port
@ -23,6 +26,11 @@ from torch.utils.data import (
SequentialSampler, SequentialSampler,
) )
try:
from torch.profiler import profile
except ImportError:
profile = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -338,3 +346,68 @@ def prepare_data_loader(
data_loader = _WrappedDataLoader(data_loader, device) data_loader = _WrappedDataLoader(data_loader, device)
return data_loader return data_loader
WORKER_TRACE_DIR_NAME = "pytorch_profiler_worker_traces"
class TorchWorkerProfiler:
"""Utility class for running PyTorch Profiler on a Train worker.
Args:
trace_dir (Optional[str]): The directory to store traces on the
worker node. If ``None``, this will use a default temporary dir.
"""
def __init__(self, trace_dir: Optional[str] = None):
if profile is None:
raise ImportError(
"Torch Profiler requires torch>=1.8.1. "
"Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler."
)
trace_dir = trace_dir or Path(tempfile.gettempdir()).joinpath(
WORKER_TRACE_DIR_NAME
)
self.trace_dir = Path(trace_dir)
self.trace_dir.mkdir(parents=True, exist_ok=True)
# Accumulated traces.
self.profiler_trace_filenames = []
def trace_handler(self, p: profile):
"""A stateful PyTorch Profiler trace handler.
This will the export chrome trace to a file on disk.
These exported traces can then be fetched by calling
``get_and_clear_profile_traces``.
Args:
p (profile): A PyTorch Profiler profile.
"""
trace_filename = f"worker_{train.world_rank()}_epoch_{p.step_num}.pt.trace.json"
trace_path = self.trace_dir.joinpath(trace_filename)
logger.debug(f"Writing worker trace to {trace_path}.")
p.export_chrome_trace(str(trace_path))
self.profiler_trace_filenames.append(trace_filename)
def get_and_clear_profile_traces(self):
"""Reads unread Profiler traces from this worker.
Returns:
The traces in a format consumable by
``TorchTensorboardProfilerCallback``.
"""
def get_trace(filename):
trace_path = self.trace_dir.joinpath(filename)
return trace_path.read_text()
traces = [
(trace_filename, get_trace(trace_filename))
for trace_filename in self.profiler_trace_filenames
]
self.profiler_trace_files = []
return {PYTORCH_PROFILER_KEY: traces}