[data] MLDataset based on ParallelIterator (#11849)

This commit is contained in:
Xianyang Liu 2020-11-19 16:33:37 +08:00 committed by GitHub
parent 2fe1321c3f
commit 9481ecd180
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 1265 additions and 13 deletions

View file

@ -26,6 +26,7 @@ py_test_module_list(
"test_error_ray_not_initialized.py",
"test_gcs_fault_tolerance.py",
"test_iter.py",
"test_mldataset.py",
],
size = "medium",
extra_srcs = SRCS,

View file

@ -0,0 +1,126 @@
import ray.util.iter as parallel_it
import ray.util.data as ml_data
import pytest
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
import os
def test_read_parquet(ray_start_regular_shared, tmp_path):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
table = pa.Table.from_pandas(df1)
pq.write_table(table, os.path.join(tmp_path, "test1.parquet"))
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
table = pa.Table.from_pandas(df2)
pq.write_table(table, os.path.join(tmp_path, "test2.parquet"))
# without columns
ds = ml_data.read_parquet(tmp_path, num_shards=2)
result = list(ds.gather_sync())
assert df1.equals(result[0])
assert df2.equals(result[1])
# with columns one
ds = ml_data.read_parquet(tmp_path, num_shards=2, columns=["one"])
result = list(ds.gather_sync())
assert df1[["one"]].equals(result[0])
assert df2[["one"]].equals(result[1])
# with columns two
ds = ml_data.read_parquet(tmp_path, num_shards=2, columns=["two"])
result = list(ds.gather_sync())
assert df1[["two"]].equals(result[0])
assert df2[["two"]].equals(result[1])
def test_from_parallel_it(ray_start_regular_shared):
para_it = parallel_it.from_range(4).for_each(lambda x: [x])
ds = ml_data.from_parallel_iter(para_it, batch_size=2)
assert repr(ds) == ("MLDataset[from_range[4, shards=2]"
".for_each().batch(2).to_pandas()]")
collected = list(ds.gather_sync())
assert len(collected) == 2
assert all(d.shape == (2, 1) for d in collected)
expected = para_it.flatten().batch(2).gather_sync().flatten()
flattened = ds.gather_sync().for_each(lambda x: x[0].to_list()).flatten()
assert list(flattened) == list(expected)
def test_batch(ray_start_regular_shared):
para_it = parallel_it.from_range(16).for_each(lambda x: [x])
ds = ml_data.from_parallel_iter(para_it, batch_size=2)
collected = list(ds.gather_sync())
assert len(collected) == 8
assert all(d.shape == (2, 1) for d in collected)
ds = ds.batch(4)
assert repr(ds) == ("MLDataset[from_range[16, shards=2]"
".for_each().batch(2).to_pandas().batch(4)]")
collected = list(ds.gather_sync())
assert len(collected) == 4
assert all(d.shape == (4, 1) for d in collected)
expected = para_it.flatten().batch(4).gather_sync().flatten()
flattened = ds.gather_sync().for_each(lambda x: x[0].to_list()).flatten()
assert list(flattened) == list(expected)
def test_local_shuffle(ray_start_regular_shared):
para_it = parallel_it.from_range(100).for_each(lambda x: [x])
# batch_size larger than 1 and shuffle_buffer_size larger than 1
ds = ml_data.from_parallel_iter(para_it, batch_size=10)
ds1 = ds.local_shuffle(shuffle_buffer_size=5)
ds2 = ds.local_shuffle(shuffle_buffer_size=5)
l1 = list(ds1.gather_sync())
l2 = list(ds2.gather_sync())
assert not all(df1.equals(df2) for df1, df2 in zip(l1, l2))
# batch_size equals 1 and shuffle_buffer_size larger than 1
ds = ml_data.from_parallel_iter(para_it, batch_size=1)
ds1 = ds.local_shuffle(shuffle_buffer_size=5)
ds2 = ds.local_shuffle(shuffle_buffer_size=5)
l1 = list(ds1.gather_sync())
l2 = list(ds2.gather_sync())
assert not all(df1.equals(df2) for df1, df2 in zip(l1, l2))
# batch_size equals 1 and shuffle_buffer_size equals 1
ds = ml_data.from_parallel_iter(para_it, batch_size=1)
ds1 = ds.local_shuffle(shuffle_buffer_size=1)
ds2 = ds.local_shuffle(shuffle_buffer_size=1)
l1 = list(ds1.gather_sync())
l2 = list(ds2.gather_sync())
assert all(df1.equals(df2) for df1, df2 in zip(l1, l2))
def test_union(ray_start_regular_shared):
para_it1 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x])
ds1 = ml_data.from_parallel_iter(para_it1, True, 2, False)
para_it2 = parallel_it.from_range(4, 2, True).for_each(lambda x: [x])
ds2 = ml_data.from_parallel_iter(para_it2, True, 2, True)
with pytest.raises(TypeError) as ex:
ds1.union(ds2)
assert "two MLDataset which have different repeated type" in str(ex.value)
# union two MLDataset with same batch size
para_it2 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x])
ds2 = ml_data.from_parallel_iter(para_it2, True, 2, False)
ds = ds1.union(ds2)
assert ds.batch_size == 2
# union two MLDataset with different batch size
para_it2 = parallel_it.from_range(4, 2, False).for_each(lambda x: [x])
ds2 = ml_data.from_parallel_iter(para_it2, True, 1, False)
ds = ds1.union(ds2)
# batch_size 0 means batch_size unknown
assert ds.batch_size == 0
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,93 @@
from collections import defaultdict
from typing import Iterable
import pandas as pd
from ray.util.data.dataset import MLDataset
from ray.util.data.parquet import read_parquet
from ray.util.iter import T, ParallelIterator
try:
import dataclasses
except: # noqa: E722
pass
else:
from dataclasses import is_dataclass
def to_pandas(it: ParallelIterator[T],
batch_size: int = 32) -> "ParallelIterator[pd.DataFrame]":
"""Convert the a ParallelIterator to ParallelIterator of pd.DataFrame.
The record type should be list like object or dataclass instance. If
the record is a iterable, we will convert to a list. If the record has
__getitem__ attr, we will use __getitem__ to get the given column
indexes data to create pandas DataFrame. If the record is dataclass
instance we will use __getattr__ to get the given column.
Args:
it (ParallelIterator[T]): the ParallelIterator to converted
batch_size (int): batch the given size to create a pandas DataFrame
Returns:
A ParallelIterator of pd.DataFrame
"""
it = it.batch(batch_size)
def convert_fn(input_it: Iterable[T]) -> Iterable[pd.DataFrame]:
names = []
for batch in input_it:
assert isinstance(batch, list)
if hasattr(batch[0], "__getitem__"):
batch = pd.DataFrame(batch)
elif hasattr(batch[0], "__iter__"):
batch = [list(item) for item in batch]
batch = pd.DataFrame(batch)
elif is_dataclass(batch[0]):
if not names:
names = [f.name for f in dataclasses.fields(batch[0])]
values = defaultdict(lambda x: [])
for item in batch:
for col in names:
values[col].append(getattr(item, col))
batch = pd.DataFrame(values, columns=names)
else:
raise ValueError("MLDataset only support list like item or "
"dataclass instance")
yield batch
it = it._with_transform(lambda local_it: local_it.transform(convert_fn),
".to_pandas()")
return it
def from_parallel_iter(para_it: ParallelIterator[T],
need_convert: bool = True,
batch_size: int = 32,
repeated: bool = False) -> MLDataset:
"""Create a MLDataset from an existing ParallelIterator.
The object of the ParallelIterator should be list like object or dataclass
instance.
Args:
para_it (ParallelIterator[T]): An existing parallel iterator, and each
should be a list like object or dataclass instance.
need_convert (bool): whether need to convert to pandas.DataFrame. This
should be False if the record type is pandas.DataFrame.
batch_size (int): if need_convert is True, we will batch the batch_size
records to a pandas.DataFrame
repeated (bool): whether the para_it is repeated.
Returns:
a MLDataset
"""
if need_convert:
para_it = to_pandas(para_it, batch_size)
else:
batch_size = 0
return MLDataset.from_parallel_it(para_it, batch_size, repeated)
__all__ = ["from_parallel_iter", "read_parquet", "MLDataset"]

