[Dataset] implement from_spark, to_spark and some optimizations (#17340)

This commit is contained in:
Zhi Lin 2021-09-10 02:43:47 +08:00 committed by GitHub
parent fdd52106bf
commit 2fcd1bcb4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 84 additions and 11 deletions

View file

@ -1720,6 +1720,12 @@ cdef class CoreWorker:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)
def get_owner_address(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()
return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress(
c_object_id).SerializeAsString()
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()

View file

@ -17,7 +17,7 @@ T = TypeVar("T")
#
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``.
Block = Union[List[T], np.ndarray, "pyarrow.Table"]
Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes]
@DeveloperAPI
@ -124,6 +124,10 @@ class BlockAccessor(Generic[T]):
from ray.data.impl.arrow_block import \
ArrowBlockAccessor
return ArrowBlockAccessor(block)
elif isinstance(block, bytes):
from ray.data.impl.arrow_block import \
ArrowBlockAccessor
return ArrowBlockAccessor.from_bytes(block)
elif isinstance(block, list):
from ray.data.impl.simple_block import \
SimpleBlockAccessor

View file

@ -1269,7 +1269,8 @@ class Dataset(Generic[T]):
pd_objs = self.to_pandas()
return from_partitions(pd_objs, axis=0)
def to_spark(self) -> "pyspark.sql.DataFrame":
def to_spark(self,
spark: "pyspark.sql.SparkSession") -> "pyspark.sql.DataFrame":
"""Convert this dataset into a Spark dataframe.
Time complexity: O(dataset size / parallelism)
@ -1277,7 +1278,14 @@ class Dataset(Generic[T]):
Returns:
A Spark dataframe created from this dataset.
"""
raise NotImplementedError # P2
import raydp
core_worker = ray.worker.global_worker.core_worker
locations = [
core_worker.get_owner_address(block)
for block in self.get_blocks()
]
return raydp.spark.ray_dataset_to_spark_dataframe(
spark, self.schema(), self.get_blocks(), locations)
def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]:
"""Convert this dataset into a distributed set of Pandas dataframes.

View file

@ -145,6 +145,11 @@ class ArrowBlockAccessor(BlockAccessor):
raise ImportError("Run `pip install pyarrow` for Arrow support")
self._table = table
@classmethod
def from_bytes(cls, data: bytes):
reader = pyarrow.ipc.open_stream(data)
return cls(reader.read_all())
def iter_rows(self) -> Iterator[ArrowRow]:
outer = self

View file

@ -506,7 +506,6 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]:
return Dataset(BlockList(blocks, ray.get(list(metadata))))
@PublicAPI(stability="beta")
def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]:
"""Create a dataset from a set of NumPy ndarrays.
@ -524,34 +523,40 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]:
@PublicAPI(stability="beta")
def from_arrow(tables: List[ObjectRef["pyarrow.Table"]]) -> Dataset[ArrowRow]:
def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]]
) -> Dataset[ArrowRow]:
"""Create a dataset from a set of Arrow tables.
Args:
dfs: A list of Ray object references to Arrow tables.
tables: A list of Ray object references to Arrow tables,
or its streaming format in bytes.
Returns:
Dataset holding Arrow records from the tables.
"""
get_metadata = cached_remote_fn(_get_metadata)
metadata = [get_metadata.remote(t) for t in tables]
return Dataset(BlockList(tables, ray.get(metadata)))
@PublicAPI(stability="beta")
def from_spark(df: "pyspark.sql.DataFrame", *,
parallelism: int = 200) -> Dataset[ArrowRow]:
def from_spark(df: "pyspark.sql.DataFrame",
*,
parallelism: Optional[int] = None) -> Dataset[ArrowRow]:
"""Create a dataset from a Spark dataframe.
Args:
spark: A SparkSession, which must be created by RayDP (Spark-on-Ray).
df: A Spark dataframe, which must be created by RayDP (Spark-on-Ray).
parallelism: The amount of parallelism to use for the dataset.
If not provided, it will be equal to the number of partitions of
the original Spark dataframe.
Returns:
Dataset holding Arrow records read from the dataframe.
"""
raise NotImplementedError # P2
import raydp
return raydp.spark.spark_dataframe_to_ray_dataset(df, parallelism)
def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]:

View file

@ -0,0 +1,44 @@
import pytest
import ray
import raydp
@pytest.fixture(scope="function")
def spark_on_ray_small(request):
ray.init(num_cpus=2, include_dashboard=False)
spark = raydp.init_spark("test", 1, 1, "500 M")
def stop_all():
raydp.stop_spark()
ray.shutdown()
request.addfinalizer(stop_all)
return spark
def test_raydp_roundtrip(spark_on_ray_small):
spark = spark_on_ray_small
spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")],
["one", "two"])
rows = [(r.one, r.two) for r in spark_df.take(3)]
ds = ray.data.from_spark(spark_df)
values = [(r["one"], r["two"]) for r in ds.take(6)]
assert values == rows
df = ds.to_spark(spark)
rows_2 = [(r.one, r.two) for r in df.take(3)]
assert values == rows_2
def test_raydp_to_spark(spark_on_ray_small):
spark = spark_on_ray_small
n = 5
ds = ray.data.range_arrow(n)
values = [r["value"] for r in ds.take(5)]
df = ds.to_spark(spark)
rows = [r.value for r in df.take(5)]
assert values == rows
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -7,3 +7,4 @@ s3fs
modin>=0.8.3; python_version < '3.7'
modin>=0.10.0; python_version >= '3.7'
pytest-repeat
raydp-nightly