mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Dataset] implement from_spark
, to_spark
and some optimizations (#17340)
This commit is contained in:
parent
fdd52106bf
commit
2fcd1bcb4b
7 changed files with 84 additions and 11 deletions
|
@ -1720,6 +1720,12 @@ cdef class CoreWorker:
|
|||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||
c_object_id)
|
||||
|
||||
def get_owner_address(self, ObjectRef object_ref):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_ref.native()
|
||||
return CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress(
|
||||
c_object_id).SerializeAsString()
|
||||
|
||||
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_ref.native()
|
||||
|
|
|
@ -17,7 +17,7 @@ T = TypeVar("T")
|
|||
#
|
||||
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
|
||||
# ``SimpleBlockAccessor``, ``ArrowBlockAccessor``, and ``TensorBlockAccessor``.
|
||||
Block = Union[List[T], np.ndarray, "pyarrow.Table"]
|
||||
Block = Union[List[T], np.ndarray, "pyarrow.Table", bytes]
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -124,6 +124,10 @@ class BlockAccessor(Generic[T]):
|
|||
from ray.data.impl.arrow_block import \
|
||||
ArrowBlockAccessor
|
||||
return ArrowBlockAccessor(block)
|
||||
elif isinstance(block, bytes):
|
||||
from ray.data.impl.arrow_block import \
|
||||
ArrowBlockAccessor
|
||||
return ArrowBlockAccessor.from_bytes(block)
|
||||
elif isinstance(block, list):
|
||||
from ray.data.impl.simple_block import \
|
||||
SimpleBlockAccessor
|
||||
|
|
|
@ -1269,7 +1269,8 @@ class Dataset(Generic[T]):
|
|||
pd_objs = self.to_pandas()
|
||||
return from_partitions(pd_objs, axis=0)
|
||||
|
||||
def to_spark(self) -> "pyspark.sql.DataFrame":
|
||||
def to_spark(self,
|
||||
spark: "pyspark.sql.SparkSession") -> "pyspark.sql.DataFrame":
|
||||
"""Convert this dataset into a Spark dataframe.
|
||||
|
||||
Time complexity: O(dataset size / parallelism)
|
||||
|
@ -1277,7 +1278,14 @@ class Dataset(Generic[T]):
|
|||
Returns:
|
||||
A Spark dataframe created from this dataset.
|
||||
"""
|
||||
raise NotImplementedError # P2
|
||||
import raydp
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
locations = [
|
||||
core_worker.get_owner_address(block)
|
||||
for block in self.get_blocks()
|
||||
]
|
||||
return raydp.spark.ray_dataset_to_spark_dataframe(
|
||||
spark, self.schema(), self.get_blocks(), locations)
|
||||
|
||||
def to_pandas(self) -> List[ObjectRef["pandas.DataFrame"]]:
|
||||
"""Convert this dataset into a distributed set of Pandas dataframes.
|
||||
|
|
|
@ -145,6 +145,11 @@ class ArrowBlockAccessor(BlockAccessor):
|
|||
raise ImportError("Run `pip install pyarrow` for Arrow support")
|
||||
self._table = table
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
reader = pyarrow.ipc.open_stream(data)
|
||||
return cls(reader.read_all())
|
||||
|
||||
def iter_rows(self) -> Iterator[ArrowRow]:
|
||||
outer = self
|
||||
|
||||
|
|
|
@ -506,7 +506,6 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]]) -> Dataset[ArrowRow]:
|
|||
return Dataset(BlockList(blocks, ray.get(list(metadata))))
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]:
|
||||
"""Create a dataset from a set of NumPy ndarrays.
|
||||
|
||||
|
@ -524,34 +523,40 @@ def from_numpy(ndarrays: List[ObjectRef[np.ndarray]]) -> Dataset[np.ndarray]:
|
|||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def from_arrow(tables: List[ObjectRef["pyarrow.Table"]]) -> Dataset[ArrowRow]:
|
||||
def from_arrow(tables: List[ObjectRef[Union["pyarrow.Table", bytes]]]
|
||||
) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a set of Arrow tables.
|
||||
|
||||
Args:
|
||||
dfs: A list of Ray object references to Arrow tables.
|
||||
tables: A list of Ray object references to Arrow tables,
|
||||
or its streaming format in bytes.
|
||||
|
||||
Returns:
|
||||
Dataset holding Arrow records from the tables.
|
||||
"""
|
||||
|
||||
get_metadata = cached_remote_fn(_get_metadata)
|
||||
metadata = [get_metadata.remote(t) for t in tables]
|
||||
return Dataset(BlockList(tables, ray.get(metadata)))
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def from_spark(df: "pyspark.sql.DataFrame", *,
|
||||
parallelism: int = 200) -> Dataset[ArrowRow]:
|
||||
def from_spark(df: "pyspark.sql.DataFrame",
|
||||
*,
|
||||
parallelism: Optional[int] = None) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a Spark dataframe.
|
||||
|
||||
Args:
|
||||
spark: A SparkSession, which must be created by RayDP (Spark-on-Ray).
|
||||
df: A Spark dataframe, which must be created by RayDP (Spark-on-Ray).
|
||||
parallelism: The amount of parallelism to use for the dataset.
|
||||
parallelism: The amount of parallelism to use for the dataset.
|
||||
If not provided, it will be equal to the number of partitions of
|
||||
the original Spark dataframe.
|
||||
|
||||
Returns:
|
||||
Dataset holding Arrow records read from the dataframe.
|
||||
"""
|
||||
raise NotImplementedError # P2
|
||||
import raydp
|
||||
return raydp.spark.spark_dataframe_to_ray_dataset(df, parallelism)
|
||||
|
||||
|
||||
def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]:
|
||||
|
|
44
python/ray/data/tests/test_raydp_dataset.py
Normal file
44
python/ray/data/tests/test_raydp_dataset.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import pytest
|
||||
import ray
|
||||
import raydp
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def spark_on_ray_small(request):
|
||||
ray.init(num_cpus=2, include_dashboard=False)
|
||||
spark = raydp.init_spark("test", 1, 1, "500 M")
|
||||
|
||||
def stop_all():
|
||||
raydp.stop_spark()
|
||||
ray.shutdown()
|
||||
|
||||
request.addfinalizer(stop_all)
|
||||
return spark
|
||||
|
||||
|
||||
def test_raydp_roundtrip(spark_on_ray_small):
|
||||
spark = spark_on_ray_small
|
||||
spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")],
|
||||
["one", "two"])
|
||||
rows = [(r.one, r.two) for r in spark_df.take(3)]
|
||||
ds = ray.data.from_spark(spark_df)
|
||||
values = [(r["one"], r["two"]) for r in ds.take(6)]
|
||||
assert values == rows
|
||||
df = ds.to_spark(spark)
|
||||
rows_2 = [(r.one, r.two) for r in df.take(3)]
|
||||
assert values == rows_2
|
||||
|
||||
|
||||
def test_raydp_to_spark(spark_on_ray_small):
|
||||
spark = spark_on_ray_small
|
||||
n = 5
|
||||
ds = ray.data.range_arrow(n)
|
||||
values = [r["value"] for r in ds.take(5)]
|
||||
df = ds.to_spark(spark)
|
||||
rows = [r.value for r in df.take(5)]
|
||||
assert values == rows
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -7,3 +7,4 @@ s3fs
|
|||
modin>=0.8.3; python_version < '3.7'
|
||||
modin>=0.10.0; python_version >= '3.7'
|
||||
pytest-repeat
|
||||
raydp-nightly
|
||||
|
|
Loading…
Add table
Reference in a new issue