Introduce flag to use pickle for serialization (#5805)

This commit is contained in:
Philipp Moritz 2019-10-18 22:29:36 -07:00 committed by GitHub
parent 29eee7f970
commit d23696de17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 85 additions and 22 deletions

View file

@ -1,12 +1,7 @@
from __future__ import absolute_import
import sys
# TODO(suquark): This is a temporary flag for
# the new serialization implementation.
# Remove it when the old one is deprecated.
USE_NEW_SERIALIZER = False
if USE_NEW_SERIALIZER and sys.version_info[:2] >= (3, 8):
if sys.version_info[:2] >= (3, 8):
from ray.cloudpickle.cloudpickle_fast import *
FAST_CLOUDPICKLE_USED = True
else:

View file

@ -264,6 +264,10 @@ class Node(object):
def load_code_from_local(self):
return self._ray_params.load_code_from_local
@property
def use_pickle(self):
return self._ray_params.use_pickle
@property
def object_id_seed(self):
"""Get the seed for deterministic generation of object IDs"""
@ -520,7 +524,7 @@ class Node(object):
include_java=self._ray_params.include_java,
java_worker_options=self._ray_params.java_worker_options,
load_code_from_local=self._ray_params.load_code_from_local,
)
use_pickle=self._ray_params.use_pickle)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]

View file

