[serve] Fix whacky worker replica failure test (#13696)

This commit is contained in:
Edward Oakes 2021-01-27 14:05:37 -06:00 committed by GitHub
parent 2d34e95c93
commit 06fac785b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,13 +1,11 @@
import os
import requests
import sys
import tempfile
import time
import pytest
import ray
from ray.test_utils import wait_for_condition
from ray import serve
from ray.serve.config import BackendConfig, ReplicaConfig
@ -160,34 +158,30 @@ def test_worker_restart(serve_instance):
def test_worker_replica_failure(serve_instance):
client = serve_instance
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def inc_and_get(self):
self.count += 1
return self.count
class Worker:
# Assumes that two replicas are started. Will hang forever in the
# constructor for any workers that are restarted.
def __init__(self, path):
def __init__(self, counter):
self.should_hang = False
if not os.path.exists(path):
with open(path, "w") as f:
f.write("1")
else:
with open(path, "r") as f:
num = int(f.read())
with open(path, "w") as f:
if num == 2:
self.should_hang = True
else:
f.write(str(num + 1))
if self.should_hang:
self.index = ray.get(counter.inc_and_get.remote())
if self.index > 2:
while True:
pass
def __call__(self, *args):
pass
return self.index
temp_path = os.path.join(tempfile.gettempdir(),
serve.utils.get_random_letters())
client.create_backend("replica_failure", Worker, temp_path)
counter = Counter.remote()
client.create_backend("replica_failure", Worker, counter)
client.update_backend_config(
"replica_failure", BackendConfig(num_replicas=2))
client.create_endpoint(
@ -195,9 +189,16 @@ def test_worker_replica_failure(serve_instance):
# Wait until both replicas have been started.
responses = set()
while len(responses) == 1:
responses.add(request_with_retries("/replica_failure", timeout=1).text)
start = time.time()
while time.time() - start < 30:
time.sleep(0.1)
response = request_with_retries("/replica_failure", timeout=1).text
assert response in ["1", "2"]
responses.add(response)
if len(responses) > 1:
break
else:
raise TimeoutError("Timed out waiting for replicas after 30s.")
# Kill one of the replicas.
handles = _get_worker_handles(client, "replica_failure")
@ -263,6 +264,4 @@ def test_create_endpoint_idempotent(serve_instance):
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))