View file

@ -0,0 +1,386 @@
import random
from typing import Callable, List, Iterable, Iterator
import pandas as pd
from ray.util.iter import (_NextValueNotReady, LocalIterator, ParallelIterator,
T, U)
class MLDataset(ParallelIterator[pd.DataFrame]):
"""A distributed ML dataset implemented based on ParallelIterator
All item should be a list like object or dataclass instance.
Args:
batch_size (int): The batch size of the current dataset. It should be
larger than zero, and 0 means unknown.
"""
def __init__(self, actor_sets: List["_ActorSet"], name: str,
parent_iterators: List[ParallelIterator[pd.DataFrame]],
batch_size: int, repeated: bool):
super(MLDataset, self).__init__(actor_sets, name, parent_iterators)
self._batch_size = batch_size
self._repeated = repeated
@staticmethod
def from_parallel_it(para_it: ParallelIterator[pd.DataFrame],
batch_size: int,
repeated: bool = False) -> "MLDataset":
"""Create a MLDataset from an parallel iterator
The record of ParallelIterator should be pandas.DataFrame.
Args:
para_it (ParallelIterator[T]): An existing parallel iterator,
and each should be a list like object or dataclass instance
batch_size (int): The batch size of the current dataset. It
should be larger than zero, and 0 means unknown.
repeated (bool): whether the para_it is repeated.
Returns:
A MLDataset
"""
return MLDataset(para_it.actor_sets, para_it.name,
para_it.parent_iterators, batch_size, repeated)
def __iter__(self):
raise TypeError(
"You must use it.gather_sync() or it.gather_async() to "
"iterate over the results of a MLDataset.")
def __str__(self):
return repr(self)
def __repr__(self):
return f"MLDataset[{self.name}]"
def _with_transform(self, local_it_fn, name) -> "MLDataset":
"""Helper function to create new MLDataset"""
para_it = super()._with_transform(local_it_fn, name)
return MLDataset.from_parallel_it(para_it, self._batch_size,
self._repeated)
def transform(
self,
fn: Callable[[Iterable[pd.DataFrame]], Iterable[pd.DataFrame]]
) -> "MLDataset":
"""Apply the fn function to the MLDataset
Args:
fn (Callable[[Iterable[DataFrame]], Iterable[DataFrame]]):
The function to applied. The input is a iterator of
pandas.DataFrame, and the output should also be a iterator of
pandas.DataFrame.
Returns:
A new MLDataset
"""
return self._with_transform(lambda local_it: local_it.transform(fn),
".transform()")
def batch(self, batch_size: int) -> "MLDataset":
"""Rebatch the number of rows for each pandas.DataFrame record
Unlike the ParallelIterator.batch. This method rebatch the underlying
the pandas DataFrame, and each pandas DataFrame will have batch_size
rows.
"""
if batch_size == self._batch_size:
return self
def batch_fn(it: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
it = iter(it)
return_df = None
while True:
try:
cur_df = next(it)
cur_index = 0
cur_size = cur_df.shape[0]
while cur_df is not None or (
cur_index + batch_size) < cur_size:
if cur_df is None or cur_index == cur_size:
cur_df = next(it)
cur_index = 0
cur_size = cur_df.shape[0]
if return_df is not None:
ri = cur_index + batch_size - return_df.shape[0]
ri = min(ri, cur_size)
tmp = cur_df.iloc[cur_index:ri]
return_df = pd.concat([return_df, tmp])
cur_index = ri
else:
ri = cur_index + batch_size
ri = min(ri, cur_size)
return_df = cur_df.iloc[cur_index:ri]
cur_index = ri
if return_df.shape[0] == batch_size:
return_df.index = range(return_df.shape[0])
yield return_df
return_df = None
except StopIteration:
break
if return_df is not None:
return_df.index = range(return_df.shape[0])
yield return_df
self._batch_size = batch_size
return self._with_transform(
lambda local_it: local_it.transform(batch_fn),
f".batch({batch_size})")
def flatten(self) -> "MLDataset":
raise Exception("Unsupported operation")
def combine(self, fn: Callable[[T], List[U]]) -> "MLDataset":
raise Exception("Unsupported operation")
@property
def repeated(self) -> bool:
return self._repeated
@property
def batch_size(self) -> int:
return self._batch_size
def local_shuffle(self, shuffle_buffer_size: int,
seed: int = None) -> "MLDataset":
"""Applying local shuffle
Unlike the ParallelIterator.local_shuffle. This shuffle will first
apply the local_shuffle for each shards and then shuffle the each
pandas DataFrame.
"""
ds = super().local_shuffle(shuffle_buffer_size, seed)
def shuffle_fn(it: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
for df in it:
df = df.sample(frac=1, random_state=seed)
yield df
ds = ds._with_transform(
lambda local_it: local_it.transform(shuffle_fn),
".inner_pandas_shuffle()")
return ds
def repartition(self, num_partitions: int,
batch_ms: int = 0) -> "MLDataset":
"""see ParallelIterator.repartition"""
if num_partitions == self.num_shards():
return self
para_it = super().repartition(num_partitions, batch_ms)
return MLDataset.from_parallel_it(para_it, self._batch_size)
def union(self, other: "MLDataset") -> "MLDataset":
"""Return an iterator that is the union of this and the other."""
if not isinstance(other, MLDataset):
raise TypeError(
f"other must be of type MLDataset, got {type(other)}")
if self._repeated != other.repeated:
raise TypeError(
f"want to union two MLDataset which have different repeated "
f"type, self repeated: {self._repeated}, other repeated: "
f"{other.repeated}")
batch_size = 0
if self._batch_size == other._batch_size:
batch_size = self._batch_size
actor_sets = []
actor_sets.extend(self.actor_sets)
actor_sets.extend(other.actor_sets)
# if one of these iterators is a result of a repartition, we need to
# keep an explicit reference to its parent iterator
return MLDataset(
actor_sets,
f"ParallelUnion[{self}, {other}]",
parent_iterators=self.parent_iterators + other.parent_iterators,
batch_size=batch_size,
repeated=self._repeated)
def select_shards(self, shards_to_keep: List[int]) -> "MLDataset":
para_it = super().select_shards(shards_to_keep)
return MLDataset.from_parallel_it(para_it, self._batch_size,
self._repeated)
def get_repeatable_shard(self,
index: int,
batch_ms: int = 0,
num_async: int = 1,
shuffle: bool = False,
shuffle_buffer_size: int = 1,
seed: int = None) -> Iterator:
"""Get the given shard of the current dataset.
The return is a iterator. Each call iter on the returned iterator will
get the shard data from beginning. And it support shuffle the return
iterator when each call iter on the return.
Args:
index (int): the shard index id, -1 means collect all data.
batch_ms (int): Batches items for batch_ms milliseconds
before retrieving it. Increasing batch_ms increases latency
but improves throughput. If this value is 0, then items are
returned immediately.
num_async (int): The max number of requests in flight. Increasing
this improves the amount of pipeline parallelism in the
iterator.
shuffle (bool): whether shuffle the given shard data
shuffle_buffer_size (int): same as ParallelIterator.local_shuffle
seed (int): the random seed
Returns:
The given shard iterator. If the shuffle is True, each call iter
will return a different ordered iterator.
"""
return _RepeatableIterator(self, index, batch_ms, num_async, shuffle,
shuffle_buffer_size, seed)
def to_torch(self,
feature_columns=None,
feature_shapes=None,
feature_types=None,
label_column=None,
label_shape=None,
label_type=None):
"""Create a TorchMLDataset from the current MLDataset.
Args:
feature_columns (List[Any]): the column indexes name.
feature_shapes (Optional[List[Any]]): the feature shapes should
match the feature columns if provided.
feature_types (Optional[List["torch.dtype"]]): the feature types
should match the feature columns if provided. All feature will
be cast into torch.float by default. Otherwise, cast into the
provided type.
label_column (Any): the label name.
label_shape (Optional[int]): the label shape.
label_type (Optional["torch.dtype"]): the label type, this will be
cast into torch.float by default
Returns:
A TorchMLDataset
"""
from ray.util.sgd.torch.torch_dataset import TorchMLDataset
return TorchMLDataset(self, feature_columns, feature_shapes,
feature_types, label_column, label_shape,
label_type)
def to_tf(self,
feature_columns=None,
feature_shapes=None,
feature_types=None,
label_column=None,
label_shape=None,
label_type=None):
"""Create a TFMLDataset from the current MLDataset.
Args:
feature_columns (List[Any]): the column names.
feature_shapes (Optional[List[tf.TensorShape]]): the feature shapes
should match the feature columns if provided.
feature_types (Optional[List["tf.DType"]]): the feature types
should match the feature columns if provided. All feature will
be cast into tf.float by default. Otherwise, cast into the
provided type.
label_column (Any): the label name.
label_shape (Optional[tf.TensorShape]): the label shape.
label_type (Optional["tf.DType"]): the label type, this will be
cast into tf.float by default
Returns:
A TFMLDataset
"""
from ray.util.sgd.tf.tf_dataset import TFMLDataset
return TFMLDataset(self, feature_columns, feature_shapes,
feature_types, label_column, label_shape,
label_type)
class _RepeatableIterator(Iterator[T]):
"""A repeatable iterator for the given shard index data.
Each call iter(_RepeatableIterator instance) will fetch the data from
beginning and will return a different order or data if set shuffle
Args:
ds (MLDataset): a MLDataset
shard_index (int): the shard index id. -1 means collect all data.
batch_ms (int): Batches items for batch_ms milliseconds
before retrieving it. Increasing batch_ms increases latency
but improves throughput. If this value is 0, then items are
returned immediately.
num_async (int): The max number of requests in flight. Increasing this
improves the amount of pipeline parallelism in the iterator.
shuffle (bool): whether shuffle the given shard data
shuffle_buffer_size (int): same as ParallelIterator.local_shuffle
seed (int): the random seed
"""
def __init__(self,
ds: MLDataset,
shard_index: int,
batch_ms: int = 0,
num_async: int = 1,
shuffle: bool = False,
shuffle_buffer_size: int = 1,
seed: int = None):
super(_RepeatableIterator, self).__init__()
self._ds = ds
self._shard_index = shard_index
self._batch_ms = batch_ms
self._num_async = num_async
self._shuffle = shuffle
self._shuffle_buffer_size = shuffle_buffer_size
self._seed = seed
self._local_it: LocalIterator[T] = None
self._i = 0
def __next__(self) -> T:
assert self._local_it is not None
return next(self._local_it)
def __iter__(self) -> Iterator[T]:
if self._shard_index >= 0:
it = self._ds.get_shard(self._shard_index, self._batch_ms,
self._num_async)
else:
if self._num_async > 0:
it = self._ds.gather_async(
batch_ms=self._batch_ms, num_async=self._num_async)
else:
it = self._ds.gather_sync()
if self._shuffle:
it = self.shuffle(it)
self._local_it = it
return self
def shuffle(self,
local_it: LocalIterator[T]) -> LocalIterator[pd.DataFrame]:
shuffle_random = random.Random(self._seed)
def apply_shuffle(it):
buffer = []
for item in it:
if isinstance(item, _NextValueNotReady):
yield item
else:
buffer.append(item)
if len(buffer) >= self._shuffle_buffer_size:
item = buffer.pop(
shuffle_random.randint(0,
len(buffer) - 1))
item = item.sample(frac=1, random_state=self._seed)
yield item
while len(buffer) > 0:
item = buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
item = item.sample(frac=1, random_state=self._seed)
yield item
return LocalIterator(
local_it.base_iterator,
local_it.shared_metrics,
local_it.local_transforms + [apply_shuffle],
name=local_it.name +
".shuffle(shuffle_buffer_size={}, seed={})".format(
self._shuffle_buffer_size,
str(self._seed) if self._seed is not None else "None"))

View file

@ -0,0 +1,66 @@
import json
import os
import ray
import ray.util.data as ml_data
import ray.util.iter as parallel_it
from ray.util.sgd.tf.tf_dataset import TFMLDataset
from ray.util.sgd.tf.tf_trainer import TFTrainer
def model_creator(config):
import tensorflow as tf
model = tf.keras.models.Sequential([
tf.keras.Input(shape=(1, )),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(1)
])
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
model.compile(optimizer=optimizer, loss="mse", metrics=["accuracy"])
return model
def make_data_creator(tf_ds: TFMLDataset):
def data_creator(config):
world_rank = None
if "TF_CONFIG" in os.environ:
tf_config = json.loads(os.environ["TF_CONFIG"])
world_rank = tf_config["task"]["index"]
else:
world_rank = -1
batch_size = config["batch_size"]
ds = tf_ds.get_shard(shard_index=world_rank).batch(batch_size).repeat()
return ds, None
return data_creator
def main():
num_points = 32 * 100 * 2
data = [i * (1 / num_points) for i in range(num_points)]
it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
# this will create MLDataset with column RangeIndex(range(2))
ds = ml_data.from_parallel_iter(it, True, batch_size=32, repeated=False)
tf_ds = ds.to_tf(feature_columns=[0], label_column=1)
trainer = TFTrainer(
model_creator=model_creator,
data_creator=make_data_creator(tf_ds),
num_replicas=2,
config={
"batch_size": 32,
"fit_config": {
"steps_per_epoch": 100,
}
})
for _ in range(10):
trainer.train()
model = trainer.get_model()
print("f(0.5)=", float(model.predict([0.5])))
if __name__ == "__main__":
ray.init()
main()

View file

@ -0,0 +1,70 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
import ray
import ray.util.data as ml_data
import ray.util.iter as parallel_it
from ray.util.sgd.torch.torch_dataset import TorchMLDataset
from ray.util.sgd.torch.torch_trainer import TorchTrainer
from ray.util.sgd.torch.training_operator import TrainingOperator
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 128)
self.fc2 = nn.Linear(128, 1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
def make_train_operator(ds: TorchMLDataset):
class IdentityTrainOperator(TrainingOperator):
def setup(self, config):
model = Net()
optimizer = torch.optim.SGD(
model.parameters(), lr=config.get("lr", 1e-4))
loss = torch.nn.MSELoss()
batch_size = config["batch_size"]
train_data = ds.get_shard(
self.world_rank, shuffle=True, shuffle_buffer_size=4)
train_loader = DataLoader(train_data, batch_size=batch_size)
self.model, self.optimizer, self.criterion = self.register(
models=model, optimizers=optimizer, criterion=loss)
self.register_data(
train_loader=train_loader, validation_loader=None)
return IdentityTrainOperator
def main():
num_points = 32 * 100 * 2
data = [i * (1 / num_points) for i in range(num_points)]
it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
# this will create MLDataset with column RangeIndex(range(2))
ds = ml_data.from_parallel_iter(it, True, batch_size=32, repeated=False)
torch_ds = ds.to_torch(feature_columns=[0], label_column=1)
trainer = TorchTrainer(
num_workers=2,
training_operator_cls=make_train_operator(torch_ds),
add_dist_sampler=False,
config={"batch_size": 32})
for i in range(10):
trainer.train(num_steps=100)
model = trainer.get_model()
print("f(0.5)=", float(model(torch.tensor([[0.5]]).float())[0][0]))
if __name__ == "__main__":
ray.init()
main()

View file

@ -0,0 +1,21 @@
from typing import Iterable
import pandas as pd
class _SourceShard:
def prefix(self) -> str:
raise NotImplementedError
@property
def shard_id(self) -> int:
raise NotImplementedError
def __iter__(self) -> Iterable[pd.DataFrame]:
raise NotImplementedError
def __str__(self):
return repr(self)
def __repr__(self):
return f"{self.prefix()}SourceShard[{self.shard_id}]"

View file

@ -0,0 +1,111 @@
import random
from typing import Iterable
from typing import List, Optional, Union
import pyarrow.parquet as pq
from pandas import DataFrame
import ray.util.iter as para_iter
from .dataset import MLDataset
from .interface import _SourceShard
class ParquetSourceShard(_SourceShard):
def __init__(self, data_pieces: List[pq.ParquetDatasetPiece],
columns: Optional[List[str]],
partitions: Optional[pq.ParquetPartitions], shard_id: int):
self._data_pieces = data_pieces
self._columns = columns
self._partitions = partitions
self._shard_id = shard_id
def prefix(self) -> str:
return "Parquet"
@property
def shard_id(self) -> int:
return self._shard_id
def __iter__(self) -> Iterable[DataFrame]:
for piece in self._data_pieces:
yield piece.read(
columns=self._columns,
use_threads=False,
partitions=self._partitions).to_pandas()
def read_parquet(paths: Union[str, List[str]],
num_shards: int,
rowgroup_split: bool = True,
shuffle: bool = False,
shuffle_seed: int = None,
columns: Optional[List[str]] = None,
**kwargs) -> MLDataset:
"""Read parquet format data from hdfs like filesystem into a MLDataset.
.. code-block:: python
# create dummy data
spark.range(...).write.parquet(...)
# create MLDataset
data = ray.util.data.read_parquet(...)
# convert to TorchMLDataset
ds = data.to_torch(feature_columns=..., label_column=...)
# get the given shard data
shard = ds.get_shard(0)
# create the DataLoader from the shard data and this can be used for
# the TorchTrainer data creator as well
data = DataLoader(shard, batch_size=32)
Args:
paths (Union[str, List[str]): a single file path or a list of file path
num_shards (int): the number of shards
rowgroup_split (bool): whether split the files into shards based on
rowgroup. If set False, each shard will have a list of files.
shuffle (bool): whether shuffle the ParquetDatasetPiece order when
divide into shards
shuffle_seed (int): the shuffle seed
columns (Optional[List[str]]): a list of column names to read
kwargs: the other parquet read options
Returns:
A MLDataset
"""
pq_ds = pq.ParquetDataset(paths, **kwargs)
pieces = pq_ds.pieces
data_pieces = []
if rowgroup_split:
# split base on rowgroup
for piece in pieces:
num_row_groups = piece.get_metadata().to_dict()["num_row_groups"]
for i in range(num_row_groups):
data_pieces.append(
pq.ParquetDatasetPiece(piece.path, piece.open_file_func,
piece.file_options, i,
piece.partition_keys))
else:
# split base on file pieces
data_pieces = pieces.copy()
if len(data_pieces) < num_shards:
raise ValueError(f"number of data pieces: {len(data_pieces)} should "
f"larger than num_shards: {num_shards}")
if shuffle:
random_shuffle = random.Random(shuffle_seed)
random_shuffle.shuffle(data_pieces)
shards = [[] for _ in range(num_shards)]
for i, item in enumerate(data_pieces):
shard = shards[i % num_shards]
if item.row_group is None:
for number in item.get_metadata().to_dict()["num_row_groups"]:
shard.append(
pq.ParquetDatasetPiece(item.path, item.open_file_func,
item.file_options, number,
item.partition_keys))
else:
shard.append(item)
for i, shard in enumerate(shards):
shards[i] = ParquetSourceShard(shard, columns, pq_ds.partitions, i)
it = para_iter.from_iterators(shards, False, "parquet")
return MLDataset.from_parallel_it(it, batch_size=0, repeated=False)

View file

@ -1,13 +1,18 @@
import os
import pytest
import tempfile
import numpy as np
import shutil
import tempfile
import numpy as np
import pytest
import ray
import ray.util.data as ml_data
import ray.util.iter as parallel_it
from ray import tune
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.util.data.examples.mlp_identity_tf import (model_creator,
make_data_creator)
from ray.util.sgd.tf import TFTrainer, TFTrainable
from ray.util.sgd.tf.examples.tensorflow_train_example import (simple_model,
simple_dataset)
@ -22,6 +27,14 @@ SIMPLE_CONFIG = {
}
@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2])
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
@ -99,6 +112,35 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
model2.optimizer.get_weights()
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2])
def test_tf_dataset(ray_start_4_cpus): # noqa: F811
num_points = 32 * 100 * 2
data = [i * (1 / num_points) for i in range(num_points)]
it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
# this will create MLDataset with column RangeIndex(range(2))
ds = ml_data.from_parallel_iter(it, True, batch_size=32, repeated=False)
tf_ds = ds.to_tf(feature_columns=[0], label_column=1)
trainer = TFTrainer(
model_creator=model_creator,
data_creator=make_data_creator(tf_ds),
num_replicas=2,
config={
"batch_size": 32,
"fit_config": {
"steps_per_epoch": 100,
}
})
for _ in range(10):
trainer.train()
model = trainer.get_model()
prediction = model.predict([0.5])[0][0]
assert 0.4 <= prediction <= 0.6
trainer.shutdown()
def _compare(d1, d2, skip_keys=None):
"""Compare two lists or dictionaries or array"""
if type(d1) != type(d2):