@ -74,6 +74,7 @@ class RayParams(object):
Java worker.
java_worker_options (str): The command options for Java worker.
load_code_from_local: Whether load code from local file or from GCS.
use_pickle: Whether data objects should be serialized with cloudpickle.
_internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY.
"""
@ -113,6 +114,7 @@ class RayParams(object):
include_java=False,
java_worker_options=None,
load_code_from_local=False,
use_pickle=False,
_internal_config=None):
self.object_id_seed = object_id_seed
self.redis_address = redis_address
@ -146,6 +148,7 @@ class RayParams(object):
self.include_java = include_java
self.java_worker_options = java_worker_options
self.load_code_from_local = load_code_from_local
self.use_pickle = use_pickle
self._internal_config = _internal_config
self._check_usage()

View file

@ -225,6 +225,11 @@ def cli(logging_level, logging_format):
is_flag=True,
default=False,
help="Specify whether load code from local file or GCS serialization.")
@click.option(
"--use-pickle",
is_flag=True,
default=False,
help="Use pickle for serialization.")
def start(node_ip_address, redis_address, address, redis_port,
num_redis_shards, redis_max_clients, redis_password,
redis_shard_ports, object_manager_port, node_manager_port, memory,
@ -232,7 +237,8 @@ def start(node_ip_address, redis_address, address, redis_port,
head, include_webui, block, plasma_directory, huge_pages,
autoscaling_config, no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
java_worker_options, load_code_from_local, internal_config):
java_worker_options, load_code_from_local, use_pickle,
internal_config):
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
@ -273,6 +279,7 @@ def start(node_ip_address, redis_address, address, redis_port,
include_webui=include_webui,
java_worker_options=java_worker_options,
load_code_from_local=load_code_from_local,
use_pickle=use_pickle,
_internal_config=internal_config)
if head:

View file

@ -1060,7 +1060,8 @@ def start_raylet(redis_address,
config=None,
include_java=False,
java_worker_options=None,
load_code_from_local=False):
load_code_from_local=False,
use_pickle=False):
"""Start a raylet, which is a combined local scheduler and object manager.
Args:
@ -1092,6 +1093,7 @@ def start_raylet(redis_address,
include_java (bool): If True, the raylet backend can also support
Java worker.
java_worker_options (str): The command options for Java worker.
use_pickle (bool): If True, use cloudpickle for serialization.
Returns:
ProcessInfo for the process that was started.
"""
@ -1155,6 +1157,8 @@ def start_raylet(redis_address,
if load_code_from_local:
start_worker_command += " --load-code-from-local "
if use_pickle:
start_worker_command += " --use-pickle "
command = [
RAYLET_EXECUTABLE,

View file

@ -130,7 +130,7 @@ def test_fair_queueing(shutdown_only):
assert len(ready) == 1000, len(ready)
def test_complex_serialization(ray_start_regular):
def complex_serialization(use_pickle):
def assert_equal(obj1, obj2):
module_numpy = (type(obj1).__module__ == np.__name__
or type(obj2).__module__ == np.__name__)
@ -340,6 +340,15 @@ def test_complex_serialization(ray_start_regular):
assert ray.get(ray.put(s)).readline() == line
def test_complex_serialization(ray_start_regular):
complex_serialization(use_pickle=False)
def test_complex_serialization_with_pickle(shutdown_only):
ray.init(use_pickle=True)
complex_serialization(use_pickle=True)
def test_nested_functions(ray_start_regular):
# Make sure that remote functions can use other values that are defined
# after the remote function but before the first function invocation.
@ -410,7 +419,7 @@ def test_ray_recursive_objects(ray_start_regular):
# Create a list of recursive objects.
recursive_objects = [lst, a1, a2, a3, d1]
if ray.worker.USE_NEW_SERIALIZER:
if ray.worker.global_worker.use_pickle:
# Serialize the recursive objects.
for obj in recursive_objects:
ray.put(obj)

View file

@ -551,3 +551,23 @@ print("success")
# Make sure we can still talk with the raylet.
ray.get(f.remote())
@pytest.mark.parametrize(
"call_ray_start", ["ray start --head --num-cpus=1 --use-pickle"],
indirect=True)
def test_use_pickle(call_ray_start):
address = call_ray_start
ray.init(address=address, use_pickle=True)
assert ray.worker.global_worker.use_pickle
x = (2, "hello")
@ray.remote
def f(x):
assert x == (2, "hello")
assert ray.worker.global_worker.use_pickle
return (3, "world")
assert ray.get(f.remote(x)) == (3, "world")

View file

@ -26,7 +26,6 @@ import random
import pyarrow
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
from ray.cloudpickle import USE_NEW_SERIALIZER
import ray.experimental.signal as ray_signal
import ray.experimental.no_return
import ray.gcs_utils
@ -176,6 +175,11 @@ class Worker(object):
self.check_connected()
return self.node.load_code_from_local
@property
def use_pickle(self):
self.check_connected()
return self.node.use_pickle
@property
def task_context(self):
"""A thread-local that contains the following attributes.
@ -391,7 +395,7 @@ class Worker(object):
for attempt in reversed(
range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)):
try:
if USE_NEW_SERIALIZER:
if self.use_pickle:
self.store_with_plasma(object_id, value)
else:
self._try_store_and_register(object_id, value)
@ -433,8 +437,13 @@ class Worker(object):
value, object_id, memcopy_threads=self.memcopy_threads)
else:
writer = Pickle5Writer()
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
if ray.cloudpickle.FAST_CLOUDPICKLE_USED:
inband = pickle.dumps(
value,
protocol=5,
buffer_callback=writer.buffer_callback)
else:
inband = pickle.dumps(value)
self.core_worker.put_pickle5_buffers(object_id, inband, writer,
self.memcopy_threads)
except pyarrow.plasma.PlasmaObjectExists:
@ -512,10 +521,12 @@ class Worker(object):
def _deserialize_object_from_arrow(self, data, metadata, object_id,
serialization_context):
if metadata:
if (USE_NEW_SERIALIZER
and metadata == ray_constants.PICKLE5_BUFFER_METADATA):
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
in_band, buffers = unpack_pickle5_buffers(data)
return pickle.loads(in_band, buffers=buffers)
if len(buffers) > 0:
return pickle.loads(in_band, buffers=buffers)
else:
return pickle.loads(in_band)
# Check if the object should be returned as raw bytes.
if metadata == ray_constants.RAW_BUFFER_METADATA:
if data is None:
@ -1085,7 +1096,7 @@ def _initialize_serialization(job_id, worker=global_worker):
worker.serialization_context_map[job_id] = serialization_context
if not USE_NEW_SERIALIZER:
if not worker.use_pickle:
for error_cls in RAY_EXCEPTION_TYPES:
register_custom_serializer(
error_cls,
@ -1158,6 +1169,7 @@ def init(address=None,
raylet_socket_name=None,
temp_dir=None,
load_code_from_local=False,
use_pickle=False,
_internal_config=None):
"""Connect to an existing Ray cluster or start one and connect to it.
@ -1242,6 +1254,7 @@ def init(address=None,
directory for the Ray process.
load_code_from_local: Whether code should be loaded from a local module
or from the GCS.
use_pickle: Whether data objects should be serialized with cloudpickle.
_internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY.
@ -1316,6 +1329,7 @@ def init(address=None,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir,
load_code_from_local=load_code_from_local,
use_pickle=use_pickle,
_internal_config=_internal_config,
)
# Start the Ray processes. We set shutdown_at_exit=False because we
@ -1372,7 +1386,8 @@ def init(address=None,
redis_password=redis_password,
object_id_seed=object_id_seed,
temp_dir=temp_dir,
load_code_from_local=load_code_from_local)
load_code_from_local=load_code_from_local,
use_pickle=use_pickle)
_global_node = ray.node.Node(
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
@ -2045,7 +2060,7 @@ def register_custom_serializer(cls,
assert isinstance(job_id, JobID)
def register_class_for_serialization(worker_info):
if USE_NEW_SERIALIZER:
if worker_info["worker"].use_pickle:
if pickle.FAST_CLOUDPICKLE_USED:
# construct a reducer
pickle.CloudPickler.dispatch[

View file

@ -62,6 +62,11 @@ parser.add_argument(
default=False,
action="store_true",
help="True if code is loaded from local files, as opposed to the GCS.")
parser.add_argument(
"--use-pickle",
default=False,
action="store_true",
help="True if cloudpickle should be used for serialization.")
if __name__ == "__main__":
args = parser.parse_args()
@ -75,7 +80,8 @@ if __name__ == "__main__":
plasma_store_socket_name=args.object_store_name,
raylet_socket_name=args.raylet_name,
temp_dir=args.temp_dir,
load_code_from_local=args.load_code_from_local)
load_code_from_local=args.load_code_from_local,
use_pickle=args.use_pickle)
node = ray.node.Node(
ray_params, head=False, shutdown_at_exit=False, connect_only=True)