mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
187 lines
5.5 KiB
Python
187 lines
5.5 KiB
Python
![]() |
"""
|
||
|
This test is parity of
|
||
|
release/serve_tests/workloads/deployment_graph_long_chain.py
|
||
|
Instead of using graph api, the test is using pure handle to
|
||
|
test long chain graph.
|
||
|
|
||
|
INPUT -> Node_1 -> Node_2 -> ... -> Node_10 -> OUTPUT
|
||
|
|
||
|
1) Intermediate blob size can be large / small
|
||
|
2) Compute time each node can be long / short
|
||
|
3) Init time can be long / short
|
||
|
"""
|
||
|
|
||
|
import time
|
||
|
import asyncio
|
||
|
import click
|
||
|
|
||
|
from typing import Optional
|
||
|
|
||
|
import ray
|
||
|
from ray import serve
|
||
|
from ray.serve.context import get_global_client
|
||
|
from serve_test_cluster_utils import (
|
||
|
setup_local_single_node_cluster,
|
||
|
setup_anyscale_cluster,
|
||
|
)
|
||
|
from serve_test_utils import save_test_results
|
||
|
from benchmark_utils import benchmark_throughput_tps, benchmark_latency_ms
|
||
|
|
||
|
DEFAULT_CHAIN_LENGTH = 4
|
||
|
|
||
|
DEFAULT_NUM_REQUESTS_PER_CLIENT = 20 # request sent for latency test
|
||
|
DEFAULT_NUM_CLIENTS = 1 # Clients concurrently sending request to deployment
|
||
|
|
||
|
DEFAULT_THROUGHPUT_TRIAL_DURATION_SECS = 10
|
||
|
|
||
|
|
||
|
@serve.deployment
|
||
|
class Node:
|
||
|
def __init__(
|
||
|
self,
|
||
|
id: int,
|
||
|
prev_node=None,
|
||
|
init_delay_secs=0,
|
||
|
compute_delay_secs=0,
|
||
|
sync_handle=True,
|
||
|
):
|
||
|
time.sleep(init_delay_secs)
|
||
|
self.id = id
|
||
|
self.prev_node = prev_node
|
||
|
self.compute_delay_secs = compute_delay_secs
|
||
|
self.sync_handle = sync_handle
|
||
|
|
||
|
async def predict(self, input_data: int):
|
||
|
await asyncio.sleep(self.compute_delay_secs)
|
||
|
if self.prev_node:
|
||
|
if self.sync_handle:
|
||
|
return await self.prev_node.predict.remote(input_data) + 1
|
||
|
else:
|
||
|
return await (await self.prev_node.predict.remote(input_data)) + 1
|
||
|
else:
|
||
|
return input_data + 1
|
||
|
|
||
|
|
||
|
def construct_long_chain_graph_with_pure_handle(
|
||
|
chain_length, sync_handle: bool, init_delay_secs=0, compute_delay_secs=0
|
||
|
):
|
||
|
prev_handle = None
|
||
|
for id in range(chain_length):
|
||
|
Node.options(name=str(id)).deploy(
|
||
|
id, prev_handle, init_delay_secs, compute_delay_secs, sync_handle
|
||
|
)
|
||
|
prev_handle = get_global_client().get_handle(str(id), sync=sync_handle)
|
||
|
return prev_handle
|
||
|
|
||
|
|
||
|
async def sanity_check_graph_deployment_with_async_handle(handle, expected_result):
|
||
|
assert await (await handle.predict.remote(0)) == expected_result
|
||
|
|
||
|
|
||
|
@click.command()
|
||
|
@click.option("--chain-length", type=int, default=DEFAULT_CHAIN_LENGTH)
|
||
|
@click.option("--init-delay-secs", type=int, default=0)
|
||
|
@click.option("--compute-delay-secs", type=int, default=0)
|
||
|
@click.option(
|
||
|
"--num-requests-per-client",
|
||
|
type=int,
|
||
|
default=DEFAULT_NUM_REQUESTS_PER_CLIENT,
|
||
|
)
|
||
|
@click.option("--num-clients", type=int, default=DEFAULT_NUM_CLIENTS)
|
||
|
@click.option(
|
||
|
"--throughput-trial-duration-secs",
|
||
|
type=int,
|
||
|
default=DEFAULT_THROUGHPUT_TRIAL_DURATION_SECS,
|
||
|
)
|
||
|
@click.option("--local-test", type=bool, default=True)
|
||
|
@click.option("--sync-handle", type=bool, default=True)
|
||
|
def main(
|
||
|
chain_length: Optional[int],
|
||
|
init_delay_secs: Optional[int],
|
||
|
compute_delay_secs: Optional[int],
|
||
|
num_requests_per_client: Optional[int],
|
||
|
num_clients: Optional[int],
|
||
|
throughput_trial_duration_secs: Optional[int],
|
||
|
local_test: Optional[bool],
|
||
|
sync_handle: Optional[bool],
|
||
|
):
|
||
|
if local_test:
|
||
|
setup_local_single_node_cluster(1, num_cpu_per_node=8)
|
||
|
else:
|
||
|
setup_anyscale_cluster()
|
||
|
|
||
|
handle = construct_long_chain_graph_with_pure_handle(
|
||
|
chain_length,
|
||
|
sync_handle,
|
||
|
init_delay_secs=init_delay_secs,
|
||
|
compute_delay_secs=compute_delay_secs,
|
||
|
)
|
||
|
if sync_handle:
|
||
|
assert ray.get(handle.predict.remote(0)) == chain_length
|
||
|
else:
|
||
|
sanity_check_graph_deployment_with_async_handle(handle, chain_length)
|
||
|
loop = asyncio.get_event_loop()
|
||
|
|
||
|
throughput_mean_tps, throughput_std_tps = loop.run_until_complete(
|
||
|
benchmark_throughput_tps(
|
||
|
handle,
|
||
|
chain_length,
|
||
|
duration_secs=throughput_trial_duration_secs,
|
||
|
num_clients=num_clients,
|
||
|
)
|
||
|
)
|
||
|
latency_mean_ms, latency_std_ms = loop.run_until_complete(
|
||
|
benchmark_latency_ms(
|
||
|
handle,
|
||
|
chain_length,
|
||
|
num_requests=num_requests_per_client,
|
||
|
num_clients=num_clients,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
print(f"chain_length: {chain_length}, num_clients: {num_clients}")
|
||
|
print(f"latency_mean_ms: {latency_mean_ms}, " f"latency_std_ms: {latency_std_ms}")
|
||
|
print(
|
||
|
f"throughput_mean_tps: {throughput_mean_tps}, "
|
||
|
f"throughput_std_tps: {throughput_std_tps}"
|
||
|
)
|
||
|
|
||
|
results = {
|
||
|
"chain_length": chain_length,
|
||
|
"init_delay_secs": init_delay_secs,
|
||
|
"compute_delay_secs": compute_delay_secs,
|
||
|
"local_test": local_test,
|
||
|
"sync_handle": sync_handle,
|
||
|
}
|
||
|
results["perf_metrics"] = [
|
||
|
{
|
||
|
"perf_metric_name": "throughput_mean_tps",
|
||
|
"perf_metric_value": throughput_mean_tps,
|
||
|
"perf_metric_type": "THROUGHPUT",
|
||
|
},
|
||
|
{
|
||
|
"perf_metric_name": "throughput_std_tps",
|
||
|
"perf_metric_value": throughput_std_tps,
|
||
|
"perf_metric_type": "THROUGHPUT",
|
||
|
},
|
||
|
{
|
||
|
"perf_metric_name": "latency_mean_ms",
|
||
|
"perf_metric_value": latency_mean_ms,
|
||
|
"perf_metric_type": "LATENCY",
|
||
|
},
|
||
|
{
|
||
|
"perf_metric_name": "latency_std_ms",
|
||
|
"perf_metric_value": latency_std_ms,
|
||
|
"perf_metric_type": "LATENCY",
|
||
|
},
|
||
|
]
|
||
|
save_test_results(results)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|
||
|
import sys
|
||
|
import pytest
|
||
|
|
||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|