Serialize StringIO with pickle (#5781)

This commit is contained in:
Philipp Moritz 2019-09-26 12:55:14 -07:00 committed by GitHub
parent 57a5871ea6
commit 01d6362472
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 0 deletions

View file

@ -6,6 +6,7 @@ from __future__ import print_function
import collections import collections
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import glob import glob
import io
import json import json
import logging import logging
from multiprocessing import Process from multiprocessing import Process
@ -304,6 +305,13 @@ def test_complex_serialization(ray_start_regular):
assert_equal(obj, ray.get(f.remote(obj))) assert_equal(obj, ray.get(f.remote(obj)))
assert_equal(obj, ray.get(ray.put(obj))) assert_equal(obj, ray.get(ray.put(obj)))
# Test StringIO serialization
s = io.StringIO(u"Hello, world!\n")
s.seek(0)
line = s.readline()
s.seek(0)
assert ray.get(ray.put(s)).readline() == line
def test_nested_functions(ray_start_regular): def test_nested_functions(ray_start_regular):
# Make sure that remote functions can use other values that are defined # Make sure that remote functions can use other values that are defined

View file

@ -8,6 +8,7 @@ import atexit
import faulthandler import faulthandler
import hashlib import hashlib
import inspect import inspect
import io
import json import json
import logging import logging
import numpy as np import numpy as np
@ -1278,6 +1279,16 @@ def _initialize_serialization(job_id, worker=global_worker):
local=True, local=True,
job_id=job_id, job_id=job_id,
class_id="ray.signature.FunctionSignature") class_id="ray.signature.FunctionSignature")
# Tell Ray to serialize StringIO with pickle. We do this because
# Ray's default __dict__ serialization is incorrect for this type
# (the object's __dict__ is empty and therefore doesn't
# contain the full state of the object).
register_custom_serializer(
io.StringIO,
use_pickle=True,
local=True,
job_id=job_id,
class_id="io.StringIO")
def init(address=None, def init(address=None,