mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
Update arrow to include pandas serialization (#1102)
* update arrow to include pandas serialization * update
This commit is contained in:
parent
b1660c4edf
commit
0684258d2e
2 changed files with 41 additions and 1 deletions
|
@ -1080,6 +1080,46 @@ def _initialize_serialization(worker=global_worker):
|
||||||
custom_serializer=default_dict_custom_serializer,
|
custom_serializer=default_dict_custom_serializer,
|
||||||
custom_deserializer=default_dict_custom_deserializer)
|
custom_deserializer=default_dict_custom_deserializer)
|
||||||
|
|
||||||
|
def _serialize_pandas_series(s):
|
||||||
|
import pandas as pd
|
||||||
|
# TODO: serializing Series without extra copy
|
||||||
|
serialized = pyarrow.serialize_pandas(pd.DataFrame({s.name: s}))
|
||||||
|
return {
|
||||||
|
'type': 'Series',
|
||||||
|
'data': serialized.to_pybytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _serialize_pandas_dataframe(df):
|
||||||
|
return {
|
||||||
|
'type': 'DataFrame',
|
||||||
|
'data': pyarrow.serialize_pandas(df).to_pybytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _deserialize_callback_pandas(data):
|
||||||
|
deserialized = pyarrow.deserialize_pandas(data['data'])
|
||||||
|
type_ = data['type']
|
||||||
|
if type_ == 'Series':
|
||||||
|
return deserialized[deserialized.columns[0]]
|
||||||
|
elif type_ == 'DataFrame':
|
||||||
|
return deserialized
|
||||||
|
else:
|
||||||
|
raise ValueError(type_)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
worker.serialization_context.register_type(
|
||||||
|
pd.Series, 'pandas.Series',
|
||||||
|
custom_serializer=_serialize_pandas_series,
|
||||||
|
custom_deserializer=_deserialize_callback_pandas)
|
||||||
|
|
||||||
|
worker.serialization_context.register_type(
|
||||||
|
pd.DataFrame, 'pandas.DataFrame',
|
||||||
|
custom_serializer=_serialize_pandas_dataframe,
|
||||||
|
custom_deserializer=_deserialize_callback_pandas)
|
||||||
|
except ImportError:
|
||||||
|
# no pandas
|
||||||
|
pass
|
||||||
|
|
||||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||||
# These should only be called on the driver because _register_class
|
# These should only be called on the driver because _register_class
|
||||||
# will export the class to all of the workers.
|
# will export the class to all of the workers.
|
||||||
|
|
2
src/thirdparty/download_thirdparty.sh
vendored
2
src/thirdparty/download_thirdparty.sh
vendored
|
@ -13,4 +13,4 @@ fi
|
||||||
cd $TP_DIR/arrow
|
cd $TP_DIR/arrow
|
||||||
git fetch origin master
|
git fetch origin master
|
||||||
|
|
||||||
git checkout 988338c544580ffd367a5540f1061dd7b0fccc0e
|
git checkout ee78cdcb1c475a05df9cd9de63358e80ba280a63
|
||||||
|
|
Loading…
Add table
Reference in a new issue