ray/release/data_processing_tests/workloads/streaming_shuffle.py

180 lines
4.8 KiB
Python
Raw Normal View History

import time
import json
import ray
import numpy as np
from typing import List
from ray import ObjectRef
from tqdm import tqdm
from ray.cluster_utils import Cluster
num_nodes = 4
num_cpus = 4
partition_size = int(500e6) # 500MB
# Number of map & reduce tasks == num_partitions.
# Number of objects == num_partitions ^ 2.
num_partitions = 200
# There are two int64 per row, so we divide by 8 * 2 bytes.
rows_per_partition = partition_size // (8 * 2)
object_store_size = 20 * 1024 * 1024 * 1024 # 20G
system_config = {
"automatic_object_spilling_enabled": True,
"max_io_workers": 1,
"object_spilling_config": json.dumps(
{
"type": "filesystem",
"params": {
"directory_path": "/tmp/spill"
}
},
separators=(",", ":"))
}
def display_spilling_info(address):
state = ray.state.GlobalState()
state._initialize_global_state(address,
ray.ray_constants.REDIS_DEFAULT_PASSWORD)
raylet = state.node_table()[0]
memory_summary = ray.internal.internal_api.memory_summary(
raylet["NodeManagerAddress"], raylet["NodeManagerPort"])
for line in memory_summary.split("\n"):
if "Spilled" in line:
print(line)
if "Restored" in line:
print(line)
print("\n\n")
@ray.remote
class Counter:
def __init__(self):
self.num_map = 0
self.num_reduce = 0
def inc(self):
self.num_map += 1
# print("Num map tasks finished", self.num_map)
def inc2(self):
self.num_reduce += 1
# print("Num reduce tasks finished", self.num_reduce)
def finish(self):
pass
# object store peak memory: O(partition size / num partitions)
# heap memory: O(partition size / num partitions)
@ray.remote(num_returns=num_partitions)
def shuffle_map_streaming(i,
counter_handle=None) -> List[ObjectRef[np.ndarray]]:
outputs = [
ray.put(
np.ones((rows_per_partition // num_partitions, 2), dtype=np.int64))
for _ in range(num_partitions)
]
counter_handle.inc.remote()
return outputs
# object store peak memory: O(partition size / num partitions)
# heap memory: O(partition size) -- TODO can be reduced too
@ray.remote
def shuffle_reduce_streaming(*inputs, counter_handle=None) -> np.ndarray:
out = None
for chunk in inputs:
if out is None:
out = ray.get(chunk)
else:
out = np.concatenate([out, ray.get(chunk)])
counter_handle.inc2.remote()
return out
shuffle_map = shuffle_map_streaming
shuffle_reduce = shuffle_reduce_streaming
def run_shuffle():
counter = Counter.remote()
start = time.time()
print("start map")
shuffle_map_out = [
shuffle_map.remote(i, counter_handle=counter)
for i in range(num_partitions)
]
# wait until all map is done before reduce phase.
for out in tqdm(shuffle_map_out):
ray.get(out)
# Start reducing
shuffle_reduce_out = [
shuffle_reduce.remote(
*[shuffle_map_out[i][j] for i in range(num_partitions)],
counter_handle=counter) for j in range(num_partitions)
]
print("start shuffle.")
pbar = tqdm(total=num_partitions)
total_rows = 0
ready, unready = ray.wait(shuffle_reduce_out)
while unready:
ready, unready = ray.wait(unready)
for output in ready:
pbar.update(1)
total_rows += ray.get(output).shape[0]
delta = time.time() - start
ray.get(counter.finish.remote())
print("Shuffled", total_rows * 8 * 2, "bytes in", delta,
"seconds in a single node.\n")
def run_single_node():
address = ray.init(
num_cpus=num_cpus * num_nodes,
object_store_memory=object_store_size,
_system_config=system_config)
# Run shuffle.
print(
"\n\nTest streaming shuffle with a single node.\n"
f"Shuffle size: {partition_size * num_partitions / 1024 / 1024 / 1024}"
"GB")
run_shuffle()
time.sleep(5)
display_spilling_info(address["redis_address"])
ray.shutdown()
time.sleep(5)
def run_multi_nodes():
c = Cluster()
c.add_node(
num_cpus=4,
object_store_memory=object_store_size,
_system_config=system_config)
ray.init(address=c.address)
for _ in range(num_nodes - 1): # subtract a head node.
c.add_node(num_cpus=4, object_store_memory=object_store_size)
c.wait_for_nodes()
# Run shuffle.
print(
f"\n\nTest streaming shuffle with {num_nodes} nodes.\n"
f"Shuffle size: {partition_size * num_partitions / 1024 / 1024 / 1024}"
"GB")
run_shuffle()
time.sleep(5)
display_spilling_info(c.address)
ray.shutdown()
c.shutdown()
time.sleep(5)
run_single_node()
run_multi_nodes()