mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[data] MLDataset based on ParallelIterator (#11849)
This commit is contained in:
parent
2fe1321c3f
commit
9481ecd180
13 changed files with 1265 additions and 13 deletions
|
@ -26,6 +26,7 @@ py_test_module_list(
|
|||
"test_error_ray_not_initialized.py",
|
||||
"test_gcs_fault_tolerance.py",
|
||||
"test_iter.py",
|
||||
"test_mldataset.py",
|
||||
],
|
||||
size = "medium",
|
||||
extra_srcs = SRCS,
|
||||
|
|
126
python/ray/tests/test_mldataset.py
Normal file
126
python/ray/tests/test_mldataset.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
import ray.util.iter as parallel_it
|
||||
import ray.util.data as ml_data
|
||||
import pytest
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
|
||||
def test_read_parquet(ray_start_regular_shared, tmp_path):
|
||||
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
|
||||
table = pa.Table.from_pandas(df1)
|
||||
pq.write_table(table, os.path.join(tmp_path, "test1.parquet"))
|
||||
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
|
||||
table = pa.Table.from_pandas(df2)
|
||||
pq.write_table(table, os.path.join(tmp_path, "test2.parquet"))
|
||||
|
||||
# without columns
|
||||
ds = ml_data.read_parquet(tmp_path, num_shards=2)
|
||||
result = list(ds.gather_sync())
|
||||
assert df1.equals(result[0])
|
||||
assert df2.equals(result[1])
|
||||
|
||||
# with columns one
|
||||
ds = ml_data.read_parquet(tmp_path, num_shards=2, columns=["one"])
|
||||
result = list(ds.gather_sync())
|
||||
assert df1[["one"]].equals(result[0])
|
||||
assert df2[["one"]].equals(result[1])
|
||||
|
||||
# with columns two
|
||||
ds = ml_data.read_parquet(tmp_path, num_shards=2, columns=["two"])
|
||||
result = list(ds.gather_sync())
|
||||
assert df1[["two"]].equals(result[0])
|
||||
assert df2[["two"]].equals(result[1])
|
||||
|
||||
|
||||
def test_from_parallel_it(ray_start_regular_shared):
|
||||
para_it = parallel_it.from_range(4).for_each(lambda x: [x])
|
||||
ds = ml_data.from_parallel_iter(para_it, batch_size=2)
|
||||
assert repr(ds) == ("MLDataset[from_range[4, shards=2]"
|
||||
".for_each().batch(2).to_pandas()]")
|
||||
collected = list(ds.gather_sync())
|
||||
assert len(collected) == 2
|
||||
assert all(d.shape == (2, 1) for d in collected)
|
||||
expected = para_it.flatten().batch(2).gather_sync().flatten()
|
||||
flattened = ds.gather_sync().for_each(lambda x: x[0].to_list()).flatten()
|
||||
assert list(flattened) == list(expected)
|
||||
|
||||
|
||||
def test_batch(ray_start_regular_shared):
|
||||
para_it = parallel_it.from_range(16).for_each(lambda x: [x])
|
||||
ds = ml_data.from_parallel_iter(para_it, batch_size=2)
|
||||
collected = list(ds.gather_sync())
|
||||
assert len(collected) == 8
|
||||
assert all(d.shape == (2, 1) for d in collected)
|
||||
|
||||
ds = ds.batch(4)
|
||||
assert repr(ds) == ("MLDataset[from_range[16, shards=2]"
|
||||
".for_each().batch(2).to_pandas().batch(4)]")
|
||||
collected = list(ds.gather_sync())
|
||||
assert len(collected) == 4
|
||||
assert all(d.shape == (4, 1) for d in collected)
|
||||
expected = para_it.flatten().batch(4).gather_sync().flatten()
|
||||
flattened = ds.gather_sync().for_each(lambda x: x[0].to_list()).flatten()
|
||||
assert list(flattened) == list(expected)
|
||||
|
||||
|
||||
def test_local_shuffle(ray_start_regular_shared):
|
||||
para_it = parallel_it.from_range(100).for_each(lambda x: [x])
|
||||
|
||||
# batch_size larger than 1 and shuffle_buffer_size larger than 1
|
||||
ds = ml_data.from_parallel_iter(para_it, batch_size=10)
|
||||
ds1 = ds.local_shuffle(shuffle_buffer_size=5)
|
||||
ds2 = ds.local_shuffle(shuffle_buffer_size=5)
|
||||
|
||||
l1 = list(ds1.gather_sync())
|
||||
l2 = list(ds2.gather_sync())
|
||||
assert not all(df1.equals(df2) for df1, df2 in zip(l1, l2))
|
||||
|
||||
# batch_size equals 1 and shuffle_buffer_size larger than 1
|
||||
ds = ml_data.from_parallel_iter(para_it, batch_size=1)
|
||||
ds1 = ds.local_shuffle(shuffle_buffer_size=5)
|
||||
ds2 = ds.local_shuffle(shuffle_buffer_size=5)
|
||||
|
||||
l1 = list(ds1.gather_sync())
|
||||
l2 = list(ds2.gather_sync())
|
||||
assert not all(df1.equals(df2) for df1, df2 in zip(l1, l2))
|
||||
|
||||
# batch_size equals 1 and shuffle_buffer_size equals 1
|
||||
ds = ml_data.from_parallel_iter(para_it, batch_size=1)
|
||||
ds1 = ds.local_shuffle(shuffle_buffer_size=1)
|
||||
ds2 = ds.local_shuffle(shuffle_buffer_size=1)
|
||||
|
||||
l1 = list(ds1.gather_sync())
|
||||
l2 = list(ds2.gather_sync())
|
||||
assert all(df1.equals(df2) for df1, df2 in zip(l1, l2))
|
||||
|
||||
|
||||
def test_union(ray_start_regular_shared):
|
||||
para_it1 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x])
|
||||
ds1 = ml_data.from_parallel_iter(para_it1, True, 2, False)
|
||||
para_it2 = parallel_it.from_range(4, 2, True).for_each(lambda x: [x])
|
||||
ds2 = ml_data.from_parallel_iter(para_it2, True, 2, True)
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
ds1.union(ds2)
|
||||
assert "two MLDataset which have different repeated type" in str(ex.value)
|
||||
|
||||
# union two MLDataset with same batch size
|
||||
para_it2 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x])
|
||||
ds2 = ml_data.from_parallel_iter(para_it2, True, 2, False)
|
||||
ds = ds1.union(ds2)
|
||||
assert ds.batch_size == 2
|
||||
|
||||
# union two MLDataset with different batch size
|
||||
para_it2 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x])
|
||||
ds2 = ml_data.from_parallel_iter(para_it2, True, 1, False)
|
||||
ds = ds1.union(ds2)
|
||||
# batch_size 0 means batch_size unknown
|
||||
assert ds.batch_size == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
93
python/ray/util/data/__init__.py
Normal file
93
python/ray/util/data/__init__.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from collections import defaultdict
|
||||
from typing import Iterable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ray.util.data.dataset import MLDataset
|
||||
from ray.util.data.parquet import read_parquet
|
||||
from ray.util.iter import T, ParallelIterator
|
||||
|
||||
try:
|
||||
import dataclasses
|
||||
except: # noqa: E722
|
||||
pass
|
||||
else:
|
||||
from dataclasses import is_dataclass
|
||||
|
||||
|
||||
def to_pandas(it: ParallelIterator[T],
|
||||
batch_size: int = 32) -> "ParallelIterator[pd.DataFrame]":
|
||||
"""Convert the a ParallelIterator to ParallelIterator of pd.DataFrame.
|
||||
|
||||
The record type should be list like object or dataclass instance. If
|
||||
the record is a iterable, we will convert to a list. If the record has
|
||||
__getitem__ attr, we will use __getitem__ to get the given column
|
||||
indexes data to create pandas DataFrame. If the record is dataclass
|
||||
instance we will use __getattr__ to get the given column.
|
||||
|
||||
Args:
|
||||
it (ParallelIterator[T]): the ParallelIterator to converted
|
||||
batch_size (int): batch the given size to create a pandas DataFrame
|
||||
Returns:
|
||||
A ParallelIterator of pd.DataFrame
|
||||
"""
|
||||
it = it.batch(batch_size)
|
||||
|
||||
def convert_fn(input_it: Iterable[T]) -> Iterable[pd.DataFrame]:
|
||||
names = []
|
||||
for batch in input_it:
|
||||
assert isinstance(batch, list)
|
||||
if hasattr(batch[0], "__getitem__"):
|
||||
batch = pd.DataFrame(batch)
|
||||
elif hasattr(batch[0], "__iter__"):
|
||||
batch = [list(item) for item in batch]
|
||||
batch = pd.DataFrame(batch)
|
||||
elif is_dataclass(batch[0]):
|
||||
if not names:
|
||||
names = [f.name for f in dataclasses.fields(batch[0])]
|
||||
values = defaultdict(lambda x: [])
|
||||
for item in batch:
|
||||
for col in names:
|
||||
values[col].append(getattr(item, col))
|
||||
batch = pd.DataFrame(values, columns=names)
|
||||
else:
|
||||
raise ValueError("MLDataset only support list like item or "
|
||||
"dataclass instance")
|
||||
|
||||
yield batch
|
||||
|
||||
it = it._with_transform(lambda local_it: local_it.transform(convert_fn),
|
||||
".to_pandas()")
|
||||
return it
|
||||
|
||||
|
||||
def from_parallel_iter(para_it: ParallelIterator[T],
|
||||
need_convert: bool = True,
|
||||
batch_size: int = 32,
|
||||
repeated: bool = False) -> MLDataset:
|
||||
"""Create a MLDataset from an existing ParallelIterator.
|
||||
|
||||
The object of the ParallelIterator should be list like object or dataclass
|
||||
instance.
|
||||
|
||||
Args:
|
||||
para_it (ParallelIterator[T]): An existing parallel iterator, and each
|
||||
should be a list like object or dataclass instance.
|
||||
need_convert (bool): whether need to convert to pandas.DataFrame. This
|
||||
should be False if the record type is pandas.DataFrame.
|
||||
batch_size (int): if need_convert is True, we will batch the batch_size
|
||||
records to a pandas.DataFrame
|
||||
repeated (bool): whether the para_it is repeated.
|
||||
Returns:
|
||||
a MLDataset
|
||||
"""
|
||||
|
||||
if need_convert:
|
||||
para_it = to_pandas(para_it, batch_size)
|
||||
else:
|
||||
batch_size = 0
|
||||
|
||||
return MLDataset.from_parallel_it(para_it, batch_size, repeated)
|
||||
|
||||
|
||||
__all__ = ["from_parallel_iter", "read_parquet", "MLDataset"]
|
386
python/ray/util/data/dataset.py
Normal file
386
python/ray/util/data/dataset.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
import random
|
||||
from typing import Callable, List, Iterable, Iterator
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ray.util.iter import (_NextValueNotReady, LocalIterator, ParallelIterator,
|
||||
T, U)
|
||||
|
||||
|
||||
class MLDataset(ParallelIterator[pd.DataFrame]):
|
||||
"""A distributed ML dataset implemented based on ParallelIterator
|
||||
|
||||
All item should be a list like object or dataclass instance.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size of the current dataset. It should be
|
||||
larger than zero, and 0 means unknown.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_sets: List["_ActorSet"], name: str,
|
||||
parent_iterators: List[ParallelIterator[pd.DataFrame]],
|
||||
batch_size: int, repeated: bool):
|
||||
super(MLDataset, self).__init__(actor_sets, name, parent_iterators)
|
||||
self._batch_size = batch_size
|
||||
self._repeated = repeated
|
||||
|
||||
@staticmethod
|
||||
def from_parallel_it(para_it: ParallelIterator[pd.DataFrame],
|
||||
batch_size: int,
|
||||
repeated: bool = False) -> "MLDataset":
|
||||
"""Create a MLDataset from an parallel iterator
|
||||
|
||||
The record of ParallelIterator should be pandas.DataFrame.
|
||||
|
||||
Args:
|
||||
para_it (ParallelIterator[T]): An existing parallel iterator,
|
||||
and each should be a list like object or dataclass instance
|
||||
batch_size (int): The batch size of the current dataset. It
|
||||
should be larger than zero, and 0 means unknown.
|
||||
repeated (bool): whether the para_it is repeated.
|
||||
Returns:
|
||||
A MLDataset
|
||||
"""
|
||||
return MLDataset(para_it.actor_sets, para_it.name,
|
||||
para_it.parent_iterators, batch_size, repeated)
|
||||
|
||||
def __iter__(self):
|
||||
raise TypeError(
|
||||
"You must use it.gather_sync() or it.gather_async() to "
|
||||
"iterate over the results of a MLDataset.")
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"MLDataset[{self.name}]"
|
||||
|
||||
def _with_transform(self, local_it_fn, name) -> "MLDataset":
|
||||
"""Helper function to create new MLDataset"""
|
||||
para_it = super()._with_transform(local_it_fn, name)
|
||||
return MLDataset.from_parallel_it(para_it, self._batch_size,
|
||||
self._repeated)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
fn: Callable[[Iterable[pd.DataFrame]], Iterable[pd.DataFrame]]
|
||||
) -> "MLDataset":
|
||||
"""Apply the fn function to the MLDataset
|
||||
|
||||
Args:
|
||||
fn (Callable[[Iterable[DataFrame]], Iterable[DataFrame]]):
|
||||
The function to applied. The input is a iterator of
|
||||
pandas.DataFrame, and the output should also be a iterator of
|
||||
pandas.DataFrame.
|
||||
Returns:
|
||||
A new MLDataset
|
||||
"""
|
||||
return self._with_transform(lambda local_it: local_it.transform(fn),
|
||||
".transform()")
|
||||
|
||||
def batch(self, batch_size: int) -> "MLDataset":
|
||||
"""Rebatch the number of rows for each pandas.DataFrame record
|
||||
|
||||
Unlike the ParallelIterator.batch. This method rebatch the underlying
|
||||
the pandas DataFrame, and each pandas DataFrame will have batch_size
|
||||
rows.
|
||||
"""
|
||||
if batch_size == self._batch_size:
|
||||
return self
|
||||
|
||||
def batch_fn(it: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
|
||||
it = iter(it)
|
||||
return_df = None
|
||||
while True:
|
||||
try:
|
||||
cur_df = next(it)
|
||||
cur_index = 0
|
||||
cur_size = cur_df.shape[0]
|
||||
while cur_df is not None or (
|
||||
cur_index + batch_size) < cur_size:
|
||||
if cur_df is None or cur_index == cur_size:
|
||||
cur_df = next(it)
|
||||
cur_index = 0
|
||||
cur_size = cur_df.shape[0]
|
||||
if return_df is not None:
|
||||
ri = cur_index + batch_size - return_df.shape[0]
|
||||
ri = min(ri, cur_size)
|
||||
tmp = cur_df.iloc[cur_index:ri]
|
||||
return_df = pd.concat([return_df, tmp])
|
||||
cur_index = ri
|
||||
else:
|
||||
ri = cur_index + batch_size
|
||||
ri = min(ri, cur_size)
|
||||
return_df = cur_df.iloc[cur_index:ri]
|
||||
cur_index = ri
|
||||
if return_df.shape[0] == batch_size:
|
||||
return_df.index = range(return_df.shape[0])
|
||||
yield return_df
|
||||
return_df = None
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
if return_df is not None:
|
||||
return_df.index = range(return_df.shape[0])
|
||||
yield return_df
|
||||
|
||||
self._batch_size = batch_size
|
||||
return self._with_transform(
|
||||
lambda local_it: local_it.transform(batch_fn),
|
||||
f".batch({batch_size})")
|
||||
|
||||
def flatten(self) -> "MLDataset":
|
||||
raise Exception("Unsupported operation")
|
||||
|
||||
def combine(self, fn: Callable[[T], List[U]]) -> "MLDataset":
|
||||
raise Exception("Unsupported operation")
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
return self._repeated
|
||||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
return self._batch_size
|
||||
|
||||
def local_shuffle(self, shuffle_buffer_size: int,
|
||||
seed: int = None) -> "MLDataset":
|
||||
"""Applying local shuffle
|
||||
|
||||
Unlike the ParallelIterator.local_shuffle. This shuffle will first
|
||||
apply the local_shuffle for each shards and then shuffle the each
|
||||
pandas DataFrame.
|
||||
"""
|
||||
ds = super().local_shuffle(shuffle_buffer_size, seed)
|
||||
|
||||
def shuffle_fn(it: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
|
||||
for df in it:
|
||||
df = df.sample(frac=1, random_state=seed)
|
||||
yield df
|
||||
|
||||
ds = ds._with_transform(
|
||||
lambda local_it: local_it.transform(shuffle_fn),
|
||||
".inner_pandas_shuffle()")
|
||||
|
||||
return ds
|
||||
|
||||
def repartition(self, num_partitions: int,
|
||||
batch_ms: int = 0) -> "MLDataset":
|
||||
"""see ParallelIterator.repartition"""
|
||||
if num_partitions == self.num_shards():
|
||||
return self
|
||||
para_it = super().repartition(num_partitions, batch_ms)
|
||||
return MLDataset.from_parallel_it(para_it, self._batch_size)
|
||||
|
||||
def union(self, other: "MLDataset") -> "MLDataset":
|
||||
"""Return an iterator that is the union of this and the other."""
|
||||
if not isinstance(other, MLDataset):
|
||||
raise TypeError(
|
||||
f"other must be of type MLDataset, got {type(other)}")
|
||||
|
||||
if self._repeated != other.repeated:
|
||||
raise TypeError(
|
||||
f"want to union two MLDataset which have different repeated "
|
||||
f"type, self repeated: {self._repeated}, other repeated: "
|
||||
f"{other.repeated}")
|
||||
|
||||
batch_size = 0
|
||||
if self._batch_size == other._batch_size:
|
||||
batch_size = self._batch_size
|
||||
|
||||
actor_sets = []
|
||||
actor_sets.extend(self.actor_sets)
|
||||
actor_sets.extend(other.actor_sets)
|
||||
# if one of these iterators is a result of a repartition, we need to
|
||||
# keep an explicit reference to its parent iterator
|
||||
return MLDataset(
|
||||
actor_sets,
|
||||
f"ParallelUnion[{self}, {other}]",
|
||||
parent_iterators=self.parent_iterators + other.parent_iterators,
|
||||
batch_size=batch_size,
|
||||
repeated=self._repeated)
|
||||
|
||||
def select_shards(self, shards_to_keep: List[int]) -> "MLDataset":
|
||||
para_it = super().select_shards(shards_to_keep)
|
||||
return MLDataset.from_parallel_it(para_it, self._batch_size,
|
||||
self._repeated)
|
||||
|
||||
def get_repeatable_shard(self,
|
||||
index: int,
|
||||
batch_ms: int = 0,
|
||||
num_async: int = 1,
|
||||
shuffle: bool = False,
|
||||
shuffle_buffer_size: int = 1,
|
||||
seed: int = None) -> Iterator:
|
||||
"""Get the given shard of the current dataset.
|
||||
|
||||
The return is a iterator. Each call iter on the returned iterator will
|
||||
get the shard data from beginning. And it support shuffle the return
|
||||
iterator when each call iter on the return.
|
||||
Args:
|
||||
index (int): the shard index id, -1 means collect all data.
|
||||
batch_ms (int): Batches items for batch_ms milliseconds
|
||||
before retrieving it. Increasing batch_ms increases latency
|
||||
but improves throughput. If this value is 0, then items are
|
||||
returned immediately.
|
||||
num_async (int): The max number of requests in flight. Increasing
|
||||
this improves the amount of pipeline parallelism in the
|
||||
iterator.
|
||||
shuffle (bool): whether shuffle the given shard data
|
||||
shuffle_buffer_size (int): same as ParallelIterator.local_shuffle
|
||||
seed (int): the random seed
|
||||
Returns:
|
||||
The given shard iterator. If the shuffle is True, each call iter
|
||||
will return a different ordered iterator.
|
||||
"""
|
||||
return _RepeatableIterator(self, index, batch_ms, num_async, shuffle,
|
||||
shuffle_buffer_size, seed)
|
||||
|
||||
def to_torch(self,
|
||||
feature_columns=None,
|
||||
feature_shapes=None,
|
||||
feature_types=None,
|
||||
label_column=None,
|
||||
label_shape=None,
|
||||
label_type=None):
|
||||
"""Create a TorchMLDataset from the current MLDataset.
|
||||
|
||||
Args:
|
||||
feature_columns (List[Any]): the column indexes name.
|
||||
feature_shapes (Optional[List[Any]]): the feature shapes should
|
||||
match the feature columns if provided.
|
||||
feature_types (Optional[List["torch.dtype"]]): the feature types
|
||||
should match the feature columns if provided. All feature will
|
||||
be cast into torch.float by default. Otherwise, cast into the
|
||||
provided type.
|
||||
label_column (Any): the label name.
|
||||
label_shape (Optional[int]): the label shape.
|
||||
label_type (Optional["torch.dtype"]): the label type, this will be
|
||||
cast into torch.float by default
|
||||
Returns:
|
||||
A TorchMLDataset
|
||||
"""
|
||||
from ray.util.sgd.torch.torch_dataset import TorchMLDataset
|
||||
return TorchMLDataset(self, feature_columns, feature_shapes,
|
||||
feature_types, label_column, label_shape,
|
||||
label_type)
|
||||
|
||||
def to_tf(self,
|
||||
feature_columns=None,
|
||||
feature_shapes=None,
|
||||
feature_types=None,
|
||||
label_column=None,
|
||||
label_shape=None,
|
||||
label_type=None):
|
||||
"""Create a TFMLDataset from the current MLDataset.
|
||||
|
||||
Args:
|
||||
feature_columns (List[Any]): the column names.
|
||||
feature_shapes (Optional[List[tf.TensorShape]]): the feature shapes
|
||||
should match the feature columns if provided.
|
||||
feature_types (Optional[List["tf.DType"]]): the feature types
|
||||
should match the feature columns if provided. All feature will
|
||||
be cast into tf.float by default. Otherwise, cast into the
|
||||
provided type.
|
||||
label_column (Any): the label name.
|
||||
label_shape (Optional[tf.TensorShape]): the label shape.
|
||||
label_type (Optional["tf.DType"]): the label type, this will be
|
||||
cast into tf.float by default
|
||||
Returns:
|
||||
A TFMLDataset
|
||||
"""
|
||||
from ray.util.sgd.tf.tf_dataset import TFMLDataset
|
||||
return TFMLDataset(self, feature_columns, feature_shapes,
|
||||
feature_types, label_column, label_shape,
|
||||
label_type)
|
||||
|
||||
|
||||
class _RepeatableIterator(Iterator[T]):
|
||||
"""A repeatable iterator for the given shard index data.
|
||||
|
||||
Each call iter(_RepeatableIterator instance) will fetch the data from
|
||||
beginning and will return a different order or data if set shuffle
|
||||
Args:
|
||||
ds (MLDataset): a MLDataset
|
||||
shard_index (int): the shard index id. -1 means collect all data.
|
||||
batch_ms (int): Batches items for batch_ms milliseconds
|
||||
before retrieving it. Increasing batch_ms increases latency
|
||||
but improves throughput. If this value is 0, then items are
|
||||
returned immediately.
|
||||
num_async (int): The max number of requests in flight. Increasing this
|
||||
improves the amount of pipeline parallelism in the iterator.
|
||||
shuffle (bool): whether shuffle the given shard data
|
||||
shuffle_buffer_size (int): same as ParallelIterator.local_shuffle
|
||||
seed (int): the random seed
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ds: MLDataset,
|
||||
shard_index: int,
|
||||
batch_ms: int = 0,
|
||||
num_async: int = 1,
|
||||
shuffle: bool = False,
|
||||
shuffle_buffer_size: int = 1,
|
||||
seed: int = None):
|
||||
super(_RepeatableIterator, self).__init__()
|
||||
self._ds = ds
|
||||
self._shard_index = shard_index
|
||||
self._batch_ms = batch_ms
|
||||
self._num_async = num_async
|
||||
self._shuffle = shuffle
|
||||
self._shuffle_buffer_size = shuffle_buffer_size
|
||||
self._seed = seed
|
||||
self._local_it: LocalIterator[T] = None
|
||||
|
||||
self._i = 0
|
||||
|
||||
def __next__(self) -> T:
|
||||
assert self._local_it is not None
|
||||
return next(self._local_it)
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
if self._shard_index >= 0:
|
||||
it = self._ds.get_shard(self._shard_index, self._batch_ms,
|
||||
self._num_async)
|
||||
else:
|
||||
if self._num_async > 0:
|
||||
it = self._ds.gather_async(
|
||||
batch_ms=self._batch_ms, num_async=self._num_async)
|
||||
else:
|
||||
it = self._ds.gather_sync()
|
||||
if self._shuffle:
|
||||
it = self.shuffle(it)
|
||||
|
||||
self._local_it = it
|
||||
return self
|
||||
|
||||
def shuffle(self,
|
||||
local_it: LocalIterator[T]) -> LocalIterator[pd.DataFrame]:
|
||||
shuffle_random = random.Random(self._seed)
|
||||
|
||||
def apply_shuffle(it):
|
||||
buffer = []
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
buffer.append(item)
|
||||
if len(buffer) >= self._shuffle_buffer_size:
|
||||
item = buffer.pop(
|
||||
shuffle_random.randint(0,
|
||||
len(buffer) - 1))
|
||||
item = item.sample(frac=1, random_state=self._seed)
|
||||
yield item
|
||||
while len(buffer) > 0:
|
||||
item = buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
|
||||
item = item.sample(frac=1, random_state=self._seed)
|
||||
yield item
|
||||
|
||||
return LocalIterator(
|
||||
local_it.base_iterator,
|
||||
local_it.shared_metrics,
|
||||
local_it.local_transforms + [apply_shuffle],
|
||||
name=local_it.name +
|
||||
".shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
self._shuffle_buffer_size,
|
||||
str(self._seed) if self._seed is not None else "None"))
|
0
python/ray/util/data/examples/__init__.py
Normal file
0
python/ray/util/data/examples/__init__.py
Normal file
66
python/ray/util/data/examples/mlp_identity_tf.py
Normal file
66
python/ray/util/data/examples/mlp_identity_tf.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import ray.util.data as ml_data
|
||||
import ray.util.iter as parallel_it
|
||||
from ray.util.sgd.tf.tf_dataset import TFMLDataset
|
||||
from ray.util.sgd.tf.tf_trainer import TFTrainer
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
import tensorflow as tf
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.Input(shape=(1, )),
|
||||
tf.keras.layers.Dense(128, activation="relu"),
|
||||
tf.keras.layers.Dense(1)
|
||||
])
|
||||
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
|
||||
model.compile(optimizer=optimizer, loss="mse", metrics=["accuracy"])
|
||||
return model
|
||||
|
||||
|
||||
def make_data_creator(tf_ds: TFMLDataset):
|
||||
def data_creator(config):
|
||||
world_rank = None
|
||||
if "TF_CONFIG" in os.environ:
|
||||
tf_config = json.loads(os.environ["TF_CONFIG"])
|
||||
world_rank = tf_config["task"]["index"]
|
||||
else:
|
||||
world_rank = -1
|
||||
batch_size = config["batch_size"]
|
||||
ds = tf_ds.get_shard(shard_index=world_rank).batch(batch_size).repeat()
|
||||
return ds, None
|
||||
|
||||
return data_creator
|
||||
|
||||
|
||||
def main():
|
||||
num_points = 32 * 100 * 2
|
||||
data = [i * (1 / num_points) for i in range(num_points)]
|
||||
it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
|
||||
# this will create MLDataset with column RangeIndex(range(2))
|
||||
ds = ml_data.from_parallel_iter(it, True, batch_size=32, repeated=False)
|
||||
tf_ds = ds.to_tf(feature_columns=[0], label_column=1)
|
||||
|
||||
trainer = TFTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=make_data_creator(tf_ds),
|
||||
num_replicas=2,
|
||||
config={
|
||||
"batch_size": 32,
|
||||
"fit_config": {
|
||||
"steps_per_epoch": 100,
|
||||
}
|
||||
})
|
||||
|
||||
for _ in range(10):
|
||||
trainer.train()
|
||||
|
||||
model = trainer.get_model()
|
||||
print("f(0.5)=", float(model.predict([0.5])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
main()
|
70
python/ray/util/data/examples/mlp_identity_torch.py
Normal file
70
python/ray/util/data/examples/mlp_identity_torch.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
import ray.util.data as ml_data
|
||||
import ray.util.iter as parallel_it
|
||||
from ray.util.sgd.torch.torch_dataset import TorchMLDataset
|
||||
from ray.util.sgd.torch.torch_trainer import TorchTrainer
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(1, 128)
|
||||
self.fc2 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_train_operator(ds: TorchMLDataset):
|
||||
class IdentityTrainOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = Net()
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(), lr=config.get("lr", 1e-4))
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
batch_size = config["batch_size"]
|
||||
train_data = ds.get_shard(
|
||||
self.world_rank, shuffle=True, shuffle_buffer_size=4)
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size)
|
||||
|
||||
self.model, self.optimizer, self.criterion = self.register(
|
||||
models=model, optimizers=optimizer, criterion=loss)
|
||||
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=None)
|
||||
|
||||
return IdentityTrainOperator
|
||||
|
||||
|
||||
def main():
|
||||
num_points = 32 * 100 * 2
|
||||
data = [i * (1 / num_points) for i in range(num_points)]
|
||||
it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
|
||||
# this will create MLDataset with column RangeIndex(range(2))
|
||||
ds = ml_data.from_parallel_iter(it, True, batch_size=32, repeated=False)
|
||||
torch_ds = ds.to_torch(feature_columns=[0], label_column=1)
|
||||
|
||||
trainer = TorchTrainer(
|
||||
num_workers=2,
|
||||
training_operator_cls=make_train_operator(torch_ds),
|
||||
add_dist_sampler=False,
|
||||
config={"batch_size": 32})
|
||||
for i in range(10):
|
||||
trainer.train(num_steps=100)
|
||||
model = trainer.get_model()
|
||||
print("f(0.5)=", float(model(torch.tensor([[0.5]]).float())[0][0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
main()
|
21
python/ray/util/data/interface.py
Normal file
21
python/ray/util/data/interface.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from typing import Iterable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class _SourceShard:
|
||||
def prefix(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def shard_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def __iter__(self) -> Iterable[pd.DataFrame]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.prefix()}SourceShard[{self.shard_id}]"
|
111
python/ray/util/data/parquet.py
Normal file
111
python/ray/util/data/parquet.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import random
|
||||
from typing import Iterable
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
from pandas import DataFrame
|
||||
|
||||
import ray.util.iter as para_iter
|
||||
from .dataset import MLDataset
|
||||
from .interface import _SourceShard
|
||||
|
||||
|
||||
class ParquetSourceShard(_SourceShard):
|
||||
def __init__(self, data_pieces: List[pq.ParquetDatasetPiece],
|
||||
columns: Optional[List[str]],
|
||||
partitions: Optional[pq.ParquetPartitions], shard_id: int):
|
||||
self._data_pieces = data_pieces
|
||||
self._columns = columns
|
||||
self._partitions = partitions
|
||||
self._shard_id = shard_id
|
||||
|
||||
def prefix(self) -> str:
|
||||
return "Parquet"
|
||||
|
||||
@property
|
||||
def shard_id(self) -> int:
|
||||
return self._shard_id
|
||||
|
||||
def __iter__(self) -> Iterable[DataFrame]:
|
||||
for piece in self._data_pieces:
|
||||
yield piece.read(
|
||||
columns=self._columns,
|
||||
use_threads=False,
|
||||
partitions=self._partitions).to_pandas()
|
||||
|
||||
|
||||
def read_parquet(paths: Union[str, List[str]],
|
||||
num_shards: int,
|
||||
rowgroup_split: bool = True,
|
||||
shuffle: bool = False,
|
||||
shuffle_seed: int = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
**kwargs) -> MLDataset:
|
||||
"""Read parquet format data from hdfs like filesystem into a MLDataset.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# create dummy data
|
||||
spark.range(...).write.parquet(...)
|
||||
# create MLDataset
|
||||
data = ray.util.data.read_parquet(...)
|
||||
# convert to TorchMLDataset
|
||||
ds = data.to_torch(feature_columns=..., label_column=...)
|
||||
# get the given shard data
|
||||
shard = ds.get_shard(0)
|
||||
# create the DataLoader from the shard data and this can be used for
|
||||
# the TorchTrainer data creator as well
|
||||
data = DataLoader(shard, batch_size=32)
|
||||
|
||||
Args:
|
||||
paths (Union[str, List[str]): a single file path or a list of file path
|
||||
num_shards (int): the number of shards
|
||||
rowgroup_split (bool): whether split the files into shards based on
|
||||
rowgroup. If set False, each shard will have a list of files.
|
||||
shuffle (bool): whether shuffle the ParquetDatasetPiece order when
|
||||
divide into shards
|
||||
shuffle_seed (int): the shuffle seed
|
||||
columns (Optional[List[str]]): a list of column names to read
|
||||
kwargs: the other parquet read options
|
||||
Returns:
|
||||
A MLDataset
|
||||
"""
|
||||
pq_ds = pq.ParquetDataset(paths, **kwargs)
|
||||
pieces = pq_ds.pieces
|
||||
data_pieces = []
|
||||
if rowgroup_split:
|
||||
# split base on rowgroup
|
||||
for piece in pieces:
|
||||
num_row_groups = piece.get_metadata().to_dict()["num_row_groups"]
|
||||
for i in range(num_row_groups):
|
||||
data_pieces.append(
|
||||
pq.ParquetDatasetPiece(piece.path, piece.open_file_func,
|
||||
piece.file_options, i,
|
||||
piece.partition_keys))
|
||||
else:
|
||||
# split base on file pieces
|
||||
data_pieces = pieces.copy()
|
||||
|
||||
if len(data_pieces) < num_shards:
|
||||
raise ValueError(f"number of data pieces: {len(data_pieces)} should "
|
||||
f"larger than num_shards: {num_shards}")
|
||||
|
||||
if shuffle:
|
||||
random_shuffle = random.Random(shuffle_seed)
|
||||
random_shuffle.shuffle(data_pieces)
|
||||
shards = [[] for _ in range(num_shards)]
|
||||
for i, item in enumerate(data_pieces):
|
||||
shard = shards[i % num_shards]
|
||||
if item.row_group is None:
|
||||
for number in item.get_metadata().to_dict()["num_row_groups"]:
|
||||
shard.append(
|
||||
pq.ParquetDatasetPiece(item.path, item.open_file_func,
|
||||
item.file_options, number,
|
||||
item.partition_keys))
|
||||
else:
|
||||
shard.append(item)
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shards[i] = ParquetSourceShard(shard, columns, pq_ds.partitions, i)
|
||||
it = para_iter.from_iterators(shards, False, "parquet")
|
||||
return MLDataset.from_parallel_it(it, batch_size=0, repeated=False)
|
|
@ -1,13 +1,18 @@
|
|||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.util.data as ml_data
|
||||
import ray.util.iter as parallel_it
|
||||
from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.util.data.examples.mlp_identity_tf import (model_creator,
|
||||
make_data_creator)
|
||||
from ray.util.sgd.tf import TFTrainer, TFTrainable
|
||||
|
||||
from ray.util.sgd.tf.examples.tensorflow_train_example import (simple_model,
|
||||
simple_dataset)
|
||||
|
||||
|
@ -22,6 +27,14 @@ SIMPLE_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
address_info = ray.init(num_cpus=4)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2])
|
||||
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
@ -99,6 +112,35 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
|||
model2.optimizer.get_weights()
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2])
|
||||
def test_tf_dataset(ray_start_4_cpus): # noqa: F811
|
||||
num_points = 32 * 100 * 2
|
||||
data = [i * (1 / num_points) for i in range(num_points)]
|
||||
it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
|
||||
# this will create MLDataset with column RangeIndex(range(2))
|
||||
ds = ml_data.from_parallel_iter(it, True, batch_size=32, repeated=False)
|
||||
tf_ds = ds.to_tf(feature_columns=[0], label_column=1)
|
||||
trainer = TFTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=make_data_creator(tf_ds),
|
||||
num_replicas=2,
|
||||
config={
|
||||
"batch_size": 32,
|
||||
"fit_config": {
|
||||
"steps_per_epoch": 100,
|
||||
}
|
||||
})
|
||||
|
||||
for _ in range(10):
|
||||
trainer.train()
|
||||
|
||||
model = trainer.get_model()
|
||||
prediction = model.predict([0.5])[0][0]
|
||||
assert 0.4 <= prediction <= 0.6
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def _compare(d1, d2, skip_keys=None):
|
||||
"""Compare two lists or dictionaries or array"""
|
||||
if type(d1) != type(d2):
|
||||
|
|
|
@ -1,23 +1,26 @@
|
|||
import numpy as np
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from ray.tune.utils import merge_dicts
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
import ray.util.data as ml_data
|
||||
import ray.util.iter as parallel_it
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.training_operator import (
|
||||
get_test_operator, get_test_metrics_operator, TrainingOperator)
|
||||
from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE)
|
||||
|
||||
from ray.tune.utils import merge_dicts
|
||||
from ray.util.data.examples.mlp_identity_torch import make_train_operator
|
||||
from ray.util.sgd.data.examples import mlp_identity
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
from ray.util.sgd.torch.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
from ray.util.sgd.torch.training_operator import (
|
||||
get_test_operator, get_test_metrics_operator, TrainingOperator)
|
||||
from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -834,6 +837,30 @@ def test_multi_input_model(ray_start_2_cpus, use_local):
|
|||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_torch_dataset(ray_start_4_cpus, use_local):
|
||||
num_points = 32 * 100 * 2
|
||||
data = [i * (1 / num_points) for i in range(num_points)]
|
||||
para_it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
|
||||
ds = ml_data.from_parallel_iter(para_it, batch_size=32)
|
||||
|
||||
torch_ds = ds.to_torch(feature_columns=[0], label_column=1)
|
||||
operator = make_train_operator(torch_ds)
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=operator,
|
||||
num_workers=2,
|
||||
use_local=use_local,
|
||||
add_dist_sampler=False,
|
||||
config={"batch_size": 32})
|
||||
for i in range(10):
|
||||
trainer.train(num_steps=100)
|
||||
|
||||
model = trainer.get_model()
|
||||
prediction = float(model(torch.tensor([[0.5]]).float())[0][0])
|
||||
assert 0.4 <= prediction <= 0.6
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
137
python/ray/util/sgd/tf/tf_dataset.py
Normal file
137
python/ray/util/sgd/tf/tf_dataset.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.util.data import MLDataset
|
||||
|
||||
|
||||
class TFMLDataset:
|
||||
""" A TFMLDataset which converted from MLDataset
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ds = ml_dataset.to_tf(feature_columns=["x"], label_column="y")
|
||||
shard = ds.get_shard(0) # the data as (x_value, y_value)
|
||||
|
||||
ds = ml_dataset.to_tf(feature_columns=["x", "y"], label_column="z")
|
||||
shard = ds.get_shard(0) # the data as ((x_value, y_value), z_value)
|
||||
|
||||
|
||||
Args:
|
||||
ds (MLDataset): a MLDataset
|
||||
feature_columns (List[Any]): the feature columns' name
|
||||
feature_shapes (Optional[List[tf.TensorShape]]): the shape for each
|
||||
feature. If provide, it should match the size of feature_columns
|
||||
feature_types (Optional[List[tf.DType]]): the data type for each
|
||||
feature. If provide, it should match the size of feature_columns
|
||||
label_column (Any): the label column name
|
||||
label_shape (Optional[tf.TensorShape]): the shape for the label data
|
||||
label_type (Optional[tf.DType]): the data type for the label data
|
||||
"""
|
||||
|
||||
def __init__(self, ds: MLDataset, feature_columns: List[Any],
|
||||
feature_shapes: Optional[List[tf.TensorShape]],
|
||||
feature_types: Optional[List[tf.DType]], label_column: Any,
|
||||
label_shape: Optional[tf.TensorShape],
|
||||
label_type: Optional[tf.DType]):
|
||||
|
||||
self._feature_columns = feature_columns
|
||||
self._feature_shapes = feature_shapes
|
||||
self._feature_types = feature_types
|
||||
self._label_column = label_column
|
||||
self._label_shape = label_shape
|
||||
self._label_type = label_type
|
||||
|
||||
self._check_and_convert()
|
||||
|
||||
self._ds = ds
|
||||
|
||||
def _check_and_convert(self):
|
||||
# convert to list for convenience
|
||||
if not isinstance(self._feature_columns, list):
|
||||
self._feature_columns = [self._feature_columns]
|
||||
|
||||
if self._feature_shapes:
|
||||
if not isinstance(self._feature_shapes, list):
|
||||
self._feature_shapes = [self._feature_shapes]
|
||||
|
||||
assert len(self._feature_columns) == len(self._feature_shapes), \
|
||||
"The feature_shapes size must match the feature_columns"
|
||||
|
||||
if self._feature_types:
|
||||
if not isinstance(self._feature_types, list):
|
||||
self._feature_types = [self._feature_types]
|
||||
|
||||
assert len(self._feature_columns) == len(self._feature_types), \
|
||||
"The feature_types size must match the feature_columns"
|
||||
for i in range(len(self._feature_types)):
|
||||
assert (all(isinstance(dtype, tf.DType)
|
||||
for dtype in self._feature_types)), \
|
||||
"All value in feature_types should be tf.DType instance"
|
||||
|
||||
if not self._feature_shapes:
|
||||
self._feature_shapes = [tf.TensorShape(
|
||||
([]))] * len(self._feature_columns)
|
||||
|
||||
if not self._feature_types:
|
||||
self._feature_types = [tf.float32] * len(self._feature_columns)
|
||||
|
||||
if not self._label_type:
|
||||
self._label_type = tf.float32
|
||||
|
||||
if not self._label_shape:
|
||||
self._label_shape = tf.TensorShape(([]))
|
||||
|
||||
def set_num_shards(self, num_shards):
|
||||
""" Repartition the MLDataset to given number of shards"""
|
||||
if num_shards != self._ds.num_shards():
|
||||
logging.info("Setting num shards", num_shards)
|
||||
self._ds = self._ds.repartition(num_shards)
|
||||
|
||||
def get_shard(self,
|
||||
shard_index: int,
|
||||
batch_ms: int = 0,
|
||||
num_async: int = 1,
|
||||
shuffle: bool = False,
|
||||
shuffle_buffer_size: int = 1,
|
||||
seed: int = None) -> "tf.data.Dataset":
|
||||
""" Get the given shard data.
|
||||
|
||||
Get a the given shard data from MLDataset and convert into a
|
||||
tensorflow.data.Dataset. If the shard_index is smaller than zero,
|
||||
it will collect all data as a tensorflow.data.Dataset.
|
||||
"""
|
||||
it = self._ds.get_repeatable_shard(shard_index, batch_ms, num_async,
|
||||
shuffle, shuffle_buffer_size, seed)
|
||||
|
||||
def make_generator():
|
||||
for df in iter(it):
|
||||
num_rows = df.shape[0]
|
||||
feature_columns = [
|
||||
df[col].values for col in self._feature_columns
|
||||
]
|
||||
label_column = df[self._label_column].values
|
||||
for i in range(num_rows):
|
||||
features = [f[i] for f in feature_columns]
|
||||
if len(features) > 1:
|
||||
yield tuple(features), label_column[i]
|
||||
else:
|
||||
yield features[0], label_column[i]
|
||||
|
||||
output_shapes = self._feature_shapes.copy()
|
||||
if len(output_shapes) > 1:
|
||||
output_shapes = (tuple(output_shapes), self._label_shape)
|
||||
else:
|
||||
output_shapes = (output_shapes[0], self._label_shape)
|
||||
|
||||
output_types = self._feature_types.copy()
|
||||
if len(output_types) > 1:
|
||||
output_types = (tuple(output_types), self._label_type)
|
||||
else:
|
||||
output_types = output_types[0], self._label_type
|
||||
ds = tf.data.Dataset.from_generator(
|
||||
make_generator,
|
||||
output_types=output_types,
|
||||
output_shapes=output_shapes)
|
||||
return ds
|
172
python/ray/util/sgd/torch/torch_dataset.py
Normal file
172
python/ray/util/sgd/torch/torch_dataset.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
import functools
|
||||
import logging
|
||||
from collections import Iterator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pandas as pd
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from ray.util.data import MLDataset
|
||||
|
||||
|
||||
def convert_to_tensor(df, feature_columns: List[Any],
|
||||
feature_shapes: List[Any],
|
||||
feature_types: List[torch.dtype], label_column: Any,
|
||||
label_shape: Optional[int], label_type: torch.dtype):
|
||||
feature_tensor = []
|
||||
for col, shape, dtype in zip(feature_columns, feature_shapes,
|
||||
feature_types):
|
||||
column = df[col].values
|
||||
if column.dtype == np.object:
|
||||
if isinstance(column[0], np.ndarray):
|
||||
column = np.stack(column)
|
||||
elif isinstance(column[0], (list, tuple)):
|
||||
column = list(column)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Column {col}'s type: {type(column[0])} is not supported."
|
||||
" It must be numpy built in type or numpy object of "
|
||||
"(ndarray, list, tuple)")
|
||||
|
||||
t = torch.as_tensor(column, dtype=dtype)
|
||||
if shape is not None:
|
||||
t = t.view(*(-1, *shape))
|
||||
else:
|
||||
t = t.view(-1, 1)
|
||||
feature_tensor.append(t)
|
||||
|
||||
label_df = df[label_column].values
|
||||
label_tensor = torch.as_tensor(label_df, dtype=label_type)
|
||||
if label_shape:
|
||||
label_tensor = label_tensor.view(-1, label_shape)
|
||||
else:
|
||||
label_tensor = label_tensor.view(-1, 1)
|
||||
return feature_tensor, label_tensor
|
||||
|
||||
|
||||
class TorchMLDataset:
|
||||
"""A TorchMLDataset which converted from MLDataset
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ds = ml_dataset.to_torch(feature_columns=["x"], label_column="y")
|
||||
shard = ds.get_shard(0)
|
||||
data = DataLoader(shard, batch_size=32)
|
||||
batch_tensor_x, batch_tensor_y = next(iter(data))
|
||||
|
||||
ds = ml_dataset.to_torch(feature_columns=["x", "y"], label_column="z")
|
||||
shard = ds.get_shard(0)
|
||||
data = DataLoader(shard, batch_size=32)
|
||||
batch_tensor_x, batch_tensor_y, batch_tensor_z = next(iter(data))
|
||||
|
||||
|
||||
Args:
|
||||
ds (MLDataset): a MLDataset
|
||||
feature_columns (List[Any]): the feature columns' name
|
||||
feature_shapes (Optional[List[Any]]): the shape for each
|
||||
feature. If provide, it should match the size of feature_columns.
|
||||
feature_types (Optional[List[torch.dtype]]): the data type for each
|
||||
feature. If provide, it should match the size of feature_columns
|
||||
label_column (Any): the label column name
|
||||
label_shape (Optional[int]): the shape for the label data
|
||||
label_type (Optional[torch.dtype]): the data type for the label data
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ds: MLDataset = None,
|
||||
feature_columns: List[Any] = None,
|
||||
feature_shapes: Optional[List[Any]] = None,
|
||||
feature_types: Optional[List[torch.dtype]] = None,
|
||||
label_column: Any = None,
|
||||
label_shape: Optional[int] = None,
|
||||
label_type: Optional[torch.dtype] = None):
|
||||
|
||||
self._feature_columns = feature_columns
|
||||
self._feature_shapes = feature_shapes
|
||||
self._feature_types = feature_types
|
||||
self._label_column = label_column
|
||||
self._label_shape = label_shape
|
||||
self._label_type = label_type
|
||||
|
||||
self._check_and_convert()
|
||||
|
||||
self._ds = ds
|
||||
|
||||
def _check_and_convert(self):
|
||||
# convert to list for convenience
|
||||
if not isinstance(self._feature_columns, list):
|
||||
self._feature_columns = [self._feature_columns]
|
||||
|
||||
if self._feature_shapes:
|
||||
if not isinstance(self._feature_shapes, list):
|
||||
self._feature_shapes = [self._feature_shapes]
|
||||
|
||||
assert len(self._feature_columns) == len(self._feature_shapes), \
|
||||
"The feature_shapes size must match the feature_columns"
|
||||
for i in range(len(self._feature_shapes)):
|
||||
if not isinstance(self._feature_shapes[i], Iterable):
|
||||
self._feature_shapes[i] = [self._feature_shapes[i]]
|
||||
else:
|
||||
self._feature_shapes = [None] * len(self._feature_columns)
|
||||
|
||||
if self._feature_types:
|
||||
if not isinstance(self._feature_types, list):
|
||||
self._feature_types = [self._feature_types]
|
||||
|
||||
assert len(self._feature_columns) == len(self._feature_types), \
|
||||
"The feature_types size must match the feature_columns"
|
||||
for i in range(len(self._feature_types)):
|
||||
assert (all(isinstance(dtype, torch.dtype)
|
||||
for dtype in self._feature_types)), \
|
||||
"All value in feature_types should be torch.dtype instance"
|
||||
else:
|
||||
self._feature_types = [torch.float] * len(self._feature_columns)
|
||||
|
||||
if not self._label_type:
|
||||
self._label_type = torch.float
|
||||
|
||||
def set_num_shards(self, num_shards):
|
||||
"""Reshards the iterator if necessary"""
|
||||
if num_shards != self._ds.num_shards():
|
||||
logging.info("Setting num shards", num_shards)
|
||||
self._ds = self._ds.repartition(num_shards)
|
||||
|
||||
def get_shard(self,
|
||||
shard_index: int,
|
||||
batch_ms: int = 0,
|
||||
num_async: int = 1,
|
||||
shuffle: bool = False,
|
||||
shuffle_buffer_size: int = 1,
|
||||
seed: int = None) -> torch.utils.data.IterableDataset:
|
||||
|
||||
it = self._ds.get_repeatable_shard(shard_index, batch_ms, num_async,
|
||||
shuffle, shuffle_buffer_size, seed)
|
||||
convert_fn = functools.partial(
|
||||
convert_to_tensor,
|
||||
feature_columns=self._feature_columns,
|
||||
feature_shapes=self._feature_shapes,
|
||||
feature_types=self._feature_types,
|
||||
label_column=self._label_column,
|
||||
label_shape=self._label_shape,
|
||||
label_type=self._label_type)
|
||||
return TorchIterableDataset(it, convert_fn)
|
||||
|
||||
|
||||
class TorchIterableDataset(IterableDataset):
|
||||
def __init__(self, it: Iterator,
|
||||
convert_fn: Callable[[pd.DataFrame], Any]):
|
||||
super().__init__()
|
||||
self._it = it
|
||||
self._convert_fn = convert_fn
|
||||
|
||||
def __iter__(self):
|
||||
for df in iter(self._it):
|
||||
num_rows = df.shape[0]
|
||||
feature_tensor, label_tensor = self._convert_fn(df)
|
||||
for i in range(num_rows):
|
||||
features = [tensor[i] for tensor in feature_tensor]
|
||||
label = label_tensor[i]
|
||||
yield (*features, label)
|
Loading…
Add table
Reference in a new issue