mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
* [Core] zero-copy serializer for pytorch (#12344)
* zero-copy serializer for pytorch
* address possible bottleneck
* add tests & device support
(cherry picked from commit 0a505ca83d
)
* add environmental variables
* update doc
This commit is contained in:
parent
bb03e2499b
commit
3f22448834
5 changed files with 111 additions and 1 deletions
|
@ -29,6 +29,12 @@ The numpy array is stored as a read-only object, and all Ray workers on the same
|
|||
|
||||
.. tip:: You can often avoid serialization issues by using only native types (e.g., numpy arrays or lists/dicts of numpy arrays and other primitive types), or by using Actors hold objects that cannot be serialized.
|
||||
|
||||
PyTorch Tensors
|
||||
---------------
|
||||
|
||||
Ray supports zero-copy serialization for PyTorch tensors by default. However, this would enforcing the importing of
|
||||
the ``torch`` module. To disable this feature, one can set the environment variable ``RAY_DISABLE_PYTORCH_SERIALIZER=1``.
|
||||
|
||||
Serialization notes
|
||||
-------------------
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from ray._raylet import (
|
|||
MessagePackSerializedObject,
|
||||
RawSerializedObject,
|
||||
)
|
||||
from ray import serialization_addons
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -155,6 +156,7 @@ class SerializationContext:
|
|||
# Because objects have default __reduce__ method, we only need to
|
||||
# treat ObjectRef specifically.
|
||||
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
|
||||
serialization_addons.apply(self)
|
||||
|
||||
def _register_cloudpickle_reducer(self, cls, reducer):
|
||||
pickle.CloudPickler.dispatch[cls] = reducer
|
||||
|
|
76
python/ray/serialization_addons.py
Normal file
76
python/ray/serialization_addons.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
"""
|
||||
This module is intended for implementing internal serializers for some
|
||||
site packages.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
_TORCH_WARNING_FILTER_ACTIVATE = True
|
||||
|
||||
|
||||
class _TorchTensorReducingHelper:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
@classmethod
|
||||
def rebuild_tensor(cls, rebuild_func, device, ndarray, params):
|
||||
import torch
|
||||
global _TORCH_WARNING_FILTER_ACTIVATE
|
||||
# filtering warning messages would be the bottleneck for
|
||||
# deserializing torch tensors. Since the warning only prompts once,
|
||||
# we would only deal with it for the first time.
|
||||
if _TORCH_WARNING_FILTER_ACTIVATE:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=UserWarning,
|
||||
message="The given NumPy array is not writeable")
|
||||
_tensor = torch.from_numpy(ndarray)
|
||||
_TORCH_WARNING_FILTER_ACTIVATE = False
|
||||
else:
|
||||
_tensor = torch.from_numpy(ndarray)
|
||||
if device != torch.device("cpu"):
|
||||
_tensor = _tensor.to(device)
|
||||
tensor = rebuild_func(_tensor.storage(), *params)
|
||||
return cls(tensor)
|
||||
|
||||
@classmethod
|
||||
def rebuild_sparse_tensor(cls, rebuild_func, content):
|
||||
tensor = rebuild_func(*content)
|
||||
return cls(tensor)
|
||||
|
||||
def __reduce_ex__(self, protocol):
|
||||
_rebuild_func, content = self.tensor.__reduce_ex__(protocol)
|
||||
if self.tensor.is_sparse:
|
||||
# Torch will help us reduce the sparse tensor into
|
||||
# several continuous tensors.
|
||||
return self.rebuild_sparse_tensor, (_rebuild_func, content)
|
||||
# By only replacing the storage with a numpy array, we can reuse
|
||||
# zero-copy serialization while keeping all other params of the
|
||||
# torch tensor.
|
||||
return self.rebuild_tensor, (_rebuild_func, self.tensor.device,
|
||||
self.tensor.detach().cpu().numpy(),
|
||||
content[1:])
|
||||
|
||||
|
||||
def _unwrap_tensor(s):
|
||||
return s.tensor
|
||||
|
||||
|
||||
def torch_tensor_reducer(tensor):
|
||||
return _unwrap_tensor, (_TorchTensorReducingHelper(tensor), )
|
||||
|
||||
|
||||
def register_pytorch_serializer(serialization_context):
|
||||
try:
|
||||
import torch
|
||||
serialization_context._register_cloudpickle_reducer(
|
||||
torch.Tensor, torch_tensor_reducer)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def apply(serialization_context):
|
||||
if os.environ.get("RAY_DISABLE_PYTORCH_SERIALIZER") != "1":
|
||||
register_pytorch_serializer(serialization_context)
|
|
@ -543,7 +543,7 @@ def test_reducer_override_no_reference_cycle(ray_start_shared_local_modes):
|
|||
assert new_obj() is None
|
||||
|
||||
|
||||
def test_buffer_alignment():
|
||||
def test_buffer_alignment(ray_start_shared_local_modes):
|
||||
# Deserialized large numpy arrays should be 64-byte aligned.
|
||||
x = np.random.normal(size=(10, 20, 30))
|
||||
y = ray.get(ray.put(x))
|
||||
|
@ -568,6 +568,30 @@ def test_buffer_alignment():
|
|||
assert y.ctypes.data % 8 == 0
|
||||
|
||||
|
||||
def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes):
|
||||
import torch
|
||||
# test dense tensor
|
||||
tensor = torch.rand(32, 3, 64, 64)
|
||||
ref = ray.put(tensor)
|
||||
tensor_1, tensor_2 = ray.get([ref] * 2)
|
||||
assert tensor_1.data_ptr() == tensor_2.data_ptr()
|
||||
|
||||
# test sparse tensor
|
||||
i = torch.arange(0, 1024 * 1024, 4).view(1, -1)
|
||||
v = torch.rand(1024 * 1024 // 4)
|
||||
k = torch.sparse_coo_tensor(i, v, size=(1024 * 1024, ))
|
||||
ref = ray.put(k)
|
||||
tensor_1, tensor_2 = ray.get([ref] * 2)
|
||||
assert tensor_1._indices().data_ptr() == tensor_2._indices().data_ptr()
|
||||
assert tensor_1._values().data_ptr() == tensor_2._values().data_ptr()
|
||||
|
||||
# test attributes
|
||||
tensor = torch.rand(4).requires_grad_(True)
|
||||
ref = ray.put(tensor)
|
||||
tensor = ray.get(ref)
|
||||
assert tensor.requires_grad
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -6,6 +6,7 @@ import sys
|
|||
if __name__ == "__main__":
|
||||
# Do not import torch for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
os.environ["RAY_DISABLE_PYTORCH_SERIALIZER"] = "1"
|
||||
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
assert "torch" not in sys.modules, \
|
||||
|
@ -23,5 +24,6 @@ if __name__ == "__main__":
|
|||
|
||||
# Clean up.
|
||||
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
|
||||
del os.environ["RAY_DISABLE_PYTORCH_SERIALIZER"]
|
||||
|
||||
print("ok")
|
||||
|
|
Loading…
Add table
Reference in a new issue