[Serve] User custom class name for replica class (#26574)

This commit is contained in:
Simon Mo 2022-07-14 20:10:56 -07:00 committed by GitHub
parent a304d1c145
commit df9f891416
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 4 deletions

View file

@ -38,6 +38,10 @@ from ray.serve.version import DeploymentVersion
logger = logging.getLogger(SERVE_LOGGER_NAME)
def _format_replica_actor_name(deployment_name: str):
return f"ServeReplica:{deployment_name}"
def create_replica_wrapper(name: str):
"""Creates a replica class wrapping the provided function or class.
@ -211,8 +215,13 @@ def create_replica_wrapper(name: str):
async def check_health(self):
await self.replica.check_health()
RayServeWrappedReplica.__name__ = name
return RayServeWrappedReplica
# Dynamically create a new class with custom name here so Ray picks it up
# correctly in actor metadata table and observability stack.
return type(
_format_replica_actor_name(name),
(RayServeWrappedReplica,),
dict(RayServeWrappedReplica.__dict__),
)
class RayServeReplica:
@ -333,7 +342,9 @@ class RayServeReplica:
def _get_handle_request_stats(self) -> Optional[Dict[str, int]]:
actor_stats = ray.runtime_context.get_runtime_context()._get_actor_call_stats()
method_stat = actor_stats.get("RayServeWrappedReplica.handle_request")
method_stat = actor_stats.get(
f"{_format_replica_actor_name(self.deployment_name)}.handle_request"
)
return method_stat
def _collect_autoscaling_metrics(self):

View file

@ -181,7 +181,7 @@ def test_intelligent_scale_down(ray_cluster):
actors = ray._private.state.actors()
node_to_actors = defaultdict(list)
for actor in actors.values():
if "RayServeWrappedReplica" not in actor["ActorClassName"]:
if "ServeReplica" not in actor["ActorClassName"]:
continue
if actor["State"] != "ALIVE":
continue

View file

@ -7,6 +7,7 @@ import ray
from ray import serve
from ray._private.test_utils import wait_for_condition
from ray.serve.utils import block_until_http_ready
import ray.experimental.state.api as state_api
def test_serve_metrics_for_successful_connection(serve_instance):
@ -142,6 +143,19 @@ def test_http_metrics(serve_instance):
verify_error_count(do_assert=True)
def test_actor_summary(serve_instance):
@serve.deployment
def f():
pass
serve.run(f.bind())
actors = state_api.list_actors()
class_names = {actor["class_name"] for actor in actors}
assert class_names.issuperset(
{"ServeController", "HTTPProxyActor", "ServeReplica:f"}
)
if __name__ == "__main__":
import sys