View file

@ -1,23 +1,26 @@
import numpy as np
import os
import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.distributed as dist
from ray.tune.utils import merge_dicts
import torch.nn as nn
from torch.utils.data import DataLoader
import ray
import ray.util.data as ml_data
import ray.util.iter as parallel_it
from ray import tune
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.training_operator import (
get_test_operator, get_test_metrics_operator, TrainingOperator)
from ray.util.sgd.torch.constants import SCHEDULER_STEP
from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE)
from ray.tune.utils import merge_dicts
from ray.util.data.examples.mlp_identity_torch import make_train_operator
from ray.util.sgd.data.examples import mlp_identity
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.constants import SCHEDULER_STEP
from ray.util.sgd.torch.examples.train_example import (
model_creator, optimizer_creator, data_creator, LinearDataset)
from ray.util.sgd.torch.training_operator import (
get_test_operator, get_test_metrics_operator, TrainingOperator)
from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE)
@pytest.fixture
@ -834,6 +837,30 @@ def test_multi_input_model(ray_start_2_cpus, use_local):
trainer.shutdown()
@pytest.mark.parametrize("use_local", [True, False])
def test_torch_dataset(ray_start_4_cpus, use_local):
num_points = 32 * 100 * 2
data = [i * (1 / num_points) for i in range(num_points)]
para_it = parallel_it.from_items(data, 2, False).for_each(lambda x: [x, x])
ds = ml_data.from_parallel_iter(para_it, batch_size=32)
torch_ds = ds.to_torch(feature_columns=[0], label_column=1)
operator = make_train_operator(torch_ds)
trainer = TorchTrainer(
training_operator_cls=operator,
num_workers=2,
use_local=use_local,
add_dist_sampler=False,
config={"batch_size": 32})
for i in range(10):
trainer.train(num_steps=100)
model = trainer.get_model()
prediction = float(model(torch.tensor([[0.5]]).float())[0][0])
assert 0.4 <= prediction <= 0.6
trainer.shutdown()
if __name__ == "__main__":
import pytest
import sys

