Cap the number of stats kept in StatsActor and purge in FIFO order if the limit exceeded (#27964)

There is a risk of using too much of memory in StatsActor, because its lifetime is the same as cluster lifetime.
This puts a cap on how many stats to keep, and purge the stats in FIFO order if this cap is exceeded.
This commit is contained in:
Jian Xiao 2022-08-18 10:25:31 -07:00 committed by GitHub
parent 24aeea8332
commit 440ae620eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 5 deletions

View file

@ -88,23 +88,40 @@ class _StatsActor:
"""Actor holding stats for blocks created by LazyBlockList.
This actor is shared across all datasets created in the same cluster.
The stats data is small so we don't worry about clean up for now.
In order to cap memory usage, we set a max number of stats to keep
in the actor. When this limit is exceeded, the stats will be garbage
collected in FIFO order.
TODO(ekl) we should consider refactoring LazyBlockList so stats can be
extracted without using an out-of-band actor."""
def __init__(self):
def __init__(self, max_stats=1000):
# Mapping from uuid -> dataset-specific stats.
self.metadata = collections.defaultdict(dict)
self.last_time = {}
self.start_time = {}
self.max_stats = max_stats
self.fifo_queue = []
def record_start(self, stats_uuid):
self.start_time[stats_uuid] = time.perf_counter()
self.fifo_queue.append(stats_uuid)
# Purge the oldest stats if the limit is exceeded.
if len(self.fifo_queue) > self.max_stats:
uuid = self.fifo_queue.pop(0)
if uuid in self.start_time:
del self.start_time[uuid]
if uuid in self.last_time:
del self.last_time[uuid]
if uuid in self.metadata:
del self.metadata[uuid]
def record_task(self, stats_uuid, i, metadata):
self.metadata[stats_uuid][i] = metadata
self.last_time[stats_uuid] = time.perf_counter()
def record_task(self, stats_uuid, task_idx, metadata):
# Null out the schema to keep the stats size small.
metadata.schema = None
if stats_uuid in self.start_time:
self.metadata[stats_uuid][task_idx] = metadata
self.last_time[stats_uuid] = time.perf_counter()
def get(self, stats_uuid):
if stats_uuid not in self.metadata:
@ -114,6 +131,9 @@ class _StatsActor:
self.last_time[stats_uuid] - self.start_time[stats_uuid],
)
def _get_stats_dict_size(self):
return len(self.start_time), len(self.last_time), len(self.metadata)
def _get_or_create_stats_actor():
ctx = DatasetContext.get_current()

View file

@ -12,6 +12,7 @@ import pytest
import ray
from ray._private.test_utils import wait_for_condition
from ray.data._internal.stats import _StatsActor
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_builder import BlockBuilder
from ray.data._internal.lazy_block_list import LazyBlockList
@ -4688,6 +4689,43 @@ def test_parquet_read_spread(ray_start_cluster, tmp_path):
assert set(locations) == {node1_id, node2_id}
def test_stats_actor_cap_num_stats(ray_start_cluster):
actor = _StatsActor.remote(3)
metadatas = []
task_idx = 0
for uuid in range(3):
metadatas.append(
BlockMetadata(
num_rows=uuid,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
)
)
num_stats = uuid + 1
actor.record_start.remote(uuid)
assert ray.get(actor._get_stats_dict_size.remote()) == (
num_stats,
num_stats - 1,
num_stats - 1,
)
actor.record_task.remote(uuid, task_idx, metadatas[-1])
assert ray.get(actor._get_stats_dict_size.remote()) == (
num_stats,
num_stats,
num_stats,
)
for uuid in range(3):
assert ray.get(actor.get.remote(uuid))[0][task_idx] == metadatas[uuid]
# Add the fourth stats to exceed the limit.
actor.record_start.remote(3)
# The first stats (with uuid=0) should have been purged.
assert ray.get(actor.get.remote(0))[0] == {}
# The start_time has 3 entries because we just added it above with record_start().
assert ray.get(actor._get_stats_dict_size.remote()) == (3, 2, 2)
@ray.remote
class Counter:
def __init__(self):