[Datasets] Use detached lifetime for stats actor (#25271)

The actor handle held at Ray client will become dangling if the Ray cluster is shutdown, and in such case if the user tries to get the actor again it will result in crash. This happened in a real user and blocked them from making progress.

This change makes the stats actor detached, and instead of keeping a handle, we access it via its name. This way we can make sure re-create this actor if the cluster gets restarted.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-32-136.us-west-2.compute.internal>
This commit is contained in:
Jian Xiao 2022-08-11 17:47:13 -07:00 committed by GitHub
parent b7a6a1294a
commit b1cad0a112
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 38 deletions

View file

@ -95,6 +95,7 @@ class LazyBlockList(BlockList):
# Whether the block list is owned by consuming APIs, and if so it can be
# eagerly deleted after read by the consumer.
self._owned_by_consumer = owned_by_consumer
self._stats_actor = _get_or_create_stats_actor()
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
"""Get the metadata for all blocks."""
@ -139,6 +140,7 @@ class LazyBlockList(BlockList):
None for _ in self._block_partition_meta_refs
]
self._cached_metadata = [None for _ in self._cached_metadata]
self._stats_actor = None
def is_cleared(self) -> bool:
return all(ref is None for ref in self._block_partition_refs)
@ -537,7 +539,9 @@ class LazyBlockList(BlockList):
self, task_idx: int
) -> Tuple[ObjectRef[MaybeBlockPartition], ObjectRef[BlockPartitionMetadata]]:
"""Submit the task with index task_idx."""
stats_actor = _get_or_create_stats_actor()
if self._stats_actor is None:
self._stats_actor = _get_or_create_stats_actor()
stats_actor = self._stats_actor
if not self._execution_started:
stats_actor.record_start.remote(self._stats_uuid)
self._execution_started = True

View file

@ -11,6 +11,9 @@ from ray.data.block import BlockMetadata
from ray.data.context import DatasetContext
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
STATS_ACTOR_NAME = "datasets_stats_actor"
STATS_ACTOR_NAMESPACE = "_dataset_stats_actor"
def fmt(seconds: float) -> str:
if seconds > 1:
@ -84,7 +87,7 @@ class _DatasetStatsBuilder:
class _StatsActor:
"""Actor holding stats for blocks created by LazyBlockList.
This actor is shared across all datasets created by the same process.
This actor is shared across all datasets created in the same cluster.
The stats data is small so we don't worry about clean up for now.
TODO(ekl) we should consider refactoring LazyBlockList so stats can be
@ -112,40 +115,23 @@ class _StatsActor:
)
# Actor handle, job id the actor was created for.
_stats_actor = [None, None]
def _get_or_create_stats_actor():
# Need to re-create it if Ray restarts (mostly for unit tests).
if (
not _stats_actor[0]
or not ray.is_initialized()
or _stats_actor[1] != ray.get_runtime_context().job_id.hex()
):
ctx = DatasetContext.get_current()
scheduling_strategy = ctx.scheduling_strategy
if not ray.util.client.ray.is_connected():
# Pin the stats actor to the local node
# so it fate-shares with the driver.
scheduling_strategy = NodeAffinitySchedulingStrategy(
ray.get_runtime_context().get_node_id(),
soft=False,
)
_stats_actor[0] = _StatsActor.options(
name="datasets_stats_actor",
get_if_exists=True,
scheduling_strategy=scheduling_strategy,
).remote()
_stats_actor[1] = ray.get_runtime_context().job_id.hex()
# Clear the actor handle after Ray reinits since it's no longer valid.
def clear_actor():
_stats_actor[0] = None
ray._private.worker._post_init_hooks.append(clear_actor)
return _stats_actor[0]
ctx = DatasetContext.get_current()
scheduling_strategy = ctx.scheduling_strategy
if not ray.util.client.ray.is_connected():
# Pin the stats actor to the local node
# so it fate-shares with the driver.
scheduling_strategy = NodeAffinitySchedulingStrategy(
ray.get_runtime_context().get_node_id(),
soft=False,
)
return _StatsActor.options(
name=STATS_ACTOR_NAME,
namespace=STATS_ACTOR_NAMESPACE,
get_if_exists=True,
lifetime="detached",
scheduling_strategy=scheduling_strategy,
).remote()
class DatasetStats:
@ -223,10 +209,9 @@ class DatasetStats:
already_printed = set()
if self.needs_stats_actor:
ac = self.stats_actor
# XXX this is a super hack, clean it up.
stats_map, self.time_total_s = ray.get(
self.stats_actor.get.remote(self.stats_uuid)
)
stats_map, self.time_total_s = ray.get(ac.get.remote(self.stats_uuid))
for i, metadata in stats_map.items():
self.stages["read"][i] = metadata
out = ""