View file

@ -0,0 +1,137 @@
import logging
from typing import Any, List, Optional
import tensorflow as tf
from ray.util.data import MLDataset
class TFMLDataset:
""" A TFMLDataset which converted from MLDataset
.. code-block:: python
ds = ml_dataset.to_tf(feature_columns=["x"], label_column="y")
shard = ds.get_shard(0) # the data as (x_value, y_value)
ds = ml_dataset.to_tf(feature_columns=["x", "y"], label_column="z")
shard = ds.get_shard(0) # the data as ((x_value, y_value), z_value)
Args:
ds (MLDataset): a MLDataset
feature_columns (List[Any]): the feature columns' name
feature_shapes (Optional[List[tf.TensorShape]]): the shape for each
feature. If provide, it should match the size of feature_columns
feature_types (Optional[List[tf.DType]]): the data type for each
feature. If provide, it should match the size of feature_columns
label_column (Any): the label column name
label_shape (Optional[tf.TensorShape]): the shape for the label data
label_type (Optional[tf.DType]): the data type for the label data
"""
def __init__(self, ds: MLDataset, feature_columns: List[Any],
feature_shapes: Optional[List[tf.TensorShape]],
feature_types: Optional[List[tf.DType]], label_column: Any,
label_shape: Optional[tf.TensorShape],
label_type: Optional[tf.DType]):
self._feature_columns = feature_columns
self._feature_shapes = feature_shapes
self._feature_types = feature_types
self._label_column = label_column
self._label_shape = label_shape
self._label_type = label_type
self._check_and_convert()
self._ds = ds
def _check_and_convert(self):
# convert to list for convenience
if not isinstance(self._feature_columns, list):
self._feature_columns = [self._feature_columns]
if self._feature_shapes:
if not isinstance(self._feature_shapes, list):
self._feature_shapes = [self._feature_shapes]
assert len(self._feature_columns) == len(self._feature_shapes), \
"The feature_shapes size must match the feature_columns"
if self._feature_types:
if not isinstance(self._feature_types, list):
self._feature_types = [self._feature_types]
assert len(self._feature_columns) == len(self._feature_types), \
"The feature_types size must match the feature_columns"
for i in range(len(self._feature_types)):
assert (all(isinstance(dtype, tf.DType)
for dtype in self._feature_types)), \
"All value in feature_types should be tf.DType instance"
if not self._feature_shapes:
self._feature_shapes = [tf.TensorShape(
([]))] * len(self._feature_columns)
if not self._feature_types:
self._feature_types = [tf.float32] * len(self._feature_columns)
if not self._label_type:
self._label_type = tf.float32
if not self._label_shape:
self._label_shape = tf.TensorShape(([]))
def set_num_shards(self, num_shards):
""" Repartition the MLDataset to given number of shards"""
if num_shards != self._ds.num_shards():
logging.info("Setting num shards", num_shards)
self._ds = self._ds.repartition(num_shards)
def get_shard(self,
shard_index: int,
batch_ms: int = 0,
num_async: int = 1,
shuffle: bool = False,
shuffle_buffer_size: int = 1,
seed: int = None) -> "tf.data.Dataset":
""" Get the given shard data.
Get a the given shard data from MLDataset and convert into a
tensorflow.data.Dataset. If the shard_index is smaller than zero,
it will collect all data as a tensorflow.data.Dataset.
"""
it = self._ds.get_repeatable_shard(shard_index, batch_ms, num_async,
shuffle, shuffle_buffer_size, seed)
def make_generator():
for df in iter(it):
num_rows = df.shape[0]
feature_columns = [
df[col].values for col in self._feature_columns
]
label_column = df[self._label_column].values
for i in range(num_rows):
features = [f[i] for f in feature_columns]
if len(features) > 1:
yield tuple(features), label_column[i]
else:
yield features[0], label_column[i]
output_shapes = self._feature_shapes.copy()
if len(output_shapes) > 1:
output_shapes = (tuple(output_shapes), self._label_shape)
else:
output_shapes = (output_shapes[0], self._label_shape)
output_types = self._feature_types.copy()
if len(output_types) > 1:
output_types = (tuple(output_types), self._label_type)
else:
output_types = output_types[0], self._label_type
ds = tf.data.Dataset.from_generator(
make_generator,
output_types=output_types,
output_shapes=output_shapes)
return ds

