mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Introduce flag to use pickle for serialization (#5805)
This commit is contained in:
parent
29eee7f970
commit
d23696de17
9 changed files with 85 additions and 22 deletions
|
@ -1,12 +1,7 @@
|
|||
from __future__ import absolute_import
|
||||
import sys
|
||||
|
||||
# TODO(suquark): This is a temporary flag for
|
||||
# the new serialization implementation.
|
||||
# Remove it when the old one is deprecated.
|
||||
USE_NEW_SERIALIZER = False
|
||||
|
||||
if USE_NEW_SERIALIZER and sys.version_info[:2] >= (3, 8):
|
||||
if sys.version_info[:2] >= (3, 8):
|
||||
from ray.cloudpickle.cloudpickle_fast import *
|
||||
FAST_CLOUDPICKLE_USED = True
|
||||
else:
|
||||
|
|
|
@ -264,6 +264,10 @@ class Node(object):
|
|||
def load_code_from_local(self):
|
||||
return self._ray_params.load_code_from_local
|
||||
|
||||
@property
|
||||
def use_pickle(self):
|
||||
return self._ray_params.use_pickle
|
||||
|
||||
@property
|
||||
def object_id_seed(self):
|
||||
"""Get the seed for deterministic generation of object IDs"""
|
||||
|
@ -520,7 +524,7 @@ class Node(object):
|
|||
include_java=self._ray_params.include_java,
|
||||
java_worker_options=self._ray_params.java_worker_options,
|
||||
load_code_from_local=self._ray_params.load_code_from_local,
|
||||
)
|
||||
use_pickle=self._ray_params.use_pickle)
|
||||
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
|
||||
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ class RayParams(object):
|
|||
Java worker.
|
||||
java_worker_options (str): The command options for Java worker.
|
||||
load_code_from_local: Whether load code from local file or from GCS.
|
||||
use_pickle: Whether data objects should be serialized with cloudpickle.
|
||||
_internal_config (str): JSON configuration for overriding
|
||||
RayConfig defaults. For testing purposes ONLY.
|
||||
"""
|
||||
|
@ -113,6 +114,7 @@ class RayParams(object):
|
|||
include_java=False,
|
||||
java_worker_options=None,
|
||||
load_code_from_local=False,
|
||||
use_pickle=False,
|
||||
_internal_config=None):
|
||||
self.object_id_seed = object_id_seed
|
||||
self.redis_address = redis_address
|
||||
|
@ -146,6 +148,7 @@ class RayParams(object):
|
|||
self.include_java = include_java
|
||||
self.java_worker_options = java_worker_options
|
||||
self.load_code_from_local = load_code_from_local
|
||||
self.use_pickle = use_pickle
|
||||
self._internal_config = _internal_config
|
||||
self._check_usage()
|
||||
|
||||
|
|
|
@ -225,6 +225,11 @@ def cli(logging_level, logging_format):
|
|||
is_flag=True,
|
||||
default=False,
|
||||
help="Specify whether load code from local file or GCS serialization.")
|
||||
@click.option(
|
||||
"--use-pickle",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Use pickle for serialization.")
|
||||
def start(node_ip_address, redis_address, address, redis_port,
|
||||
num_redis_shards, redis_max_clients, redis_password,
|
||||
redis_shard_ports, object_manager_port, node_manager_port, memory,
|
||||
|
@ -232,7 +237,8 @@ def start(node_ip_address, redis_address, address, redis_port,
|
|||
head, include_webui, block, plasma_directory, huge_pages,
|
||||
autoscaling_config, no_redirect_worker_output, no_redirect_output,
|
||||
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
|
||||
java_worker_options, load_code_from_local, internal_config):
|
||||
java_worker_options, load_code_from_local, use_pickle,
|
||||
internal_config):
|
||||
# Convert hostnames to numerical IP address.
|
||||
if node_ip_address is not None:
|
||||
node_ip_address = services.address_to_ip(node_ip_address)
|
||||
|
@ -273,6 +279,7 @@ def start(node_ip_address, redis_address, address, redis_port,
|
|||
include_webui=include_webui,
|
||||
java_worker_options=java_worker_options,
|
||||
load_code_from_local=load_code_from_local,
|
||||
use_pickle=use_pickle,
|
||||
_internal_config=internal_config)
|
||||
|
||||
if head:
|
||||
|
|
|
@ -1060,7 +1060,8 @@ def start_raylet(redis_address,
|
|||
config=None,
|
||||
include_java=False,
|
||||
java_worker_options=None,
|
||||
load_code_from_local=False):
|
||||
load_code_from_local=False,
|
||||
use_pickle=False):
|
||||
"""Start a raylet, which is a combined local scheduler and object manager.
|
||||
|
||||
Args:
|
||||
|
@ -1092,6 +1093,7 @@ def start_raylet(redis_address,
|
|||
include_java (bool): If True, the raylet backend can also support
|
||||
Java worker.
|
||||
java_worker_options (str): The command options for Java worker.
|
||||
use_pickle (bool): If True, use cloudpickle for serialization.
|
||||
Returns:
|
||||
ProcessInfo for the process that was started.
|
||||
"""
|
||||
|
@ -1155,6 +1157,8 @@ def start_raylet(redis_address,
|
|||
|
||||
if load_code_from_local:
|
||||
start_worker_command += " --load-code-from-local "
|
||||
if use_pickle:
|
||||
start_worker_command += " --use-pickle "
|
||||
|
||||
command = [
|
||||
RAYLET_EXECUTABLE,
|
||||
|
|
|
@ -130,7 +130,7 @@ def test_fair_queueing(shutdown_only):
|
|||
assert len(ready) == 1000, len(ready)
|
||||
|
||||
|
||||
def test_complex_serialization(ray_start_regular):
|
||||
def complex_serialization(use_pickle):
|
||||
def assert_equal(obj1, obj2):
|
||||
module_numpy = (type(obj1).__module__ == np.__name__
|
||||
or type(obj2).__module__ == np.__name__)
|
||||
|
@ -340,6 +340,15 @@ def test_complex_serialization(ray_start_regular):
|
|||
assert ray.get(ray.put(s)).readline() == line
|
||||
|
||||
|
||||
def test_complex_serialization(ray_start_regular):
|
||||
complex_serialization(use_pickle=False)
|
||||
|
||||
|
||||
def test_complex_serialization_with_pickle(shutdown_only):
|
||||
ray.init(use_pickle=True)
|
||||
complex_serialization(use_pickle=True)
|
||||
|
||||
|
||||
def test_nested_functions(ray_start_regular):
|
||||
# Make sure that remote functions can use other values that are defined
|
||||
# after the remote function but before the first function invocation.
|
||||
|
@ -410,7 +419,7 @@ def test_ray_recursive_objects(ray_start_regular):
|
|||
# Create a list of recursive objects.
|
||||
recursive_objects = [lst, a1, a2, a3, d1]
|
||||
|
||||
if ray.worker.USE_NEW_SERIALIZER:
|
||||
if ray.worker.global_worker.use_pickle:
|
||||
# Serialize the recursive objects.
|
||||
for obj in recursive_objects:
|
||||
ray.put(obj)
|
||||
|
|
|
@ -551,3 +551,23 @@ print("success")
|
|||
|
||||
# Make sure we can still talk with the raylet.
|
||||
ray.get(f.remote())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"call_ray_start", ["ray start --head --num-cpus=1 --use-pickle"],
|
||||
indirect=True)
|
||||
def test_use_pickle(call_ray_start):
|
||||
address = call_ray_start
|
||||
|
||||
ray.init(address=address, use_pickle=True)
|
||||
|
||||
assert ray.worker.global_worker.use_pickle
|
||||
x = (2, "hello")
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
assert x == (2, "hello")
|
||||
assert ray.worker.global_worker.use_pickle
|
||||
return (3, "world")
|
||||
|
||||
assert ray.get(f.remote(x)) == (3, "world")
|
||||
|
|
|
@ -26,7 +26,6 @@ import random
|
|||
import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.cloudpickle import USE_NEW_SERIALIZER
|
||||
import ray.experimental.signal as ray_signal
|
||||
import ray.experimental.no_return
|
||||
import ray.gcs_utils
|
||||
|
@ -176,6 +175,11 @@ class Worker(object):
|
|||
self.check_connected()
|
||||
return self.node.load_code_from_local
|
||||
|
||||
@property
|
||||
def use_pickle(self):
|
||||
self.check_connected()
|
||||
return self.node.use_pickle
|
||||
|
||||
@property
|
||||
def task_context(self):
|
||||
"""A thread-local that contains the following attributes.
|
||||
|
@ -391,7 +395,7 @@ class Worker(object):
|
|||
for attempt in reversed(
|
||||
range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)):
|
||||
try:
|
||||
if USE_NEW_SERIALIZER:
|
||||
if self.use_pickle:
|
||||
self.store_with_plasma(object_id, value)
|
||||
else:
|
||||
self._try_store_and_register(object_id, value)
|
||||
|
@ -433,8 +437,13 @@ class Worker(object):
|
|||
value, object_id, memcopy_threads=self.memcopy_threads)
|
||||
else:
|
||||
writer = Pickle5Writer()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
if ray.cloudpickle.FAST_CLOUDPICKLE_USED:
|
||||
inband = pickle.dumps(
|
||||
value,
|
||||
protocol=5,
|
||||
buffer_callback=writer.buffer_callback)
|
||||
else:
|
||||
inband = pickle.dumps(value)
|
||||
self.core_worker.put_pickle5_buffers(object_id, inband, writer,
|
||||
self.memcopy_threads)
|
||||
except pyarrow.plasma.PlasmaObjectExists:
|
||||
|
@ -512,10 +521,12 @@ class Worker(object):
|
|||
def _deserialize_object_from_arrow(self, data, metadata, object_id,
|
||||
serialization_context):
|
||||
if metadata:
|
||||
if (USE_NEW_SERIALIZER
|
||||
and metadata == ray_constants.PICKLE5_BUFFER_METADATA):
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
in_band, buffers = unpack_pickle5_buffers(data)
|
||||
return pickle.loads(in_band, buffers=buffers)
|
||||
if len(buffers) > 0:
|
||||
return pickle.loads(in_band, buffers=buffers)
|
||||
else:
|
||||
return pickle.loads(in_band)
|
||||
# Check if the object should be returned as raw bytes.
|
||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||
if data is None:
|
||||
|
@ -1085,7 +1096,7 @@ def _initialize_serialization(job_id, worker=global_worker):
|
|||
|
||||
worker.serialization_context_map[job_id] = serialization_context
|
||||
|
||||
if not USE_NEW_SERIALIZER:
|
||||
if not worker.use_pickle:
|
||||
for error_cls in RAY_EXCEPTION_TYPES:
|
||||
register_custom_serializer(
|
||||
error_cls,
|
||||
|
@ -1158,6 +1169,7 @@ def init(address=None,
|
|||
raylet_socket_name=None,
|
||||
temp_dir=None,
|
||||
load_code_from_local=False,
|
||||
use_pickle=False,
|
||||
_internal_config=None):
|
||||
"""Connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
|
@ -1242,6 +1254,7 @@ def init(address=None,
|
|||
directory for the Ray process.
|
||||
load_code_from_local: Whether code should be loaded from a local module
|
||||
or from the GCS.
|
||||
use_pickle: Whether data objects should be serialized with cloudpickle.
|
||||
_internal_config (str): JSON configuration for overriding
|
||||
RayConfig defaults. For testing purposes ONLY.
|
||||
|
||||
|
@ -1316,6 +1329,7 @@ def init(address=None,
|
|||
raylet_socket_name=raylet_socket_name,
|
||||
temp_dir=temp_dir,
|
||||
load_code_from_local=load_code_from_local,
|
||||
use_pickle=use_pickle,
|
||||
_internal_config=_internal_config,
|
||||
)
|
||||
# Start the Ray processes. We set shutdown_at_exit=False because we
|
||||
|
@ -1372,7 +1386,8 @@ def init(address=None,
|
|||
redis_password=redis_password,
|
||||
object_id_seed=object_id_seed,
|
||||
temp_dir=temp_dir,
|
||||
load_code_from_local=load_code_from_local)
|
||||
load_code_from_local=load_code_from_local,
|
||||
use_pickle=use_pickle)
|
||||
_global_node = ray.node.Node(
|
||||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||
|
||||
|
@ -2045,7 +2060,7 @@ def register_custom_serializer(cls,
|
|||
assert isinstance(job_id, JobID)
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
if USE_NEW_SERIALIZER:
|
||||
if worker_info["worker"].use_pickle:
|
||||
if pickle.FAST_CLOUDPICKLE_USED:
|
||||
# construct a reducer
|
||||
pickle.CloudPickler.dispatch[
|
||||
|
|
|
@ -62,6 +62,11 @@ parser.add_argument(
|
|||
default=False,
|
||||
action="store_true",
|
||||
help="True if code is loaded from local files, as opposed to the GCS.")
|
||||
parser.add_argument(
|
||||
"--use-pickle",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="True if cloudpickle should be used for serialization.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
@ -75,7 +80,8 @@ if __name__ == "__main__":
|
|||
plasma_store_socket_name=args.object_store_name,
|
||||
raylet_socket_name=args.raylet_name,
|
||||
temp_dir=args.temp_dir,
|
||||
load_code_from_local=args.load_code_from_local)
|
||||
load_code_from_local=args.load_code_from_local,
|
||||
use_pickle=args.use_pickle)
|
||||
|
||||
node = ray.node.Node(
|
||||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||
|
|
Loading…
Add table
Reference in a new issue