diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index d3f1dbdf7..e7031fd8c 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2492,8 +2492,13 @@ List[str]]]): The names of the columns to use as the features. Can be a list of """ import raydp + core_worker = ray.worker.global_worker.core_worker + locations = [ + core_worker.get_owner_address(block) + for block in self.get_internal_block_refs() + ] return raydp.spark.ray_dataset_to_spark_dataframe( - spark, self.schema(), self.get_internal_block_refs() + spark, self.schema(), self.get_internal_block_refs(), locations ) def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame": diff --git a/python/ray/data/tests/test_raydp_dataset.py b/python/ray/data/tests/test_raydp_dataset.py index 26add48c9..2d3cf394a 100644 --- a/python/ray/data/tests/test_raydp_dataset.py +++ b/python/ray/data/tests/test_raydp_dataset.py @@ -5,19 +5,26 @@ import torch @pytest.fixture(scope="function") -def spark(request): +def spark_on_ray_small(request): ray.init(num_cpus=2, include_dashboard=False) - spark_session = raydp.init_spark("test", 1, 1, "500 M") + spark = raydp.init_spark("test", 1, 1, "500 M") def stop_all(): raydp.stop_spark() ray.shutdown() request.addfinalizer(stop_all) - return spark_session + return spark -def test_raydp_roundtrip(spark): +@pytest.mark.skip( + reason=( + "raydp.spark.spark_dataframe_to_ray_dataset needs to be updated to " + "use ray.data.from_arrow_refs." + ) +) +def test_raydp_roundtrip(spark_on_ray_small): + spark = spark_on_ray_small spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["one", "two"]) rows = [(r.one, r.two) for r in spark_df.take(3)] ds = ray.data.from_spark(spark_df) @@ -28,7 +35,11 @@ def test_raydp_roundtrip(spark): assert values == rows_2 -def test_raydp_to_spark(spark): +@pytest.mark.skip( + reason="raydp need to be updated to work without redis.", +) +def test_raydp_to_spark(spark_on_ray_small): + spark = spark_on_ray_small n = 5 ds = ray.data.range_arrow(n) values = [r["value"] for r in ds.take(5)]