View file

@ -0,0 +1,172 @@
import functools
import logging
from collections import Iterator
from collections.abc import Iterable
from typing import Any, Callable, List, Optional
import numpy as np
import torch
import pandas as pd
from torch.utils.data import IterableDataset
from ray.util.data import MLDataset
def convert_to_tensor(df, feature_columns: List[Any],
feature_shapes: List[Any],
feature_types: List[torch.dtype], label_column: Any,
label_shape: Optional[int], label_type: torch.dtype):
feature_tensor = []
for col, shape, dtype in zip(feature_columns, feature_shapes,
feature_types):
column = df[col].values
if column.dtype == np.object:
if isinstance(column[0], np.ndarray):
column = np.stack(column)
elif isinstance(column[0], (list, tuple)):
column = list(column)
else:
raise Exception(
f"Column {col}'s type: {type(column[0])} is not supported."
" It must be numpy built in type or numpy object of "
"(ndarray, list, tuple)")
t = torch.as_tensor(column, dtype=dtype)
if shape is not None:
t = t.view(*(-1, *shape))
else:
t = t.view(-1, 1)
feature_tensor.append(t)
label_df = df[label_column].values
label_tensor = torch.as_tensor(label_df, dtype=label_type)
if label_shape:
label_tensor = label_tensor.view(-1, label_shape)
else:
label_tensor = label_tensor.view(-1, 1)
return feature_tensor, label_tensor
class TorchMLDataset:
"""A TorchMLDataset which converted from MLDataset
.. code-block:: python
ds = ml_dataset.to_torch(feature_columns=["x"], label_column="y")
shard = ds.get_shard(0)
data = DataLoader(shard, batch_size=32)
batch_tensor_x, batch_tensor_y = next(iter(data))
ds = ml_dataset.to_torch(feature_columns=["x", "y"], label_column="z")
shard = ds.get_shard(0)
data = DataLoader(shard, batch_size=32)
batch_tensor_x, batch_tensor_y, batch_tensor_z = next(iter(data))
Args:
ds (MLDataset): a MLDataset
feature_columns (List[Any]): the feature columns' name
feature_shapes (Optional[List[Any]]): the shape for each
feature. If provide, it should match the size of feature_columns.
feature_types (Optional[List[torch.dtype]]): the data type for each
feature. If provide, it should match the size of feature_columns
label_column (Any): the label column name
label_shape (Optional[int]): the shape for the label data
label_type (Optional[torch.dtype]): the data type for the label data
"""
def __init__(self,
ds: MLDataset = None,
feature_columns: List[Any] = None,
feature_shapes: Optional[List[Any]] = None,
feature_types: Optional[List[torch.dtype]] = None,
label_column: Any = None,
label_shape: Optional[int] = None,
label_type: Optional[torch.dtype] = None):
self._feature_columns = feature_columns
self._feature_shapes = feature_shapes
self._feature_types = feature_types
self._label_column = label_column
self._label_shape = label_shape
self._label_type = label_type
self._check_and_convert()
self._ds = ds
def _check_and_convert(self):
# convert to list for convenience
if not isinstance(self._feature_columns, list):
self._feature_columns = [self._feature_columns]
if self._feature_shapes:
if not isinstance(self._feature_shapes, list):
self._feature_shapes = [self._feature_shapes]
assert len(self._feature_columns) == len(self._feature_shapes), \
"The feature_shapes size must match the feature_columns"
for i in range(len(self._feature_shapes)):
if not isinstance(self._feature_shapes[i], Iterable):
self._feature_shapes[i] = [self._feature_shapes[i]]
else:
self._feature_shapes = [None] * len(self._feature_columns)
if self._feature_types:
if not isinstance(self._feature_types, list):
self._feature_types = [self._feature_types]
assert len(self._feature_columns) == len(self._feature_types), \
"The feature_types size must match the feature_columns"
for i in range(len(self._feature_types)):
assert (all(isinstance(dtype, torch.dtype)
for dtype in self._feature_types)), \
"All value in feature_types should be torch.dtype instance"
else:
self._feature_types = [torch.float] * len(self._feature_columns)
if not self._label_type:
self._label_type = torch.float
def set_num_shards(self, num_shards):
"""Reshards the iterator if necessary"""
if num_shards != self._ds.num_shards():
logging.info("Setting num shards", num_shards)
self._ds = self._ds.repartition(num_shards)
def get_shard(self,
shard_index: int,
batch_ms: int = 0,
num_async: int = 1,
shuffle: bool = False,
shuffle_buffer_size: int = 1,
seed: int = None) -> torch.utils.data.IterableDataset:
it = self._ds.get_repeatable_shard(shard_index, batch_ms, num_async,
shuffle, shuffle_buffer_size, seed)
convert_fn = functools.partial(
convert_to_tensor,
feature_columns=self._feature_columns,
feature_shapes=self._feature_shapes,
feature_types=self._feature_types,
label_column=self._label_column,
label_shape=self._label_shape,
label_type=self._label_type)
return TorchIterableDataset(it, convert_fn)
class TorchIterableDataset(IterableDataset):
def __init__(self, it: Iterator,
convert_fn: Callable[[pd.DataFrame], Any]):
super().__init__()
self._it = it
self._convert_fn = convert_fn
def __iter__(self):
for df in iter(self._it):
num_rows = df.shape[0]
feature_tensor, label_tensor = self._convert_fn(df)
for i in range(num_rows):
features = [tensor[i] for tensor in feature_tensor]
label = label_tensor[i]
yield (*features, label)