mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[train] add TorchTensorboardProfilerCallback (#22345)
The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ```
This commit is contained in:
parent
35a157948e
commit
8f9e0d7f6b
13 changed files with 393 additions and 7 deletions
|
@ -236,6 +236,7 @@ MOCK_MODULES = [
|
||||||
"torch.distributed",
|
"torch.distributed",
|
||||||
"torch.nn",
|
"torch.nn",
|
||||||
"torch.nn.parallel",
|
"torch.nn.parallel",
|
||||||
|
"torch.profiler",
|
||||||
"torch.utils.data",
|
"torch.utils.data",
|
||||||
"torch.utils.data.distributed",
|
"torch.utils.data.distributed",
|
||||||
"wandb",
|
"wandb",
|
||||||
|
|
|
@ -86,6 +86,14 @@ MLflowLoggerCallback
|
||||||
|
|
||||||
.. autoclass:: ray.train.callbacks.MLflowLoggerCallback
|
.. autoclass:: ray.train.callbacks.MLflowLoggerCallback
|
||||||
|
|
||||||
|
|
||||||
|
.. _train-api-torch-tensorboard-profiler-callback:
|
||||||
|
|
||||||
|
TorchTensorboardProfilerCallback
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: ray.train.callbacks.TorchTensorboardProfilerCallback
|
||||||
|
|
||||||
ResultsPreprocessors
|
ResultsPreprocessors
|
||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -175,6 +183,14 @@ train.torch.get_device
|
||||||
|
|
||||||
.. autofunction:: ray.train.torch.get_device
|
.. autofunction:: ray.train.torch.get_device
|
||||||
|
|
||||||
|
.. _train-api-torch-worker-profiler:
|
||||||
|
|
||||||
|
train.torch.TorchWorkerProfiler
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: ray.train.torch.TorchWorkerProfiler
|
||||||
|
:members:
|
||||||
|
|
||||||
TensorFlow Training Function Utilities
|
TensorFlow Training Function Utilities
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ In this guide, we cover examples for the following use cases:
|
||||||
* How do I :ref:`monitor <train-monitoring>` my training?
|
* How do I :ref:`monitor <train-monitoring>` my training?
|
||||||
* How do I run my training on pre-emptible instances
|
* How do I run my training on pre-emptible instances
|
||||||
(:ref:`fault tolerance <train-fault-tolerance>`)?
|
(:ref:`fault tolerance <train-fault-tolerance>`)?
|
||||||
|
* How do I :ref:`profile <train-profiling>` my training?
|
||||||
* How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`?
|
* How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`?
|
||||||
* How do I :ref:`tune <train-tune>` my Ray Train model?
|
* How do I :ref:`tune <train-tune>` my Ray Train model?
|
||||||
|
|
||||||
|
@ -429,6 +430,7 @@ The following ``TrainingCallback``\s are available and will log the intermediate
|
||||||
2. :ref:`train-api-json-logger-callback`
|
2. :ref:`train-api-json-logger-callback`
|
||||||
3. :ref:`train-api-tbx-logger-callback`
|
3. :ref:`train-api-tbx-logger-callback`
|
||||||
4. :ref:`train-api-mlflow-logger-callback`
|
4. :ref:`train-api-mlflow-logger-callback`
|
||||||
|
5. :ref:`train-api-torch-tensorboard-profiler-callback`
|
||||||
|
|
||||||
Example: Logging to MLflow and TensorBoard
|
Example: Logging to MLflow and TensorBoard
|
||||||
++++++++++++++++++++++++++++++++++++++++++
|
++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
@ -919,6 +921,60 @@ number of retries is configurable through the ``max_retries`` argument of the
|
||||||
|
|
||||||
.. TODO.
|
.. TODO.
|
||||||
|
|
||||||
|
.. _train-profiling:
|
||||||
|
|
||||||
|
Profiling
|
||||||
|
---------
|
||||||
|
|
||||||
|
Ray Train comes with an integration with `PyTorch Profiler <https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/>`_.
|
||||||
|
Specifically, it comes with a :ref:`TorchWorkerProfiler <train-api-torch-worker-profiler>` utility class and :ref:`train-api-torch-tensorboard-profiler-callback` callback
|
||||||
|
that allow you to use the PyTorch Profiler as you would in a non-distributed PyTorch script, and synchronize the generated Tensorboard traces onto
|
||||||
|
the disk that from which your script was executed from.
|
||||||
|
|
||||||
|
**Step 1: Update training function with** ``TorchWorkerProfiler``
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
from ray.train.torch import TorchWorkerProfiler
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
twp = TorchWorkerProfiler()
|
||||||
|
with profile(..., on_trace_ready=twp.trace_handler) as p:
|
||||||
|
...
|
||||||
|
profile_results = twp.get_and_clear_profile_traces()
|
||||||
|
train.report(..., **profile_results)
|
||||||
|
...
|
||||||
|
|
||||||
|
**Step 2: Run training function with** ``TorchTensorboardProfilerCallback``
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from ray.train import Trainer
|
||||||
|
from ray.train.callbacks import TorchTensorboardProfilerCallback
|
||||||
|
|
||||||
|
trainer = Trainer(backend="torch", num_workers=2)
|
||||||
|
trainer.start()
|
||||||
|
trainer.run(train_func, callbacks=[TorchTensorboardProfilerCallback()])
|
||||||
|
trainer.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
**Step 3: Visualize the logs**
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# Navigate to the run directory of the trainer.
|
||||||
|
# For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001/pytorch_profiler`
|
||||||
|
$ cd <TRAINER_RUN_DIR>/pytorch_profiler
|
||||||
|
|
||||||
|
# Install the PyTorch Profiler TensorBoard Plugin.
|
||||||
|
$ pip install torch_tb_profiler
|
||||||
|
|
||||||
|
# Star the TensorBoard UI.
|
||||||
|
$ tensorboard --logdir .
|
||||||
|
|
||||||
|
# View the PyTorch Profiler traces.
|
||||||
|
$ open http://localhost:6006/#pytorch_profiler
|
||||||
|
|
||||||
.. _train-datasets:
|
.. _train-datasets:
|
||||||
|
|
||||||
Distributed Data Ingest (Ray Datasets)
|
Distributed Data Ingest (Ray Datasets)
|
||||||
|
|
|
@ -39,6 +39,15 @@ py_test(
|
||||||
deps = [":train_lib"]
|
deps = [":train_lib"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "torch_tensorboard_profiler_example",
|
||||||
|
size = "small",
|
||||||
|
main = "examples/torch_tensorboard_profiler_example.py",
|
||||||
|
srcs = ["examples/torch_tensorboard_profiler_example.py"],
|
||||||
|
tags = ["team:ml", "exclusive"],
|
||||||
|
deps = [":train_lib"]
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "transformers_example",
|
name = "transformers_example",
|
||||||
size = "large",
|
size = "large",
|
||||||
|
|
|
@ -5,11 +5,13 @@ from ray.train.callbacks.logging import (
|
||||||
TBXLoggerCallback,
|
TBXLoggerCallback,
|
||||||
)
|
)
|
||||||
from ray.train.callbacks.print import PrintCallback
|
from ray.train.callbacks.print import PrintCallback
|
||||||
|
from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingCallback",
|
"TrainingCallback",
|
||||||
"JsonLoggerCallback",
|
"JsonLoggerCallback",
|
||||||
"MLflowLoggerCallback",
|
"MLflowLoggerCallback",
|
||||||
"TBXLoggerCallback",
|
"TBXLoggerCallback",
|
||||||
|
"TorchTensorboardProfilerCallback",
|
||||||
"PrintCallback",
|
"PrintCallback",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,13 +1,22 @@
|
||||||
import abc
|
import abc
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from ray.train.callbacks.results_preprocessors import ResultsPreprocessor
|
from ray.train.callbacks.results_preprocessors import (
|
||||||
|
ResultsPreprocessor,
|
||||||
|
ExcludedKeysResultsPreprocessor,
|
||||||
|
SequentialResultsPreprocessor,
|
||||||
|
)
|
||||||
|
from ray.train.constants import ALL_RESERVED_KEYS
|
||||||
|
|
||||||
|
|
||||||
class TrainingCallback(abc.ABC):
|
class TrainingCallback(abc.ABC):
|
||||||
"""Abstract Train callback class."""
|
"""Abstract Train callback class."""
|
||||||
|
|
||||||
results_preprocessor: ResultsPreprocessor = None
|
results_preprocessor: ResultsPreprocessor = None
|
||||||
|
# Reserved keys used by this specific Callback.
|
||||||
|
# This should be set in a Callback class implementation so that the keys
|
||||||
|
# are not filtered out. See ``_preprocess_results`` for more details.
|
||||||
|
RESERVED_KEYS = {}
|
||||||
|
|
||||||
def start_training(self, logdir: str, config: Dict, **info):
|
def start_training(self, logdir: str, config: Dict, **info):
|
||||||
"""Called once on training start.
|
"""Called once on training start.
|
||||||
|
@ -34,10 +43,37 @@ class TrainingCallback(abc.ABC):
|
||||||
the training function from each worker.
|
the training function from each worker.
|
||||||
**info: kwargs dict for forward compatibility.
|
**info: kwargs dict for forward compatibility.
|
||||||
"""
|
"""
|
||||||
if self.results_preprocessor:
|
results = self._preprocess_results(results)
|
||||||
results = self.results_preprocessor.preprocess(results)
|
|
||||||
self.handle_result(results, **info)
|
self.handle_result(results, **info)
|
||||||
|
|
||||||
|
def _preprocess_results(self, results: List[Dict]) -> List[Dict]:
|
||||||
|
"""Preprocesses the reported training results.
|
||||||
|
|
||||||
|
This will:
|
||||||
|
|
||||||
|
* Exclude all keys that are present in ``self.ALL_RESERVED_KEYS`` but
|
||||||
|
not ``self.RESERVED_KEYS``
|
||||||
|
* Execute ``self.results_preprocessor`` if defined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (List[Dict]): List of results from the training
|
||||||
|
function. Each value in the list corresponds to the output of
|
||||||
|
the training function from each worker.
|
||||||
|
Returns:
|
||||||
|
The preprocessed results.
|
||||||
|
|
||||||
|
"""
|
||||||
|
results_to_exclude = ALL_RESERVED_KEYS.difference(self.RESERVED_KEYS)
|
||||||
|
system_preprocessor = ExcludedKeysResultsPreprocessor(results_to_exclude)
|
||||||
|
if self.results_preprocessor:
|
||||||
|
self.results_preprocessor = SequentialResultsPreprocessor(
|
||||||
|
[system_preprocessor, self.results_preprocessor]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.results_preprocessor = system_preprocessor
|
||||||
|
results = self.results_preprocessor.preprocess(results)
|
||||||
|
return results
|
||||||
|
|
||||||
def handle_result(self, results: List[Dict], **info):
|
def handle_result(self, results: List[Dict], **info):
|
||||||
"""Called every time train.report() is called after preprocessing.
|
"""Called every time train.report() is called after preprocessing.
|
||||||
|
|
||||||
|
|
|
@ -59,14 +59,14 @@ class TrainCallbackLogdirManager:
|
||||||
self._logdir = Path(logdir) if logdir else None
|
self._logdir = Path(logdir) if logdir else None
|
||||||
self._create_logdir = create_logdir
|
self._create_logdir = create_logdir
|
||||||
|
|
||||||
def setup_logdir(self, default_logdir: str) -> Path:
|
def setup_logdir(self, default_logdir: Union[str, Path]) -> Path:
|
||||||
"""Sets up the logdir.
|
"""Sets up the logdir.
|
||||||
|
|
||||||
The directory will be created if it does not exist and
|
The directory will be created if it does not exist and
|
||||||
``create_logdir`` is set to True.
|
``create_logdir`` is set to True.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
default_logdir (str): The default logdir to use, only if the
|
default_logdir (str|Path): The default logdir to use, only if the
|
||||||
``TrainCallbackLogdirManager`` was not initialized with a ``logdir``.
|
``TrainCallbackLogdirManager`` was not initialized with a ``logdir``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
53
python/ray/train/callbacks/profile.py
Normal file
53
python/ray/train/callbacks/profile.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Optional, Union
|
||||||
|
|
||||||
|
from ray.train.callbacks import TrainingCallback
|
||||||
|
from ray.train.callbacks.logging import TrainCallbackLogdirManager
|
||||||
|
from ray.train.callbacks.results_preprocessors import IndexedResultsPreprocessor
|
||||||
|
from ray.train.constants import PYTORCH_PROFILER_KEY
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DRIVER_TRACE_DIR_NAME = "pytorch_profiler"
|
||||||
|
|
||||||
|
|
||||||
|
class TorchTensorboardProfilerCallback(TrainingCallback):
|
||||||
|
"""Synchronizes PyTorch Profiler traces onto disk.
|
||||||
|
|
||||||
|
This should typically be used in conjunction with ``TorchWorkerProfiler``,
|
||||||
|
though the actual requirement is for the ``_train_torch_profiler`` key
|
||||||
|
to be populated in the results from ``train.report()``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logdir (Optional[str]): The directory to store traces. If ``None``,
|
||||||
|
this will use a default temporary dir.
|
||||||
|
workers_to_log (Optional[int|List[int]]): Worker indices to log.
|
||||||
|
If ``None``, will log all workers. By default, this will log all
|
||||||
|
workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESERVED_KEYS = [PYTORCH_PROFILER_KEY]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logdir: Optional[str] = None,
|
||||||
|
workers_to_log: Optional[Union[int, List[int]]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._logdir = logdir
|
||||||
|
self._logdir_manager = TrainCallbackLogdirManager(logdir=logdir)
|
||||||
|
self.results_preprocessor = IndexedResultsPreprocessor(indices=workers_to_log)
|
||||||
|
|
||||||
|
def start_training(self, logdir: str, **info):
|
||||||
|
default_logdir = Path(logdir).joinpath(DRIVER_TRACE_DIR_NAME)
|
||||||
|
self._logdir_manager.setup_logdir(default_logdir=default_logdir)
|
||||||
|
|
||||||
|
def handle_result(self, results: List[Dict], **info):
|
||||||
|
for result in results:
|
||||||
|
if PYTORCH_PROFILER_KEY in result and result[PYTORCH_PROFILER_KEY]:
|
||||||
|
profile_traces = result[PYTORCH_PROFILER_KEY]
|
||||||
|
for (name, data) in profile_traces:
|
||||||
|
path = self._logdir_manager.logdir_path.joinpath(name)
|
||||||
|
with path.open("w") as f:
|
||||||
|
f.write(data)
|
|
@ -64,3 +64,13 @@ TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"
|
||||||
# The key used to identify whether we have already warned about ray.train
|
# The key used to identify whether we have already warned about ray.train
|
||||||
# functions being used outside of the session
|
# functions being used outside of the session
|
||||||
SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse"
|
SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse"
|
||||||
|
|
||||||
|
# Reserved keyword used by the ``TorchWorkerProfiler`` and
|
||||||
|
# ``TorchTensorboardProfilerCallback`` for passing PyTorch Profiler data
|
||||||
|
# through ``train.report()``
|
||||||
|
PYTORCH_PROFILER_KEY = "_train_torch_profiler"
|
||||||
|
|
||||||
|
# Reserved keys used across all Callbacks.
|
||||||
|
# By default these will be filtered out from ``train.report()``.
|
||||||
|
# See ``TrainingCallback._preprocess_results`` for more details.
|
||||||
|
ALL_RESERVED_KEYS = {PYTORCH_PROFILER_KEY}
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
||||||
|
from torch.profiler import profile, record_function, schedule
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import ray.train as train
|
||||||
|
from ray.train import Trainer
|
||||||
|
from ray.train.callbacks import TBXLoggerCallback
|
||||||
|
from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
|
||||||
|
from ray.train.torch import TorchWorkerProfiler
|
||||||
|
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
twp = TorchWorkerProfiler()
|
||||||
|
with profile(
|
||||||
|
activities=[],
|
||||||
|
schedule=schedule(wait=0, warmup=0, active=1),
|
||||||
|
on_trace_ready=twp.trace_handler,
|
||||||
|
) as p:
|
||||||
|
|
||||||
|
# Setup model.
|
||||||
|
model = torch.nn.Linear(1, 1)
|
||||||
|
model = train.torch.prepare_model(model)
|
||||||
|
loss_fn = torch.nn.MSELoss()
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||||
|
|
||||||
|
# Setup data.
|
||||||
|
input = torch.randn(1000, 1)
|
||||||
|
labels = input * 2
|
||||||
|
dataset = torch.utils.data.TensorDataset(input, labels)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||||
|
dataloader = train.torch.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
# Train.
|
||||||
|
for epoch in range(5):
|
||||||
|
with record_function("train_epoch"):
|
||||||
|
for X, y in dataloader:
|
||||||
|
pred = model(X)
|
||||||
|
loss = loss_fn(pred, y)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
with record_function("train_checkpoint"):
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
consume_prefix_in_state_dict_if_present(state_dict, "module.")
|
||||||
|
train.save_checkpoint(epoch=epoch, model_weights=state_dict)
|
||||||
|
|
||||||
|
p.step()
|
||||||
|
|
||||||
|
with record_function("train_report"):
|
||||||
|
profile_results = twp.get_and_clear_profile_traces()
|
||||||
|
train.report(epoch=epoch, **profile_results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--address", required=False, type=str, help="the address to use for Ray"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
"-n",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Sets number of workers for training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ray.init(address=args.address)
|
||||||
|
|
||||||
|
callbacks = [TorchTensorboardProfilerCallback(), TBXLoggerCallback()]
|
||||||
|
trainer = Trainer(
|
||||||
|
backend="torch", num_workers=args.num_workers, use_gpu=args.use_gpu
|
||||||
|
)
|
||||||
|
trainer.start()
|
||||||
|
trainer.run(train_func, callbacks=callbacks)
|
||||||
|
trainer.shutdown()
|
|
@ -11,7 +11,12 @@ import ray
|
||||||
import ray.train as train
|
import ray.train as train
|
||||||
from ray.train import Trainer
|
from ray.train import Trainer
|
||||||
from ray.train.backend import BackendConfig, Backend
|
from ray.train.backend import BackendConfig, Backend
|
||||||
from ray.train.callbacks import JsonLoggerCallback, PrintCallback, TBXLoggerCallback
|
from ray.train.callbacks import (
|
||||||
|
JsonLoggerCallback,
|
||||||
|
PrintCallback,
|
||||||
|
TBXLoggerCallback,
|
||||||
|
TorchTensorboardProfilerCallback,
|
||||||
|
)
|
||||||
from ray.train.callbacks.logging import MLflowLoggerCallback, TrainCallbackLogdirManager
|
from ray.train.callbacks.logging import MLflowLoggerCallback, TrainCallbackLogdirManager
|
||||||
from ray.train.constants import (
|
from ray.train.constants import (
|
||||||
TRAINING_ITERATION,
|
TRAINING_ITERATION,
|
||||||
|
@ -255,6 +260,47 @@ def test_mlflow(ray_start_4_cpus, tmp_path):
|
||||||
assert rewards == [4, 5, 6]
|
assert rewards == [4, 5, 6]
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_tensorboard_profiler_callback(ray_start_4_cpus, tmp_path):
|
||||||
|
config = TestConfig()
|
||||||
|
|
||||||
|
temp_dir = tmp_path
|
||||||
|
num_workers = 4
|
||||||
|
num_epochs = 2
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
from ray.train.torch import TorchWorkerProfiler
|
||||||
|
from torch.profiler import profile, record_function, schedule
|
||||||
|
|
||||||
|
twp = TorchWorkerProfiler()
|
||||||
|
with profile(
|
||||||
|
activities=[],
|
||||||
|
schedule=schedule(wait=0, warmup=0, active=1),
|
||||||
|
on_trace_ready=twp.trace_handler,
|
||||||
|
) as p:
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
with record_function("test_function"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
p.step()
|
||||||
|
|
||||||
|
profile_results = twp.get_and_clear_profile_traces()
|
||||||
|
train.report(epoch=epoch, **profile_results)
|
||||||
|
|
||||||
|
callback = TorchTensorboardProfilerCallback(temp_dir)
|
||||||
|
trainer = Trainer(config, num_workers=num_workers)
|
||||||
|
trainer.start()
|
||||||
|
trainer.run(train_func, callbacks=[callback])
|
||||||
|
|
||||||
|
assert temp_dir.exists()
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for path in temp_dir.iterdir():
|
||||||
|
assert path.is_file()
|
||||||
|
count += 1
|
||||||
|
assert count == num_workers * num_epochs
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -6,7 +6,7 @@ import ray
|
||||||
import ray.train as train
|
import ray.train as train
|
||||||
from ray.train import Trainer
|
from ray.train import Trainer
|
||||||
from ray.train.backend import BackendConfig, Backend
|
from ray.train.backend import BackendConfig, Backend
|
||||||
from ray.train.callbacks.callback import TrainingCallback
|
from ray.train.callbacks import TrainingCallback
|
||||||
from ray.train.worker_group import WorkerGroup
|
from ray.train.worker_group import WorkerGroup
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import train
|
from ray import train
|
||||||
from ray.train.backend import BackendConfig, Backend, EncodedData
|
from ray.train.backend import BackendConfig, Backend, EncodedData
|
||||||
|
from ray.train.constants import PYTORCH_PROFILER_KEY
|
||||||
from ray.train.worker_group import WorkerGroup
|
from ray.train.worker_group import WorkerGroup
|
||||||
from ray.train.utils import get_address_and_port
|
from ray.train.utils import get_address_and_port
|
||||||
|
|
||||||
|
@ -23,6 +26,11 @@ from torch.utils.data import (
|
||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.profiler import profile
|
||||||
|
except ImportError:
|
||||||
|
profile = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -338,3 +346,68 @@ def prepare_data_loader(
|
||||||
data_loader = _WrappedDataLoader(data_loader, device)
|
data_loader = _WrappedDataLoader(data_loader, device)
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
|
WORKER_TRACE_DIR_NAME = "pytorch_profiler_worker_traces"
|
||||||
|
|
||||||
|
|
||||||
|
class TorchWorkerProfiler:
|
||||||
|
"""Utility class for running PyTorch Profiler on a Train worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_dir (Optional[str]): The directory to store traces on the
|
||||||
|
worker node. If ``None``, this will use a default temporary dir.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trace_dir: Optional[str] = None):
|
||||||
|
if profile is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Torch Profiler requires torch>=1.8.1. "
|
||||||
|
"Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler."
|
||||||
|
)
|
||||||
|
|
||||||
|
trace_dir = trace_dir or Path(tempfile.gettempdir()).joinpath(
|
||||||
|
WORKER_TRACE_DIR_NAME
|
||||||
|
)
|
||||||
|
self.trace_dir = Path(trace_dir)
|
||||||
|
self.trace_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Accumulated traces.
|
||||||
|
self.profiler_trace_filenames = []
|
||||||
|
|
||||||
|
def trace_handler(self, p: profile):
|
||||||
|
"""A stateful PyTorch Profiler trace handler.
|
||||||
|
|
||||||
|
This will the export chrome trace to a file on disk.
|
||||||
|
|
||||||
|
These exported traces can then be fetched by calling
|
||||||
|
``get_and_clear_profile_traces``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (profile): A PyTorch Profiler profile.
|
||||||
|
"""
|
||||||
|
trace_filename = f"worker_{train.world_rank()}_epoch_{p.step_num}.pt.trace.json"
|
||||||
|
trace_path = self.trace_dir.joinpath(trace_filename)
|
||||||
|
|
||||||
|
logger.debug(f"Writing worker trace to {trace_path}.")
|
||||||
|
p.export_chrome_trace(str(trace_path))
|
||||||
|
self.profiler_trace_filenames.append(trace_filename)
|
||||||
|
|
||||||
|
def get_and_clear_profile_traces(self):
|
||||||
|
"""Reads unread Profiler traces from this worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The traces in a format consumable by
|
||||||
|
``TorchTensorboardProfilerCallback``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_trace(filename):
|
||||||
|
trace_path = self.trace_dir.joinpath(filename)
|
||||||
|
return trace_path.read_text()
|
||||||
|
|
||||||
|
traces = [
|
||||||
|
(trace_filename, get_trace(trace_filename))
|
||||||
|
for trace_filename in self.profiler_trace_filenames
|
||||||
|
]
|
||||||
|
|
||||||
|
self.profiler_trace_files = []
|
||||||
|
return {PYTORCH_PROFILER_KEY: traces}
|
||||||
|
|
Loading…
Add table
Reference in a new issue