diff --git a/python/ray/data/_internal/lazy_block_list.py b/python/ray/data/_internal/lazy_block_list.py index ef69d28f1..fcc10dbe2 100644 --- a/python/ray/data/_internal/lazy_block_list.py +++ b/python/ray/data/_internal/lazy_block_list.py @@ -95,6 +95,7 @@ class LazyBlockList(BlockList): # Whether the block list is owned by consuming APIs, and if so it can be # eagerly deleted after read by the consumer. self._owned_by_consumer = owned_by_consumer + self._stats_actor = _get_or_create_stats_actor() def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]: """Get the metadata for all blocks.""" @@ -139,6 +140,7 @@ class LazyBlockList(BlockList): None for _ in self._block_partition_meta_refs ] self._cached_metadata = [None for _ in self._cached_metadata] + self._stats_actor = None def is_cleared(self) -> bool: return all(ref is None for ref in self._block_partition_refs) @@ -537,7 +539,9 @@ class LazyBlockList(BlockList): self, task_idx: int ) -> Tuple[ObjectRef[MaybeBlockPartition], ObjectRef[BlockPartitionMetadata]]: """Submit the task with index task_idx.""" - stats_actor = _get_or_create_stats_actor() + if self._stats_actor is None: + self._stats_actor = _get_or_create_stats_actor() + stats_actor = self._stats_actor if not self._execution_started: stats_actor.record_start.remote(self._stats_uuid) self._execution_started = True diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index f3c5436ae..a6f8b03f4 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -11,6 +11,9 @@ from ray.data.block import BlockMetadata from ray.data.context import DatasetContext from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +STATS_ACTOR_NAME = "datasets_stats_actor" +STATS_ACTOR_NAMESPACE = "_dataset_stats_actor" + def fmt(seconds: float) -> str: if seconds > 1: @@ -84,7 +87,7 @@ class _DatasetStatsBuilder: class _StatsActor: """Actor holding stats for blocks created by LazyBlockList. - This actor is shared across all datasets created by the same process. + This actor is shared across all datasets created in the same cluster. The stats data is small so we don't worry about clean up for now. TODO(ekl) we should consider refactoring LazyBlockList so stats can be @@ -112,40 +115,23 @@ class _StatsActor: ) -# Actor handle, job id the actor was created for. -_stats_actor = [None, None] - - def _get_or_create_stats_actor(): - # Need to re-create it if Ray restarts (mostly for unit tests). - if ( - not _stats_actor[0] - or not ray.is_initialized() - or _stats_actor[1] != ray.get_runtime_context().job_id.hex() - ): - ctx = DatasetContext.get_current() - scheduling_strategy = ctx.scheduling_strategy - if not ray.util.client.ray.is_connected(): - # Pin the stats actor to the local node - # so it fate-shares with the driver. - scheduling_strategy = NodeAffinitySchedulingStrategy( - ray.get_runtime_context().get_node_id(), - soft=False, - ) - - _stats_actor[0] = _StatsActor.options( - name="datasets_stats_actor", - get_if_exists=True, - scheduling_strategy=scheduling_strategy, - ).remote() - _stats_actor[1] = ray.get_runtime_context().job_id.hex() - - # Clear the actor handle after Ray reinits since it's no longer valid. - def clear_actor(): - _stats_actor[0] = None - - ray._private.worker._post_init_hooks.append(clear_actor) - return _stats_actor[0] + ctx = DatasetContext.get_current() + scheduling_strategy = ctx.scheduling_strategy + if not ray.util.client.ray.is_connected(): + # Pin the stats actor to the local node + # so it fate-shares with the driver. + scheduling_strategy = NodeAffinitySchedulingStrategy( + ray.get_runtime_context().get_node_id(), + soft=False, + ) + return _StatsActor.options( + name=STATS_ACTOR_NAME, + namespace=STATS_ACTOR_NAMESPACE, + get_if_exists=True, + lifetime="detached", + scheduling_strategy=scheduling_strategy, + ).remote() class DatasetStats: @@ -223,10 +209,9 @@ class DatasetStats: already_printed = set() if self.needs_stats_actor: + ac = self.stats_actor # XXX this is a super hack, clean it up. - stats_map, self.time_total_s = ray.get( - self.stats_actor.get.remote(self.stats_uuid) - ) + stats_map, self.time_total_s = ray.get(ac.get.remote(self.stats_uuid)) for i, metadata in stats_map.items(): self.stages["read"][i] = metadata out = ""