Re-Revert "[Core] zero-copy serializer for pytorch (#12344)" (#12478)

* [Core] zero-copy serializer for pytorch (#12344)

* zero-copy serializer for pytorch

* address possible bottleneck

* add tests & device support

(cherry picked from commit 0a505ca83d)

* add environmental variables

* update doc
This commit is contained in:
Siyuan (Ryans) Zhuang 2020-11-30 11:43:03 -08:00 committed by GitHub
parent bb03e2499b
commit 3f22448834
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 1 deletions

View file

@ -29,6 +29,12 @@ The numpy array is stored as a read-only object, and all Ray workers on the same
.. tip:: You can often avoid serialization issues by using only native types (e.g., numpy arrays or lists/dicts of numpy arrays and other primitive types), or by using Actors hold objects that cannot be serialized.
PyTorch Tensors
---------------
Ray supports zero-copy serialization for PyTorch tensors by default. However, this would enforcing the importing of
the ``torch`` module. To disable this feature, one can set the environment variable ``RAY_DISABLE_PYTORCH_SERIALIZER=1``.
Serialization notes
-------------------

View file

@ -26,6 +26,7 @@ from ray._raylet import (
MessagePackSerializedObject,
RawSerializedObject,
)
from ray import serialization_addons
logger = logging.getLogger(__name__)
@ -155,6 +156,7 @@ class SerializationContext:
# Because objects have default __reduce__ method, we only need to
# treat ObjectRef specifically.
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
serialization_addons.apply(self)
def _register_cloudpickle_reducer(self, cls, reducer):
pickle.CloudPickler.dispatch[cls] = reducer

View file

@ -0,0 +1,76 @@
"""
This module is intended for implementing internal serializers for some
site packages.
"""
import os
import warnings
_TORCH_WARNING_FILTER_ACTIVATE = True
class _TorchTensorReducingHelper:
def __init__(self, tensor):
self.tensor = tensor
@classmethod
def rebuild_tensor(cls, rebuild_func, device, ndarray, params):
import torch
global _TORCH_WARNING_FILTER_ACTIVATE
# filtering warning messages would be the bottleneck for
# deserializing torch tensors. Since the warning only prompts once,
# we would only deal with it for the first time.
if _TORCH_WARNING_FILTER_ACTIVATE:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The given NumPy array is not writeable")
_tensor = torch.from_numpy(ndarray)
_TORCH_WARNING_FILTER_ACTIVATE = False
else:
_tensor = torch.from_numpy(ndarray)
if device != torch.device("cpu"):
_tensor = _tensor.to(device)
tensor = rebuild_func(_tensor.storage(), *params)
return cls(tensor)
@classmethod
def rebuild_sparse_tensor(cls, rebuild_func, content):
tensor = rebuild_func(*content)
return cls(tensor)
def __reduce_ex__(self, protocol):
_rebuild_func, content = self.tensor.__reduce_ex__(protocol)
if self.tensor.is_sparse:
# Torch will help us reduce the sparse tensor into
# several continuous tensors.
return self.rebuild_sparse_tensor, (_rebuild_func, content)
# By only replacing the storage with a numpy array, we can reuse
# zero-copy serialization while keeping all other params of the
# torch tensor.
return self.rebuild_tensor, (_rebuild_func, self.tensor.device,
self.tensor.detach().cpu().numpy(),
content[1:])
def _unwrap_tensor(s):
return s.tensor
def torch_tensor_reducer(tensor):
return _unwrap_tensor, (_TorchTensorReducingHelper(tensor), )
def register_pytorch_serializer(serialization_context):
try:
import torch
serialization_context._register_cloudpickle_reducer(
torch.Tensor, torch_tensor_reducer)
except ImportError:
pass
def apply(serialization_context):
if os.environ.get("RAY_DISABLE_PYTORCH_SERIALIZER") != "1":
register_pytorch_serializer(serialization_context)

View file

@ -543,7 +543,7 @@ def test_reducer_override_no_reference_cycle(ray_start_shared_local_modes):
assert new_obj() is None
def test_buffer_alignment():
def test_buffer_alignment(ray_start_shared_local_modes):
# Deserialized large numpy arrays should be 64-byte aligned.
x = np.random.normal(size=(10, 20, 30))
y = ray.get(ray.put(x))
@ -568,6 +568,30 @@ def test_buffer_alignment():
assert y.ctypes.data % 8 == 0
def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes):
import torch
# test dense tensor
tensor = torch.rand(32, 3, 64, 64)
ref = ray.put(tensor)
tensor_1, tensor_2 = ray.get([ref] * 2)
assert tensor_1.data_ptr() == tensor_2.data_ptr()
# test sparse tensor
i = torch.arange(0, 1024 * 1024, 4).view(1, -1)
v = torch.rand(1024 * 1024 // 4)
k = torch.sparse_coo_tensor(i, v, size=(1024 * 1024, ))
ref = ray.put(k)
tensor_1, tensor_2 = ray.get([ref] * 2)
assert tensor_1._indices().data_ptr() == tensor_2._indices().data_ptr()
assert tensor_1._values().data_ptr() == tensor_2._values().data_ptr()
# test attributes
tensor = torch.rand(4).requires_grad_(True)
ref = ray.put(tensor)
tensor = ray.get(ref)
assert tensor.requires_grad
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -6,6 +6,7 @@ import sys
if __name__ == "__main__":
# Do not import torch for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
os.environ["RAY_DISABLE_PYTORCH_SERIALIZER"] = "1"
from ray.rllib.agents.a3c import A2CTrainer
assert "torch" not in sys.modules, \
@ -23,5 +24,6 @@ if __name__ == "__main__":
# Clean up.
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
del os.environ["RAY_DISABLE_PYTORCH_SERIALIZER"]
print("ok")