mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[train] add TorchTensorboardProfilerCallback (#22345)
The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ```
This commit is contained in:
parent
35a157948e
commit
8f9e0d7f6b
13 changed files with 393 additions and 7 deletions
|
@ -236,6 +236,7 @@ MOCK_MODULES = [
|
|||
"torch.distributed",
|
||||
"torch.nn",
|
||||
"torch.nn.parallel",
|
||||
"torch.profiler",
|
||||
"torch.utils.data",
|
||||
"torch.utils.data.distributed",
|
||||
"wandb",
|
||||
|
|
|
@ -86,6 +86,14 @@ MLflowLoggerCallback
|
|||
|
||||
.. autoclass:: ray.train.callbacks.MLflowLoggerCallback
|
||||
|
||||
|
||||
.. _train-api-torch-tensorboard-profiler-callback:
|
||||
|
||||
TorchTensorboardProfilerCallback
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: ray.train.callbacks.TorchTensorboardProfilerCallback
|
||||
|
||||
ResultsPreprocessors
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -175,6 +183,14 @@ train.torch.get_device
|
|||
|
||||
.. autofunction:: ray.train.torch.get_device
|
||||
|
||||
.. _train-api-torch-worker-profiler:
|
||||
|
||||
train.torch.TorchWorkerProfiler
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: ray.train.torch.TorchWorkerProfiler
|
||||
:members:
|
||||
|
||||
TensorFlow Training Function Utilities
|
||||
--------------------------------------
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ In this guide, we cover examples for the following use cases:
|
|||
* How do I :ref:`monitor <train-monitoring>` my training?
|
||||
* How do I run my training on pre-emptible instances
|
||||
(:ref:`fault tolerance <train-fault-tolerance>`)?
|
||||
* How do I :ref:`profile <train-profiling>` my training?
|
||||
* How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`?
|
||||
* How do I :ref:`tune <train-tune>` my Ray Train model?
|
||||
|
||||
|
@ -429,6 +430,7 @@ The following ``TrainingCallback``\s are available and will log the intermediate
|
|||
2. :ref:`train-api-json-logger-callback`
|
||||
3. :ref:`train-api-tbx-logger-callback`
|
||||
4. :ref:`train-api-mlflow-logger-callback`
|
||||
5. :ref:`train-api-torch-tensorboard-profiler-callback`
|
||||
|
||||
Example: Logging to MLflow and TensorBoard
|
||||
++++++++++++++++++++++++++++++++++++++++++
|
||||
|
@ -919,6 +921,60 @@ number of retries is configurable through the ``max_retries`` argument of the
|
|||
|
||||
.. TODO.
|
||||
|
||||
.. _train-profiling:
|
||||
|
||||
Profiling
|
||||
---------
|
||||
|
||||
Ray Train comes with an integration with `PyTorch Profiler <https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/>`_.
|
||||
Specifically, it comes with a :ref:`TorchWorkerProfiler <train-api-torch-worker-profiler>` utility class and :ref:`train-api-torch-tensorboard-profiler-callback` callback
|
||||
that allow you to use the PyTorch Profiler as you would in a non-distributed PyTorch script, and synchronize the generated Tensorboard traces onto
|
||||
the disk that from which your script was executed from.
|
||||
|
||||
**Step 1: Update training function with** ``TorchWorkerProfiler``
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
from ray.train.torch import TorchWorkerProfiler
|
||||
|
||||
def train_func():
|
||||
twp = TorchWorkerProfiler()
|
||||
with profile(..., on_trace_ready=twp.trace_handler) as p:
|
||||
...
|
||||
profile_results = twp.get_and_clear_profile_traces()
|
||||
train.report(..., **profile_results)
|
||||
...
|
||||
|
||||
**Step 2: Run training function with** ``TorchTensorboardProfilerCallback``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.train import Trainer
|
||||
from ray.train.callbacks import TorchTensorboardProfilerCallback
|
||||
|
||||
trainer = Trainer(backend="torch", num_workers=2)
|
||||
trainer.start()
|
||||
trainer.run(train_func, callbacks=[TorchTensorboardProfilerCallback()])
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
**Step 3: Visualize the logs**
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Navigate to the run directory of the trainer.
|
||||
# For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001/pytorch_profiler`
|
||||
$ cd <TRAINER_RUN_DIR>/pytorch_profiler
|
||||
|
||||
# Install the PyTorch Profiler TensorBoard Plugin.
|
||||
$ pip install torch_tb_profiler
|
||||
|
||||
# Star the TensorBoard UI.
|
||||
$ tensorboard --logdir .
|
||||
|
||||
# View the PyTorch Profiler traces.
|
||||
$ open http://localhost:6006/#pytorch_profiler
|
||||
|
||||
.. _train-datasets:
|
||||
|
||||
Distributed Data Ingest (Ray Datasets)
|
||||
|
|
|
@ -39,6 +39,15 @@ py_test(
|
|||
deps = [":train_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "torch_tensorboard_profiler_example",
|
||||
size = "small",
|
||||
main = "examples/torch_tensorboard_profiler_example.py",
|
||||
srcs = ["examples/torch_tensorboard_profiler_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":train_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "transformers_example",
|
||||
size = "large",
|
||||
|
|
|
@ -5,11 +5,13 @@ from ray.train.callbacks.logging import (
|
|||
TBXLoggerCallback,
|
||||
)
|
||||
from ray.train.callbacks.print import PrintCallback
|
||||
from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
|
||||
|
||||
__all__ = [
|
||||
"TrainingCallback",
|
||||
"JsonLoggerCallback",
|
||||
"MLflowLoggerCallback",
|
||||
"TBXLoggerCallback",
|
||||
"TorchTensorboardProfilerCallback",
|
||||
"PrintCallback",
|
||||
]
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
import abc
|
||||
from typing import List, Dict
|
||||
|
||||
from ray.train.callbacks.results_preprocessors import ResultsPreprocessor
|
||||
from ray.train.callbacks.results_preprocessors import (
|
||||
ResultsPreprocessor,
|
||||
ExcludedKeysResultsPreprocessor,
|
||||
SequentialResultsPreprocessor,
|
||||
)
|
||||
from ray.train.constants import ALL_RESERVED_KEYS
|
||||
|
||||
|
||||
class TrainingCallback(abc.ABC):
|
||||
"""Abstract Train callback class."""
|
||||
|
||||
results_preprocessor: ResultsPreprocessor = None
|
||||
# Reserved keys used by this specific Callback.
|
||||
# This should be set in a Callback class implementation so that the keys
|
||||
# are not filtered out. See ``_preprocess_results`` for more details.
|
||||
RESERVED_KEYS = {}
|
||||
|
||||
def start_training(self, logdir: str, config: Dict, **info):
|
||||
"""Called once on training start.
|
||||
|
@ -34,10 +43,37 @@ class TrainingCallback(abc.ABC):
|
|||
the training function from each worker.
|
||||
**info: kwargs dict for forward compatibility.
|
||||
"""
|
||||
if self.results_preprocessor:
|
||||
results = self.results_preprocessor.preprocess(results)
|
||||
results = self._preprocess_results(results)
|
||||
self.handle_result(results, **info)
|
||||
|
||||
def _preprocess_results(self, results: List[Dict]) -> List[Dict]:
|
||||
"""Preprocesses the reported training results.
|
||||
|
||||
This will:
|
||||
|
||||
* Exclude all keys that are present in ``self.ALL_RESERVED_KEYS`` but
|
||||
not ``self.RESERVED_KEYS``
|
||||
* Execute ``self.results_preprocessor`` if defined.
|
||||
|
||||
Args:
|
||||
results (List[Dict]): List of results from the training
|
||||
function. Each value in the list corresponds to the output of
|
||||
the training function from each worker.
|
||||
Returns:
|
||||
The preprocessed results.
|
||||
|
||||
"""
|
||||
results_to_exclude = ALL_RESERVED_KEYS.difference(self.RESERVED_KEYS)
|
||||
system_preprocessor = ExcludedKeysResultsPreprocessor(results_to_exclude)
|
||||
if self.results_preprocessor:
|
||||
self.results_preprocessor = SequentialResultsPreprocessor(
|
||||
[system_preprocessor, self.results_preprocessor]
|
||||
)
|
||||
else:
|
||||
self.results_preprocessor = system_preprocessor
|
||||
results = self.results_preprocessor.preprocess(results)
|
||||
return results
|
||||
|
||||
def handle_result(self, results: List[Dict], **info):
|
||||
"""Called every time train.report() is called after preprocessing.
|
||||
|
||||
|
|
|
@ -59,14 +59,14 @@ class TrainCallbackLogdirManager:
|
|||
self._logdir = Path(logdir) if logdir else None
|
||||
self._create_logdir = create_logdir
|
||||
|
||||
def setup_logdir(self, default_logdir: str) -> Path:
|
||||
def setup_logdir(self, default_logdir: Union[str, Path]) -> Path:
|
||||
"""Sets up the logdir.
|
||||
|
||||
The directory will be created if it does not exist and
|
||||
``create_logdir`` is set to True.
|
||||
|
||||
Args:
|
||||
default_logdir (str): The default logdir to use, only if the
|
||||
default_logdir (str|Path): The default logdir to use, only if the
|
||||
``TrainCallbackLogdirManager`` was not initialized with a ``logdir``.
|
||||
|
||||
Returns:
|
||||
|
|
53
python/ray/train/callbacks/profile.py
Normal file
53
python/ray/train/callbacks/profile.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
from ray.train.callbacks import TrainingCallback
|
||||
from ray.train.callbacks.logging import TrainCallbackLogdirManager
|
||||
from ray.train.callbacks.results_preprocessors import IndexedResultsPreprocessor
|
||||
from ray.train.constants import PYTORCH_PROFILER_KEY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DRIVER_TRACE_DIR_NAME = "pytorch_profiler"
|
||||
|
||||
|
||||
class TorchTensorboardProfilerCallback(TrainingCallback):
|
||||
"""Synchronizes PyTorch Profiler traces onto disk.
|
||||
|
||||
This should typically be used in conjunction with ``TorchWorkerProfiler``,
|
||||
though the actual requirement is for the ``_train_torch_profiler`` key
|
||||
to be populated in the results from ``train.report()``.
|
||||
|
||||
Args:
|
||||
logdir (Optional[str]): The directory to store traces. If ``None``,
|
||||
this will use a default temporary dir.
|
||||
workers_to_log (Optional[int|List[int]]): Worker indices to log.
|
||||
If ``None``, will log all workers. By default, this will log all
|
||||
workers.
|
||||
"""
|
||||
|
||||
RESERVED_KEYS = [PYTORCH_PROFILER_KEY]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logdir: Optional[str] = None,
|
||||
workers_to_log: Optional[Union[int, List[int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._logdir = logdir
|
||||
self._logdir_manager = TrainCallbackLogdirManager(logdir=logdir)
|
||||
self.results_preprocessor = IndexedResultsPreprocessor(indices=workers_to_log)
|
||||
|
||||
def start_training(self, logdir: str, **info):
|
||||
default_logdir = Path(logdir).joinpath(DRIVER_TRACE_DIR_NAME)
|
||||
self._logdir_manager.setup_logdir(default_logdir=default_logdir)
|
||||
|
||||
def handle_result(self, results: List[Dict], **info):
|
||||
for result in results:
|
||||
if PYTORCH_PROFILER_KEY in result and result[PYTORCH_PROFILER_KEY]:
|
||||
profile_traces = result[PYTORCH_PROFILER_KEY]
|
||||
for (name, data) in profile_traces:
|
||||
path = self._logdir_manager.logdir_path.joinpath(name)
|
||||
with path.open("w") as f:
|
||||
f.write(data)
|
|
@ -64,3 +64,13 @@ TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"
|
|||
# The key used to identify whether we have already warned about ray.train
|
||||
# functions being used outside of the session
|
||||
SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse"
|
||||
|
||||
# Reserved keyword used by the ``TorchWorkerProfiler`` and
|
||||
# ``TorchTensorboardProfilerCallback`` for passing PyTorch Profiler data
|
||||
# through ``train.report()``
|
||||
PYTORCH_PROFILER_KEY = "_train_torch_profiler"
|
||||
|
||||
# Reserved keys used across all Callbacks.
|
||||
# By default these will be filtered out from ``train.report()``.
|
||||
# See ``TrainingCallback._preprocess_results`` for more details.
|
||||
ALL_RESERVED_KEYS = {PYTORCH_PROFILER_KEY}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
||||
from torch.profiler import profile, record_function, schedule
|
||||
|
||||
import ray
|
||||
import ray.train as train
|
||||
from ray.train import Trainer
|
||||
from ray.train.callbacks import TBXLoggerCallback
|
||||
from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
|
||||
from ray.train.torch import TorchWorkerProfiler
|
||||
|
||||
|
||||
def train_func():
|
||||
twp = TorchWorkerProfiler()
|
||||
with profile(
|
||||
activities=[],
|
||||
schedule=schedule(wait=0, warmup=0, active=1),
|
||||
on_trace_ready=twp.trace_handler,
|
||||
) as p:
|
||||
|
||||
# Setup model.
|
||||
model = torch.nn.Linear(1, 1)
|
||||
model = train.torch.prepare_model(model)
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||
|
||||
# Setup data.
|
||||
input = torch.randn(1000, 1)
|
||||
labels = input * 2
|
||||
dataset = torch.utils.data.TensorDataset(input, labels)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
dataloader = train.torch.prepare_data_loader(dataloader)
|
||||
|
||||
# Train.
|
||||
for epoch in range(5):
|
||||
with record_function("train_epoch"):
|
||||
for X, y in dataloader:
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
with record_function("train_checkpoint"):
|
||||
state_dict = model.state_dict()
|
||||
consume_prefix_in_state_dict_if_present(state_dict, "module.")
|
||||
train.save_checkpoint(epoch=epoch, model_weights=state_dict)
|
||||
|
||||
p.step()
|
||||
|
||||
with record_function("train_report"):
|
||||
profile_results = twp.get_and_clear_profile_traces()
|
||||
train.report(epoch=epoch, **profile_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address", required=False, type=str, help="the address to use for Ray"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(address=args.address)
|
||||
|
||||
callbacks = [TorchTensorboardProfilerCallback(), TBXLoggerCallback()]
|
||||
trainer = Trainer(
|
||||
backend="torch", num_workers=args.num_workers, use_gpu=args.use_gpu
|
||||
)
|
||||
trainer.start()
|
||||
trainer.run(train_func, callbacks=callbacks)
|
||||
trainer.shutdown()
|
|
@ -11,7 +11,12 @@ import ray
|
|||
import ray.train as train
|
||||
from ray.train import Trainer
|
||||
from ray.train.backend import BackendConfig, Backend
|
||||
from ray.train.callbacks import JsonLoggerCallback, PrintCallback, TBXLoggerCallback
|
||||
from ray.train.callbacks import (
|
||||
JsonLoggerCallback,
|
||||
PrintCallback,
|
||||
TBXLoggerCallback,
|
||||
TorchTensorboardProfilerCallback,
|
||||
)
|
||||
from ray.train.callbacks.logging import MLflowLoggerCallback, TrainCallbackLogdirManager
|
||||
from ray.train.constants import (
|
||||
TRAINING_ITERATION,
|
||||
|
@ -255,6 +260,47 @@ def test_mlflow(ray_start_4_cpus, tmp_path):
|
|||
assert rewards == [4, 5, 6]
|
||||
|
||||
|
||||
def test_torch_tensorboard_profiler_callback(ray_start_4_cpus, tmp_path):
|
||||
config = TestConfig()
|
||||
|
||||
temp_dir = tmp_path
|
||||
num_workers = 4
|
||||
num_epochs = 2
|
||||
|
||||
def train_func():
|
||||
from ray.train.torch import TorchWorkerProfiler
|
||||
from torch.profiler import profile, record_function, schedule
|
||||
|
||||
twp = TorchWorkerProfiler()
|
||||
with profile(
|
||||
activities=[],
|
||||
schedule=schedule(wait=0, warmup=0, active=1),
|
||||
on_trace_ready=twp.trace_handler,
|
||||
) as p:
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
with record_function("test_function"):
|
||||
pass
|
||||
|
||||
p.step()
|
||||
|
||||
profile_results = twp.get_and_clear_profile_traces()
|
||||
train.report(epoch=epoch, **profile_results)
|
||||
|
||||
callback = TorchTensorboardProfilerCallback(temp_dir)
|
||||
trainer = Trainer(config, num_workers=num_workers)
|
||||
trainer.start()
|
||||
trainer.run(train_func, callbacks=[callback])
|
||||
|
||||
assert temp_dir.exists()
|
||||
|
||||
count = 0
|
||||
for path in temp_dir.iterdir():
|
||||
assert path.is_file()
|
||||
count += 1
|
||||
assert count == num_workers * num_epochs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -6,7 +6,7 @@ import ray
|
|||
import ray.train as train
|
||||
from ray.train import Trainer
|
||||
from ray.train.backend import BackendConfig, Backend
|
||||
from ray.train.callbacks.callback import TrainingCallback
|
||||
from ray.train.callbacks import TrainingCallback
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import ray
|
||||
from ray import train
|
||||
from ray.train.backend import BackendConfig, Backend, EncodedData
|
||||
from ray.train.constants import PYTORCH_PROFILER_KEY
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.train.utils import get_address_and_port
|
||||
|
||||
|
@ -23,6 +26,11 @@ from torch.utils.data import (
|
|||
SequentialSampler,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.profiler import profile
|
||||
except ImportError:
|
||||
profile = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -338,3 +346,68 @@ def prepare_data_loader(
|
|||
data_loader = _WrappedDataLoader(data_loader, device)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
WORKER_TRACE_DIR_NAME = "pytorch_profiler_worker_traces"
|
||||
|
||||
|
||||
class TorchWorkerProfiler:
|
||||
"""Utility class for running PyTorch Profiler on a Train worker.
|
||||
|
||||
Args:
|
||||
trace_dir (Optional[str]): The directory to store traces on the
|
||||
worker node. If ``None``, this will use a default temporary dir.
|
||||
"""
|
||||
|
||||
def __init__(self, trace_dir: Optional[str] = None):
|
||||
if profile is None:
|
||||
raise ImportError(
|
||||
"Torch Profiler requires torch>=1.8.1. "
|
||||
"Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler."
|
||||
)
|
||||
|
||||
trace_dir = trace_dir or Path(tempfile.gettempdir()).joinpath(
|
||||
WORKER_TRACE_DIR_NAME
|
||||
)
|
||||
self.trace_dir = Path(trace_dir)
|
||||
self.trace_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Accumulated traces.
|
||||
self.profiler_trace_filenames = []
|
||||
|
||||
def trace_handler(self, p: profile):
|
||||
"""A stateful PyTorch Profiler trace handler.
|
||||
|
||||
This will the export chrome trace to a file on disk.
|
||||
|
||||
These exported traces can then be fetched by calling
|
||||
``get_and_clear_profile_traces``.
|
||||
|
||||
Args:
|
||||
p (profile): A PyTorch Profiler profile.
|
||||
"""
|
||||
trace_filename = f"worker_{train.world_rank()}_epoch_{p.step_num}.pt.trace.json"
|
||||
trace_path = self.trace_dir.joinpath(trace_filename)
|
||||
|
||||
logger.debug(f"Writing worker trace to {trace_path}.")
|
||||
p.export_chrome_trace(str(trace_path))
|
||||
self.profiler_trace_filenames.append(trace_filename)
|
||||
|
||||
def get_and_clear_profile_traces(self):
|
||||
"""Reads unread Profiler traces from this worker.
|
||||
|
||||
Returns:
|
||||
The traces in a format consumable by
|
||||
``TorchTensorboardProfilerCallback``.
|
||||
"""
|
||||
|
||||
def get_trace(filename):
|
||||
trace_path = self.trace_dir.joinpath(filename)
|
||||
return trace_path.read_text()
|
||||
|
||||
traces = [
|
||||
(trace_filename, get_trace(trace_filename))
|
||||
for trace_filename in self.profiler_trace_filenames
|
||||
]
|
||||
|
||||
self.profiler_trace_files = []
|
||||
return {PYTORCH_PROFILER_KEY: traces}
|
||||
|
|
Loading…
Add table
Reference in a new issue