ray/rllib/utils/compression.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

92 lines
2.1 KiB
Python
Raw Normal View History

from ray.rllib.utils.annotations import DeveloperAPI
import logging
import time
import base64
import numpy as np
2020-02-19 19:01:26 -08:00
from ray import cloudpickle as pickle
from six import string_types
logger = logging.getLogger(__name__)
try:
import lz4.frame
LZ4_ENABLED = True
except ImportError:
logger.warning(
"lz4 not available, disabling sample compression. "
"This will significantly impact RLlib performance. "
"To install lz4, run `pip install lz4`."
)
LZ4_ENABLED = False
@DeveloperAPI
def compression_supported():
return LZ4_ENABLED
@DeveloperAPI
def pack(data):
if LZ4_ENABLED:
2020-02-19 19:01:26 -08:00
data = pickle.dumps(data)
data = lz4.frame.compress(data)
# TODO(ekl) we shouldn't need to base64 encode this data, but this
# seems to not survive a transfer through the object store if we don't.
data = base64.b64encode(data).decode("ascii")
return data
@DeveloperAPI
def pack_if_needed(data):
if isinstance(data, np.ndarray):
data = pack(data)
return data
@DeveloperAPI
def unpack(data):
if LZ4_ENABLED:
data = base64.b64decode(data)
data = lz4.frame.decompress(data)
2020-02-19 19:01:26 -08:00
data = pickle.loads(data)
return data
@DeveloperAPI
def unpack_if_needed(data):
if is_compressed(data):
data = unpack(data)
return data
@DeveloperAPI
def is_compressed(data):
return isinstance(data, bytes) or isinstance(data, string_types)
# Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
# Compression speed: 753.664 MB/s
# Compression ratio: 87.4839812046
# Decompression speed: 910.9504 MB/s
if __name__ == "__main__":
size = 32 * 80 * 80 * 4
data = np.ones(size).reshape((32, 80, 80, 4))
count = 0
start = time.time()
while time.time() - start < 1:
pack(data)
count += 1
compressed = pack(data)
print("Compression speed: {} MB/s".format(count * size * 4 / 1e6))
print("Compression ratio: {}".format(round(size * 4 / len(compressed), 2)))
count = 0
start = time.time()
while time.time() - start < 1:
unpack(compressed)
count += 1
print("Decompression speed: {} MB/s".format(count * size * 4 / 1e6))