mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
303 lines
11 KiB
Python
303 lines
11 KiB
Python
![]() |
import base64
|
||
|
|
||
|
from collections import defaultdict
|
||
|
from enum import Enum
|
||
|
from typing import List
|
||
|
|
||
|
import ray
|
||
|
|
||
|
from ray._raylet import (TaskID, ActorID, JobID)
|
||
|
import logging
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
# These values are used to calculate if objectRefs are actor handles.
|
||
|
TASKID_BYTES_SIZE = TaskID.size()
|
||
|
ACTORID_BYTES_SIZE = ActorID.size()
|
||
|
JOBID_BYTES_SIZE = JobID.size()
|
||
|
# We need to multiply 2 because we need bits size instead of bytes size.
|
||
|
TASKID_RANDOM_BITS_SIZE = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2
|
||
|
ACTORID_RANDOM_BITS_SIZE = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2
|
||
|
|
||
|
|
||
|
def decode_object_ref_if_needed(object_ref: str) -> bytes:
|
||
|
"""Decode objectRef bytes string.
|
||
|
|
||
|
gRPC reply contains an objectRef that is encodded by Base64.
|
||
|
This function is used to decode the objectRef.
|
||
|
Note that there are times that objectRef is already decoded as
|
||
|
a hex string. In this case, just convert it to a binary number.
|
||
|
"""
|
||
|
if object_ref.endswith("="):
|
||
|
# If the object ref ends with =, that means it is base64 encoded.
|
||
|
# Object refs will always have = as a padding
|
||
|
# when it is base64 encoded because objectRef is always 20B.
|
||
|
return base64.standard_b64decode(object_ref)
|
||
|
else:
|
||
|
return ray.utils.hex_to_binary(object_ref)
|
||
|
|
||
|
|
||
|
class SortingType(Enum):
|
||
|
PID = 1
|
||
|
OBJECT_SIZE = 3
|
||
|
REFERENCE_TYPE = 4
|
||
|
|
||
|
|
||
|
class GroupByType(Enum):
|
||
|
NODE_ADDRESS = "node"
|
||
|
STACK_TRACE = "stack_trace"
|
||
|
|
||
|
|
||
|
class ReferenceType:
|
||
|
# We don't use enum because enum is not json serializable.
|
||
|
ACTOR_HANDLE = "ACTOR_HANDLE"
|
||
|
PINNED_IN_MEMORY = "PINNED_IN_MEMORY"
|
||
|
LOCAL_REFERENCE = "LOCAL_REFERENCE"
|
||
|
USED_BY_PENDING_TASK = "USED_BY_PENDING_TASK"
|
||
|
CAPTURED_IN_OBJECT = "CAPTURED_IN_OBJECT"
|
||
|
UNKNOWN_STATUS = "UNKNOWN_STATUS"
|
||
|
|
||
|
|
||
|
class MemoryTableEntry:
|
||
|
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool,
|
||
|
pid: int):
|
||
|
# worker info
|
||
|
self.is_driver = is_driver
|
||
|
self.pid = pid
|
||
|
self.node_address = node_address
|
||
|
|
||
|
# object info
|
||
|
self.object_size = int(object_ref.get("objectSize", -1))
|
||
|
self.call_site = object_ref.get("callSite", "<Unknown>")
|
||
|
self.object_ref = ray.ObjectRef(
|
||
|
decode_object_ref_if_needed(object_ref["objectId"]))
|
||
|
|
||
|
# reference info
|
||
|
self.local_ref_count = int(object_ref.get("localRefCount", 0))
|
||
|
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
|
||
|
self.submitted_task_ref_count = int(
|
||
|
object_ref.get("submittedTaskRefCount", 0))
|
||
|
self.contained_in_owned = [
|
||
|
ray.ObjectRef(decode_object_ref_if_needed(object_ref))
|
||
|
for object_ref in object_ref.get("containedInOwned", [])
|
||
|
]
|
||
|
self.reference_type = self._get_reference_type()
|
||
|
|
||
|
def is_valid(self) -> bool:
|
||
|
# If the entry doesn't have a reference type or some invalid state,
|
||
|
# (e.g., no object ref presented), it is considered invalid.
|
||
|
if (not self.pinned_in_memory and self.local_ref_count == 0
|
||
|
and self.submitted_task_ref_count == 0
|
||
|
and len(self.contained_in_owned) == 0):
|
||
|
return False
|
||
|
elif self.object_ref.is_nil():
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
def group_key(self, group_by_type: GroupByType) -> str:
|
||
|
if group_by_type == GroupByType.NODE_ADDRESS:
|
||
|
return self.node_address
|
||
|
elif group_by_type == GroupByType.STACK_TRACE:
|
||
|
return self.call_site
|
||
|
else:
|
||
|
raise ValueError(f"group by type {group_by_type} is invalid.")
|
||
|
|
||
|
def _get_reference_type(self) -> str:
|
||
|
if self._is_object_ref_actor_handle():
|
||
|
return ReferenceType.ACTOR_HANDLE
|
||
|
if self.pinned_in_memory:
|
||
|
return ReferenceType.PINNED_IN_MEMORY
|
||
|
elif self.submitted_task_ref_count > 0:
|
||
|
return ReferenceType.USED_BY_PENDING_TASK
|
||
|
elif self.local_ref_count > 0:
|
||
|
return ReferenceType.LOCAL_REFERENCE
|
||
|
elif len(self.contained_in_owned) > 0:
|
||
|
return ReferenceType.CAPTURED_IN_OBJECT
|
||
|
else:
|
||
|
return ReferenceType.UNKNOWN_STATUS
|
||
|
|
||
|
def _is_object_ref_actor_handle(self) -> bool:
|
||
|
object_ref_hex = self.object_ref.hex()
|
||
|
|
||
|
# random (8B) | ActorID(6B) | flag (2B) | index (6B)
|
||
|
# ActorID(6B) == ActorRandomByte(4B) + JobID(2B)
|
||
|
# If random bytes are all 'f', but ActorRandomBytes
|
||
|
# are not all 'f', that means it is an actor creation
|
||
|
# task, which is an actor handle.
|
||
|
random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE]
|
||
|
actor_random_bits = object_ref_hex[TASKID_RANDOM_BITS_SIZE:
|
||
|
TASKID_RANDOM_BITS_SIZE +
|
||
|
ACTORID_RANDOM_BITS_SIZE]
|
||
|
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 8):
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def as_dict(self):
|
||
|
return {
|
||
|
"object_ref": self.object_ref.hex(),
|
||
|
"pid": self.pid,
|
||
|
"node_ip_address": self.node_address,
|
||
|
"object_size": self.object_size,
|
||
|
"reference_type": self.reference_type,
|
||
|
"call_site": self.call_site,
|
||
|
"local_ref_count": self.local_ref_count,
|
||
|
"pinned_in_memory": self.pinned_in_memory,
|
||
|
"submitted_task_ref_count": self.submitted_task_ref_count,
|
||
|
"contained_in_owned": [
|
||
|
object_ref.hex() for object_ref in self.contained_in_owned
|
||
|
],
|
||
|
"type": "Driver" if self.is_driver else "Worker"
|
||
|
}
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.__repr__()
|
||
|
|
||
|
def __repr__(self):
|
||
|
return str(self.as_dict())
|
||
|
|
||
|
|
||
|
class MemoryTable:
|
||
|
def __init__(self,
|
||
|
entries: List[MemoryTableEntry],
|
||
|
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
|
||
|
sort_by_type: SortingType = SortingType.PID):
|
||
|
self.table = entries
|
||
|
# Group is a list of memory tables grouped by a group key.
|
||
|
self.group = {}
|
||
|
self.summary = defaultdict(int)
|
||
|
# NOTE YOU MUST SORT TABLE BEFORE GROUPING.
|
||
|
# self._group_by(..)._sort_by(..) != self._sort_by(..)._group_by(..)
|
||
|
if group_by_type and sort_by_type:
|
||
|
self.setup(group_by_type, sort_by_type)
|
||
|
elif group_by_type:
|
||
|
self._group_by(group_by_type)
|
||
|
elif sort_by_type:
|
||
|
self._sort_by(sort_by_type)
|
||
|
|
||
|
def setup(self, group_by_type: GroupByType, sort_by_type: SortingType):
|
||
|
"""Setup memory table.
|
||
|
|
||
|
This will sort entries first and group them after.
|
||
|
Sort order will be still kept.
|
||
|
"""
|
||
|
self._sort_by(sort_by_type)._group_by(group_by_type)
|
||
|
for group_memory_table in self.group.values():
|
||
|
group_memory_table.summarize()
|
||
|
self.summarize()
|
||
|
return self
|
||
|
|
||
|
def insert_entry(self, entry: MemoryTableEntry):
|
||
|
self.table.append(entry)
|
||
|
|
||
|
def summarize(self):
|
||
|
# Reset summary.
|
||
|
total_object_size = 0
|
||
|
total_local_ref_count = 0
|
||
|
total_pinned_in_memory = 0
|
||
|
total_used_by_pending_task = 0
|
||
|
total_captured_in_objects = 0
|
||
|
total_actor_handles = 0
|
||
|
|
||
|
for entry in self.table:
|
||
|
if entry.object_size > 0:
|
||
|
total_object_size += entry.object_size
|
||
|
if entry.reference_type == ReferenceType.LOCAL_REFERENCE:
|
||
|
total_local_ref_count += 1
|
||
|
elif entry.reference_type == ReferenceType.PINNED_IN_MEMORY:
|
||
|
total_pinned_in_memory += 1
|
||
|
elif entry.reference_type == ReferenceType.USED_BY_PENDING_TASK:
|
||
|
total_used_by_pending_task += 1
|
||
|
elif entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT:
|
||
|
total_captured_in_objects += 1
|
||
|
elif entry.reference_type == ReferenceType.ACTOR_HANDLE:
|
||
|
total_actor_handles += 1
|
||
|
|
||
|
self.summary = {
|
||
|
"total_object_size": total_object_size,
|
||
|
"total_local_ref_count": total_local_ref_count,
|
||
|
"total_pinned_in_memory": total_pinned_in_memory,
|
||
|
"total_used_by_pending_task": total_used_by_pending_task,
|
||
|
"total_captured_in_objects": total_captured_in_objects,
|
||
|
"total_actor_handles": total_actor_handles
|
||
|
}
|
||
|
return self
|
||
|
|
||
|
def _sort_by(self, sorting_type: SortingType):
|
||
|
if sorting_type == SortingType.PID:
|
||
|
self.table.sort(key=lambda entry: entry.pid)
|
||
|
elif sorting_type == SortingType.OBJECT_SIZE:
|
||
|
self.table.sort(key=lambda entry: entry.object_size)
|
||
|
elif sorting_type == SortingType.REFERENCE_TYPE:
|
||
|
self.table.sort(key=lambda entry: entry.reference_type)
|
||
|
else:
|
||
|
raise ValueError(f"Give sorting type: {sorting_type} is invalid.")
|
||
|
return self
|
||
|
|
||
|
def _group_by(self, group_by_type: GroupByType):
|
||
|
"""Group entries and summarize the result.
|
||
|
|
||
|
NOTE: Each group is another MemoryTable.
|
||
|
"""
|
||
|
# Reset group
|
||
|
self.group = {}
|
||
|
|
||
|
# Build entries per group.
|
||
|
group = defaultdict(list)
|
||
|
for entry in self.table:
|
||
|
group[entry.group_key(group_by_type)].append(entry)
|
||
|
|
||
|
# Build a group table.
|
||
|
for group_key, entries in group.items():
|
||
|
self.group[group_key] = MemoryTable(
|
||
|
entries, group_by_type=None, sort_by_type=None)
|
||
|
for group_key, group_memory_table in self.group.items():
|
||
|
group_memory_table.summarize()
|
||
|
return self
|
||
|
|
||
|
def as_dict(self):
|
||
|
return {
|
||
|
"summary": self.summary,
|
||
|
"group": {
|
||
|
group_key: {
|
||
|
"entries": group_memory_table.get_entries(),
|
||
|
"summary": group_memory_table.summary
|
||
|
}
|
||
|
for group_key, group_memory_table in self.group.items()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
def get_entries(self) -> List[dict]:
|
||
|
return [entry.__dict__() for entry in self.table]
|
||
|
|
||
|
def __repr__(self):
|
||
|
return str(self.__dict__())
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.__repr__()
|
||
|
|
||
|
|
||
|
def construct_memory_table(workers_info: List,
|
||
|
group_by: GroupByType = GroupByType.NODE_ADDRESS,
|
||
|
sort_by=SortingType.OBJECT_SIZE) -> MemoryTable:
|
||
|
memory_table_entries = []
|
||
|
for worker_info in workers_info:
|
||
|
pid = worker_info["pid"]
|
||
|
is_driver = worker_info.get("isDriver", False)
|
||
|
core_worker_stats = worker_info["coreWorkerStats"]
|
||
|
node_address = core_worker_stats["ipAddress"]
|
||
|
object_refs = core_worker_stats.get("objectRefs", [])
|
||
|
|
||
|
for object_ref in object_refs:
|
||
|
memory_table_entry = MemoryTableEntry(
|
||
|
object_ref=object_ref,
|
||
|
node_address=node_address,
|
||
|
is_driver=is_driver,
|
||
|
pid=pid)
|
||
|
if memory_table_entry.is_valid():
|
||
|
memory_table_entries.append(memory_table_entry)
|
||
|
memory_table = MemoryTable(
|
||
|
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by)
|
||
|
return memory_table
|