[Serve] Feature flag and turn off placement group usage. (#15865)

This commit is contained in:
Simon Mo 2021-05-19 15:43:46 -07:00 committed by GitHub
parent 4825f1b2a5
commit 7a5981f244
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 20 deletions

View file

@ -3,6 +3,7 @@ import time
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import ray.cloudpickle as pickle import ray.cloudpickle as pickle
@ -33,6 +34,7 @@ class ReplicaState(Enum):
ALL_REPLICA_STATES = list(ReplicaState) ALL_REPLICA_STATES = list(ReplicaState)
USE_PLACEMENT_GROUP = os.environ.get("SERVE_USE_PLACEMENT_GROUP", "0") == "1"
class ActorReplicaWrapper: class ActorReplicaWrapper:
@ -82,6 +84,10 @@ class ActorReplicaWrapper:
def start(self, backend_info: BackendInfo): def start(self, backend_info: BackendInfo):
self._actor_resources = backend_info.replica_config.resource_dict self._actor_resources = backend_info.replica_config.resource_dict
# Feature flagging because of placement groups doesn't handle
# newly added nodes.
# https://github.com/ray-project/ray/issues/15801
if USE_PLACEMENT_GROUP:
try: try:
self._placement_group = ray.util.get_placement_group( self._placement_group = ray.util.get_placement_group(
self._placement_group_name) self._placement_group_name)
@ -137,9 +143,11 @@ class ActorReplicaWrapper:
return True return True
try: try:
ray.get_actor(self._actor_name) handle = ray.get_actor(self._actor_name)
ready, _ = ray.wait([self._drain_obj_ref], timeout=0) ready, _ = ray.wait([self._drain_obj_ref], timeout=0)
self._stopped = len(ready) == 1 self._stopped = len(ready) == 1
if self._stopped:
ray.kill(handle, no_restart=True)
except ValueError: except ValueError:
self._stopped = True self._stopped = True
@ -165,6 +173,9 @@ class ActorReplicaWrapper:
Currently, this just removes the placement group. Currently, this just removes the placement group.
""" """
if not USE_PLACEMENT_GROUP:
return
try: try:
ray.util.remove_placement_group( ray.util.remove_placement_group(
ray.util.get_placement_group(self._placement_group_name)) ray.util.get_placement_group(self._placement_group_name))

View file

@ -305,5 +305,3 @@ class RayServeReplica:
f"Waiting for an additional {sleep_time}s to shut down " f"Waiting for an additional {sleep_time}s to shut down "
f"because there are {self.num_ongoing_requests} " f"because there are {self.num_ongoing_requests} "
"ongoing requests.") "ongoing requests.")
ray.actor.exit_actor()

View file

@ -75,7 +75,7 @@ def test_node_failure(ray_cluster):
worker_node = cluster.add_node(num_cpus=2) worker_node = cluster.add_node(num_cpus=2)
@serve.deployment(version="1", num_replicas=3) @serve.deployment(version="1", num_replicas=5)
def D(*args): def D(*args):
return os.getpid() return os.getpid()
@ -92,24 +92,24 @@ def test_node_failure(ray_cluster):
print("Initial deploy.") print("Initial deploy.")
D.deploy() D.deploy()
pids1 = get_pids(3) pids1 = get_pids(5)
# Remove the node. There should still be one replica running. # Remove the node. There should still be three replicas running.
print("Kill node.") print("Kill node.")
cluster.remove_node(worker_node) cluster.remove_node(worker_node)
pids2 = get_pids(1) pids2 = get_pids(3)
assert pids2.issubset(pids1) assert pids2.issubset(pids1)
# Add a worker node back. One replica should get placed. # Add a worker node back. One replica should get placed.
print("Add back first node.") print("Add back first node.")
cluster.add_node(num_cpus=1) cluster.add_node(num_cpus=1)
pids3 = get_pids(2) pids3 = get_pids(4)
assert pids2.issubset(pids3) assert pids2.issubset(pids3)
# Add another worker node. One more replica should get placed. # Add another worker node. One more replica should get placed.
print("Add back second node.") print("Add back second node.")
cluster.add_node(num_cpus=1) cluster.add_node(num_cpus=1)
pids4 = get_pids(3) pids4 = get_pids(5)
assert pids3.issubset(pids4) assert pids3.issubset(pids4)