mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Introduce flag to use pickle for serialization (#5805)
This commit is contained in:
parent
29eee7f970
commit
d23696de17
9 changed files with 85 additions and 22 deletions
|
@ -1,12 +1,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# TODO(suquark): This is a temporary flag for
|
if sys.version_info[:2] >= (3, 8):
|
||||||
# the new serialization implementation.
|
|
||||||
# Remove it when the old one is deprecated.
|
|
||||||
USE_NEW_SERIALIZER = False
|
|
||||||
|
|
||||||
if USE_NEW_SERIALIZER and sys.version_info[:2] >= (3, 8):
|
|
||||||
from ray.cloudpickle.cloudpickle_fast import *
|
from ray.cloudpickle.cloudpickle_fast import *
|
||||||
FAST_CLOUDPICKLE_USED = True
|
FAST_CLOUDPICKLE_USED = True
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -264,6 +264,10 @@ class Node(object):
|
||||||
def load_code_from_local(self):
|
def load_code_from_local(self):
|
||||||
return self._ray_params.load_code_from_local
|
return self._ray_params.load_code_from_local
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pickle(self):
|
||||||
|
return self._ray_params.use_pickle
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def object_id_seed(self):
|
def object_id_seed(self):
|
||||||
"""Get the seed for deterministic generation of object IDs"""
|
"""Get the seed for deterministic generation of object IDs"""
|
||||||
|
@ -520,7 +524,7 @@ class Node(object):
|
||||||
include_java=self._ray_params.include_java,
|
include_java=self._ray_params.include_java,
|
||||||
java_worker_options=self._ray_params.java_worker_options,
|
java_worker_options=self._ray_params.java_worker_options,
|
||||||
load_code_from_local=self._ray_params.load_code_from_local,
|
load_code_from_local=self._ray_params.load_code_from_local,
|
||||||
)
|
use_pickle=self._ray_params.use_pickle)
|
||||||
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
|
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
|
||||||
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
|
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,7 @@ class RayParams(object):
|
||||||
Java worker.
|
Java worker.
|
||||||
java_worker_options (str): The command options for Java worker.
|
java_worker_options (str): The command options for Java worker.
|
||||||
load_code_from_local: Whether load code from local file or from GCS.
|
load_code_from_local: Whether load code from local file or from GCS.
|
||||||
|
use_pickle: Whether data objects should be serialized with cloudpickle.
|
||||||
_internal_config (str): JSON configuration for overriding
|
_internal_config (str): JSON configuration for overriding
|
||||||
RayConfig defaults. For testing purposes ONLY.
|
RayConfig defaults. For testing purposes ONLY.
|
||||||
"""
|
"""
|
||||||
|
@ -113,6 +114,7 @@ class RayParams(object):
|
||||||
include_java=False,
|
include_java=False,
|
||||||
java_worker_options=None,
|
java_worker_options=None,
|
||||||
load_code_from_local=False,
|
load_code_from_local=False,
|
||||||
|
use_pickle=False,
|
||||||
_internal_config=None):
|
_internal_config=None):
|
||||||
self.object_id_seed = object_id_seed
|
self.object_id_seed = object_id_seed
|
||||||
self.redis_address = redis_address
|
self.redis_address = redis_address
|
||||||
|
@ -146,6 +148,7 @@ class RayParams(object):
|
||||||
self.include_java = include_java
|
self.include_java = include_java
|
||||||
self.java_worker_options = java_worker_options
|
self.java_worker_options = java_worker_options
|
||||||
self.load_code_from_local = load_code_from_local
|
self.load_code_from_local = load_code_from_local
|
||||||
|
self.use_pickle = use_pickle
|
||||||
self._internal_config = _internal_config
|
self._internal_config = _internal_config
|
||||||
self._check_usage()
|
self._check_usage()
|
||||||
|
|
||||||
|
|
|
@ -225,6 +225,11 @@ def cli(logging_level, logging_format):
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
default=False,
|
default=False,
|
||||||
help="Specify whether load code from local file or GCS serialization.")
|
help="Specify whether load code from local file or GCS serialization.")
|
||||||
|
@click.option(
|
||||||
|
"--use-pickle",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Use pickle for serialization.")
|
||||||
def start(node_ip_address, redis_address, address, redis_port,
|
def start(node_ip_address, redis_address, address, redis_port,
|
||||||
num_redis_shards, redis_max_clients, redis_password,
|
num_redis_shards, redis_max_clients, redis_password,
|
||||||
redis_shard_ports, object_manager_port, node_manager_port, memory,
|
redis_shard_ports, object_manager_port, node_manager_port, memory,
|
||||||
|
@ -232,7 +237,8 @@ def start(node_ip_address, redis_address, address, redis_port,
|
||||||
head, include_webui, block, plasma_directory, huge_pages,
|
head, include_webui, block, plasma_directory, huge_pages,
|
||||||
autoscaling_config, no_redirect_worker_output, no_redirect_output,
|
autoscaling_config, no_redirect_worker_output, no_redirect_output,
|
||||||
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
|
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
|
||||||
java_worker_options, load_code_from_local, internal_config):
|
java_worker_options, load_code_from_local, use_pickle,
|
||||||
|
internal_config):
|
||||||
# Convert hostnames to numerical IP address.
|
# Convert hostnames to numerical IP address.
|
||||||
if node_ip_address is not None:
|
if node_ip_address is not None:
|
||||||
node_ip_address = services.address_to_ip(node_ip_address)
|
node_ip_address = services.address_to_ip(node_ip_address)
|
||||||
|
@ -273,6 +279,7 @@ def start(node_ip_address, redis_address, address, redis_port,
|
||||||
include_webui=include_webui,
|
include_webui=include_webui,
|
||||||
java_worker_options=java_worker_options,
|
java_worker_options=java_worker_options,
|
||||||
load_code_from_local=load_code_from_local,
|
load_code_from_local=load_code_from_local,
|
||||||
|
use_pickle=use_pickle,
|
||||||
_internal_config=internal_config)
|
_internal_config=internal_config)
|
||||||
|
|
||||||
if head:
|
if head:
|
||||||
|
|
|
@ -1060,7 +1060,8 @@ def start_raylet(redis_address,
|
||||||
config=None,
|
config=None,
|
||||||
include_java=False,
|
include_java=False,
|
||||||
java_worker_options=None,
|
java_worker_options=None,
|
||||||
load_code_from_local=False):
|
load_code_from_local=False,
|
||||||
|
use_pickle=False):
|
||||||
"""Start a raylet, which is a combined local scheduler and object manager.
|
"""Start a raylet, which is a combined local scheduler and object manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1092,6 +1093,7 @@ def start_raylet(redis_address,
|
||||||
include_java (bool): If True, the raylet backend can also support
|
include_java (bool): If True, the raylet backend can also support
|
||||||
Java worker.
|
Java worker.
|
||||||
java_worker_options (str): The command options for Java worker.
|
java_worker_options (str): The command options for Java worker.
|
||||||
|
use_pickle (bool): If True, use cloudpickle for serialization.
|
||||||
Returns:
|
Returns:
|
||||||
ProcessInfo for the process that was started.
|
ProcessInfo for the process that was started.
|
||||||
"""
|
"""
|
||||||
|
@ -1155,6 +1157,8 @@ def start_raylet(redis_address,
|
||||||
|
|
||||||
if load_code_from_local:
|
if load_code_from_local:
|
||||||
start_worker_command += " --load-code-from-local "
|
start_worker_command += " --load-code-from-local "
|
||||||
|
if use_pickle:
|
||||||
|
start_worker_command += " --use-pickle "
|
||||||
|
|
||||||
command = [
|
command = [
|
||||||
RAYLET_EXECUTABLE,
|
RAYLET_EXECUTABLE,
|
||||||
|
|
|
@ -130,7 +130,7 @@ def test_fair_queueing(shutdown_only):
|
||||||
assert len(ready) == 1000, len(ready)
|
assert len(ready) == 1000, len(ready)
|
||||||
|
|
||||||
|
|
||||||
def test_complex_serialization(ray_start_regular):
|
def complex_serialization(use_pickle):
|
||||||
def assert_equal(obj1, obj2):
|
def assert_equal(obj1, obj2):
|
||||||
module_numpy = (type(obj1).__module__ == np.__name__
|
module_numpy = (type(obj1).__module__ == np.__name__
|
||||||
or type(obj2).__module__ == np.__name__)
|
or type(obj2).__module__ == np.__name__)
|
||||||
|
@ -340,6 +340,15 @@ def test_complex_serialization(ray_start_regular):
|
||||||
assert ray.get(ray.put(s)).readline() == line
|
assert ray.get(ray.put(s)).readline() == line
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_serialization(ray_start_regular):
|
||||||
|
complex_serialization(use_pickle=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_serialization_with_pickle(shutdown_only):
|
||||||
|
ray.init(use_pickle=True)
|
||||||
|
complex_serialization(use_pickle=True)
|
||||||
|
|
||||||
|
|
||||||
def test_nested_functions(ray_start_regular):
|
def test_nested_functions(ray_start_regular):
|
||||||
# Make sure that remote functions can use other values that are defined
|
# Make sure that remote functions can use other values that are defined
|
||||||
# after the remote function but before the first function invocation.
|
# after the remote function but before the first function invocation.
|
||||||
|
@ -410,7 +419,7 @@ def test_ray_recursive_objects(ray_start_regular):
|
||||||
# Create a list of recursive objects.
|
# Create a list of recursive objects.
|
||||||
recursive_objects = [lst, a1, a2, a3, d1]
|
recursive_objects = [lst, a1, a2, a3, d1]
|
||||||
|
|
||||||
if ray.worker.USE_NEW_SERIALIZER:
|
if ray.worker.global_worker.use_pickle:
|
||||||
# Serialize the recursive objects.
|
# Serialize the recursive objects.
|
||||||
for obj in recursive_objects:
|
for obj in recursive_objects:
|
||||||
ray.put(obj)
|
ray.put(obj)
|
||||||
|
|
|
@ -551,3 +551,23 @@ print("success")
|
||||||
|
|
||||||
# Make sure we can still talk with the raylet.
|
# Make sure we can still talk with the raylet.
|
||||||
ray.get(f.remote())
|
ray.get(f.remote())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"call_ray_start", ["ray start --head --num-cpus=1 --use-pickle"],
|
||||||
|
indirect=True)
|
||||||
|
def test_use_pickle(call_ray_start):
|
||||||
|
address = call_ray_start
|
||||||
|
|
||||||
|
ray.init(address=address, use_pickle=True)
|
||||||
|
|
||||||
|
assert ray.worker.global_worker.use_pickle
|
||||||
|
x = (2, "hello")
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def f(x):
|
||||||
|
assert x == (2, "hello")
|
||||||
|
assert ray.worker.global_worker.use_pickle
|
||||||
|
return (3, "world")
|
||||||
|
|
||||||
|
assert ray.get(f.remote(x)) == (3, "world")
|
||||||
|
|
|
@ -26,7 +26,6 @@ import random
|
||||||
import pyarrow
|
import pyarrow
|
||||||
import pyarrow.plasma as plasma
|
import pyarrow.plasma as plasma
|
||||||
import ray.cloudpickle as pickle
|
import ray.cloudpickle as pickle
|
||||||
from ray.cloudpickle import USE_NEW_SERIALIZER
|
|
||||||
import ray.experimental.signal as ray_signal
|
import ray.experimental.signal as ray_signal
|
||||||
import ray.experimental.no_return
|
import ray.experimental.no_return
|
||||||
import ray.gcs_utils
|
import ray.gcs_utils
|
||||||
|
@ -176,6 +175,11 @@ class Worker(object):
|
||||||
self.check_connected()
|
self.check_connected()
|
||||||
return self.node.load_code_from_local
|
return self.node.load_code_from_local
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pickle(self):
|
||||||
|
self.check_connected()
|
||||||
|
return self.node.use_pickle
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_context(self):
|
def task_context(self):
|
||||||
"""A thread-local that contains the following attributes.
|
"""A thread-local that contains the following attributes.
|
||||||
|
@ -391,7 +395,7 @@ class Worker(object):
|
||||||
for attempt in reversed(
|
for attempt in reversed(
|
||||||
range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)):
|
range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)):
|
||||||
try:
|
try:
|
||||||
if USE_NEW_SERIALIZER:
|
if self.use_pickle:
|
||||||
self.store_with_plasma(object_id, value)
|
self.store_with_plasma(object_id, value)
|
||||||
else:
|
else:
|
||||||
self._try_store_and_register(object_id, value)
|
self._try_store_and_register(object_id, value)
|
||||||
|
@ -433,8 +437,13 @@ class Worker(object):
|
||||||
value, object_id, memcopy_threads=self.memcopy_threads)
|
value, object_id, memcopy_threads=self.memcopy_threads)
|
||||||
else:
|
else:
|
||||||
writer = Pickle5Writer()
|
writer = Pickle5Writer()
|
||||||
|
if ray.cloudpickle.FAST_CLOUDPICKLE_USED:
|
||||||
inband = pickle.dumps(
|
inband = pickle.dumps(
|
||||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
value,
|
||||||
|
protocol=5,
|
||||||
|
buffer_callback=writer.buffer_callback)
|
||||||
|
else:
|
||||||
|
inband = pickle.dumps(value)
|
||||||
self.core_worker.put_pickle5_buffers(object_id, inband, writer,
|
self.core_worker.put_pickle5_buffers(object_id, inband, writer,
|
||||||
self.memcopy_threads)
|
self.memcopy_threads)
|
||||||
except pyarrow.plasma.PlasmaObjectExists:
|
except pyarrow.plasma.PlasmaObjectExists:
|
||||||
|
@ -512,10 +521,12 @@ class Worker(object):
|
||||||
def _deserialize_object_from_arrow(self, data, metadata, object_id,
|
def _deserialize_object_from_arrow(self, data, metadata, object_id,
|
||||||
serialization_context):
|
serialization_context):
|
||||||
if metadata:
|
if metadata:
|
||||||
if (USE_NEW_SERIALIZER
|
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||||
and metadata == ray_constants.PICKLE5_BUFFER_METADATA):
|
|
||||||
in_band, buffers = unpack_pickle5_buffers(data)
|
in_band, buffers = unpack_pickle5_buffers(data)
|
||||||
|
if len(buffers) > 0:
|
||||||
return pickle.loads(in_band, buffers=buffers)
|
return pickle.loads(in_band, buffers=buffers)
|
||||||
|
else:
|
||||||
|
return pickle.loads(in_band)
|
||||||
# Check if the object should be returned as raw bytes.
|
# Check if the object should be returned as raw bytes.
|
||||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||||
if data is None:
|
if data is None:
|
||||||
|
@ -1085,7 +1096,7 @@ def _initialize_serialization(job_id, worker=global_worker):
|
||||||
|
|
||||||
worker.serialization_context_map[job_id] = serialization_context
|
worker.serialization_context_map[job_id] = serialization_context
|
||||||
|
|
||||||
if not USE_NEW_SERIALIZER:
|
if not worker.use_pickle:
|
||||||
for error_cls in RAY_EXCEPTION_TYPES:
|
for error_cls in RAY_EXCEPTION_TYPES:
|
||||||
register_custom_serializer(
|
register_custom_serializer(
|
||||||
error_cls,
|
error_cls,
|
||||||
|
@ -1158,6 +1169,7 @@ def init(address=None,
|
||||||
raylet_socket_name=None,
|
raylet_socket_name=None,
|
||||||
temp_dir=None,
|
temp_dir=None,
|
||||||
load_code_from_local=False,
|
load_code_from_local=False,
|
||||||
|
use_pickle=False,
|
||||||
_internal_config=None):
|
_internal_config=None):
|
||||||
"""Connect to an existing Ray cluster or start one and connect to it.
|
"""Connect to an existing Ray cluster or start one and connect to it.
|
||||||
|
|
||||||
|
@ -1242,6 +1254,7 @@ def init(address=None,
|
||||||
directory for the Ray process.
|
directory for the Ray process.
|
||||||
load_code_from_local: Whether code should be loaded from a local module
|
load_code_from_local: Whether code should be loaded from a local module
|
||||||
or from the GCS.
|
or from the GCS.
|
||||||
|
use_pickle: Whether data objects should be serialized with cloudpickle.
|
||||||
_internal_config (str): JSON configuration for overriding
|
_internal_config (str): JSON configuration for overriding
|
||||||
RayConfig defaults. For testing purposes ONLY.
|
RayConfig defaults. For testing purposes ONLY.
|
||||||
|
|
||||||
|
@ -1316,6 +1329,7 @@ def init(address=None,
|
||||||
raylet_socket_name=raylet_socket_name,
|
raylet_socket_name=raylet_socket_name,
|
||||||
temp_dir=temp_dir,
|
temp_dir=temp_dir,
|
||||||
load_code_from_local=load_code_from_local,
|
load_code_from_local=load_code_from_local,
|
||||||
|
use_pickle=use_pickle,
|
||||||
_internal_config=_internal_config,
|
_internal_config=_internal_config,
|
||||||
)
|
)
|
||||||
# Start the Ray processes. We set shutdown_at_exit=False because we
|
# Start the Ray processes. We set shutdown_at_exit=False because we
|
||||||
|
@ -1372,7 +1386,8 @@ def init(address=None,
|
||||||
redis_password=redis_password,
|
redis_password=redis_password,
|
||||||
object_id_seed=object_id_seed,
|
object_id_seed=object_id_seed,
|
||||||
temp_dir=temp_dir,
|
temp_dir=temp_dir,
|
||||||
load_code_from_local=load_code_from_local)
|
load_code_from_local=load_code_from_local,
|
||||||
|
use_pickle=use_pickle)
|
||||||
_global_node = ray.node.Node(
|
_global_node = ray.node.Node(
|
||||||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||||
|
|
||||||
|
@ -2045,7 +2060,7 @@ def register_custom_serializer(cls,
|
||||||
assert isinstance(job_id, JobID)
|
assert isinstance(job_id, JobID)
|
||||||
|
|
||||||
def register_class_for_serialization(worker_info):
|
def register_class_for_serialization(worker_info):
|
||||||
if USE_NEW_SERIALIZER:
|
if worker_info["worker"].use_pickle:
|
||||||
if pickle.FAST_CLOUDPICKLE_USED:
|
if pickle.FAST_CLOUDPICKLE_USED:
|
||||||
# construct a reducer
|
# construct a reducer
|
||||||
pickle.CloudPickler.dispatch[
|
pickle.CloudPickler.dispatch[
|
||||||
|
|
|
@ -62,6 +62,11 @@ parser.add_argument(
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="True if code is loaded from local files, as opposed to the GCS.")
|
help="True if code is loaded from local files, as opposed to the GCS.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-pickle",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="True if cloudpickle should be used for serialization.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -75,7 +80,8 @@ if __name__ == "__main__":
|
||||||
plasma_store_socket_name=args.object_store_name,
|
plasma_store_socket_name=args.object_store_name,
|
||||||
raylet_socket_name=args.raylet_name,
|
raylet_socket_name=args.raylet_name,
|
||||||
temp_dir=args.temp_dir,
|
temp_dir=args.temp_dir,
|
||||||
load_code_from_local=args.load_code_from_local)
|
load_code_from_local=args.load_code_from_local,
|
||||||
|
use_pickle=args.use_pickle)
|
||||||
|
|
||||||
node = ray.node.Node(
|
node = ray.node.Node(
|
||||||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||||
|
|
Loading…
Add table
Reference in a new issue