[serve] Support kwargs to deployment constructor (#19023)

This commit is contained in:
Edward Oakes 2021-10-06 14:16:23 -05:00 committed by GitHub
parent 77d0a08c38
commit 9316a9977f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 17 deletions

View file

@ -188,7 +188,8 @@ class Client:
def deploy(self,
name: str,
backend_def: Union[Callable, Type[Callable], str],
*init_args: Any,
init_args: Tuple[Any],
init_kwargs: Dict[Any, Any],
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None,
version: Optional[str] = None,
@ -212,7 +213,10 @@ class Client:
del ray_actor_options["runtime_env"]["working_dir"]
replica_config = ReplicaConfig(
backend_def, *init_args, ray_actor_options=ray_actor_options)
backend_def,
init_args=init_args,
init_kwargs=init_kwargs,
ray_actor_options=ray_actor_options)
if isinstance(config, dict):
backend_config = BackendConfig.parse_obj(config)
@ -601,6 +605,7 @@ class Deployment:
version: Optional[str] = None,
prev_version: Optional[str] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Tuple[Any]] = None,
route_prefix: Optional[str] = None,
ray_actor_options: Optional[Dict] = None,
_internal=False) -> None:
@ -626,6 +631,8 @@ class Deployment:
raise TypeError("prev_version must be a string.")
if not (init_args is None or isinstance(init_args, tuple)):
raise TypeError("init_args must be a tuple.")
if not (init_kwargs is None or isinstance(init_kwargs, dict)):
raise TypeError("init_kwargs must be a dict.")
if route_prefix is not None:
if not isinstance(route_prefix, str):
raise TypeError("route_prefix must be a string.")
@ -642,6 +649,8 @@ class Deployment:
if init_args is None:
init_args = ()
if init_kwargs is None:
init_kwargs = {}
# TODO(architkulkarni): Enforce that autoscaling_config and
# user-provided num_replicas should be mutually exclusive.
@ -657,6 +666,7 @@ class Deployment:
self._prev_version = prev_version
self._config = config
self._init_args = init_args
self._init_kwargs = init_kwargs
self._route_prefix = route_prefix
self._ray_actor_options = ray_actor_options
@ -714,7 +724,12 @@ class Deployment:
@property
def init_args(self) -> Tuple[Any]:
"""Arguments passed to the underlying class's constructor."""
"""Positional args passed to the underlying class's constructor."""
return self._init_args
@property
def init_kwargs(self) -> Tuple[Any]:
"""Keyword args passed to the underlying class's constructor."""
return self._init_args
@property
@ -728,20 +743,25 @@ class Deployment:
"Use `deployment.deploy() instead.`")
@PublicAPI
def deploy(self, *init_args, _blocking=True):
def deploy(self, *init_args, _blocking=True, **init_kwargs):
"""Deploy or update this deployment.
Args:
init_args (optional): args to pass to the class __init__
method. Not valid if this deployment wraps a function.
init_kwargs (optional): kwargs to pass to the class __init__
method. Not valid if this deployment wraps a function.
"""
if len(init_args) == 0 and self._init_args is not None:
init_args = self._init_args
if len(init_kwargs) == 0 and self._init_kwargs is not None:
init_kwargs = self._init_kwargs
return _get_global_client().deploy(
self._name,
self._func_or_class,
*init_args,
init_args,
init_kwargs,
ray_actor_options=self._ray_actor_options,
config=self._config,
version=self._version,
@ -780,6 +800,7 @@ class Deployment:
version: Optional[str] = None,
prev_version: Optional[str] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
route_prefix: Optional[str] = None,
num_replicas: Optional[int] = None,
ray_actor_options: Optional[Dict] = None,
@ -813,6 +834,9 @@ class Deployment:
if init_args is None:
init_args = self._init_args
if init_kwargs is None:
init_kwargs = self._init_kwargs
if route_prefix is None:
if self._route_prefix == f"/{self._name}":
route_prefix = None
@ -832,6 +856,7 @@ class Deployment:
version=version,
prev_version=prev_version,
init_args=init_args,
init_kwargs=init_kwargs,
route_prefix=route_prefix,
ray_actor_options=ray_actor_options,
_internal=True,
@ -843,6 +868,7 @@ class Deployment:
self._version == other._version,
self._config == other._config,
self._init_args == other._init_args,
self._init_kwargs == other._init_kwargs,
self._route_prefix == other._route_prefix,
self._ray_actor_options == self._ray_actor_options,
])
@ -872,6 +898,7 @@ def deployment(
prev_version: Optional[str] = None,
num_replicas: Optional[int] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
ray_actor_options: Optional[Dict] = None,
user_config: Optional[Any] = None,
max_concurrent_queries: Optional[int] = None,
@ -888,6 +915,7 @@ def deployment(
prev_version: Optional[str] = None,
num_replicas: Optional[int] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
route_prefix: Optional[str] = None,
ray_actor_options: Optional[Dict] = None,
user_config: Optional[Any] = None,
@ -911,7 +939,10 @@ def deployment(
not check the existing deployment's version.
num_replicas (Optional[int]): The number of processes to start up that
will handle requests to this deployment. Defaults to 1.
init_args (Optional[Tuple]): Arguments to be passed to the class
init_args (Optional[Tuple]): Positional args to be passed to the class
constructor when starting up deployment replicas. These can also be
passed when you call `.deploy()` on the returned Deployment.
init_kwargs (Optional[Dict]): Keyword args to be passed to the class
constructor when starting up deployment replicas. These can also be
passed when you call `.deploy()` on the returned Deployment.
route_prefix (Optional[str]): Requests to paths under this HTTP path
@ -968,6 +999,7 @@ def deployment(
version=version,
prev_version=prev_version,
init_args=init_args,
init_kwargs=init_kwargs,
route_prefix=route_prefix,
ray_actor_options=ray_actor_options,
_internal=True,
@ -1009,6 +1041,7 @@ def get_deployment(name: str) -> Deployment:
backend_info.backend_config,
version=backend_info.version,
init_args=backend_info.replica_config.init_args,
init_kwargs=backend_info.replica_config.init_kwargs,
route_prefix=route_prefix,
ray_actor_options=backend_info.replica_config.ray_actor_options,
_internal=True,
@ -1032,6 +1065,7 @@ def list_deployments() -> Dict[str, Deployment]:
backend_info.backend_config,
version=backend_info.version,
init_args=backend_info.replica_config.init_args,
init_kwargs=backend_info.replica_config.init_kwargs,
route_prefix=route_prefix,
ray_actor_options=backend_info.replica_config.ray_actor_options,
_internal=True,

View file

@ -164,6 +164,7 @@ class ActorReplicaWrapper:
**backend_info.replica_config.ray_actor_options).remote(
self.backend_tag, self.replica_tag,
backend_info.replica_config.init_args,
backend_info.replica_config.init_kwargs,
backend_info.backend_config.to_proto_bytes(), version,
self._controller_name, self._detached)

View file

@ -43,7 +43,7 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
class RayServeWrappedReplica(object):
async def __init__(self, backend_tag, replica_tag, init_args,
backend_config_proto_bytes: bytes,
init_kwargs, backend_config_proto_bytes: bytes,
version: BackendVersion, controller_name: str,
detached: bool):
backend = cloudpickle.loads(serialized_backend_def)
@ -72,7 +72,8 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
# This allows backends to define an async __init__ method
# (required for FastAPI backend definition).
_callable = backend.__new__(backend)
await sync_to_async(_callable.__init__)(*init_args)
await sync_to_async(_callable.__init__)(*init_args,
**init_kwargs)
# Setting the context again to update the servable_object.
ray.serve.api._set_internal_replica_context(
backend_tag,

View file

@ -1,7 +1,7 @@
import inspect
import pickle
from enum import Enum
from typing import Any, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import pydantic
from google.protobuf.json_format import MessageToDict
@ -124,16 +124,23 @@ class BackendConfig(BaseModel):
class ReplicaConfig:
def __init__(self, backend_def, *init_args, ray_actor_options=None):
def __init__(self,
backend_def: Callable,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
ray_actor_options=None):
# Validate that backend_def is an import path, function, or class.
if isinstance(backend_def, str):
self.func_or_class_name = backend_def
pass
elif inspect.isfunction(backend_def):
self.func_or_class_name = backend_def.__name__
if len(init_args) != 0:
if init_args:
raise ValueError(
"init_args not supported for function backend.")
if init_kwargs:
raise ValueError(
"init_kwargs not supported for function backend.")
elif inspect.isclass(backend_def):
self.func_or_class_name = backend_def.__name__
else:
@ -142,7 +149,8 @@ class ReplicaConfig:
format(type(backend_def)))
self.serialized_backend_def = cloudpickle.dumps(backend_def)
self.init_args = init_args
self.init_args = init_args if init_args is not None else ()
self.init_kwargs = init_kwargs if init_kwargs is not None else {}
if ray_actor_options is None:
self.ray_actor_options = {}
else:
@ -161,12 +169,13 @@ class ReplicaConfig:
raise TypeError("ray_actor_options must be a dictionary.")
elif "lifetime" in self.ray_actor_options:
raise ValueError(
"Specifying lifetime in init_args is not allowed.")
"Specifying lifetime in ray_actor_options is not allowed.")
elif "name" in self.ray_actor_options:
raise ValueError("Specifying name in init_args is not allowed.")
raise ValueError(
"Specifying name in ray_actor_options is not allowed.")
elif "max_restarts" in self.ray_actor_options:
raise ValueError("Specifying max_restarts in "
"init_args is not allowed.")
"ray_actor_options is not allowed.")
else:
# Ray defaults to zero CPUs for placement, we default to one here.
if "num_cpus" not in self.ray_actor_options:

View file

@ -20,8 +20,8 @@ from ray.serve.common import (
NodeId,
ReplicaTag,
)
from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig)
from ray.serve.constants import (CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY)
from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig
from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY
from ray.serve.endpoint_state import EndpointState
from ray.serve.http_state import HTTPState
from ray.serve.storage.checkpoint_path import make_kv_store

View file

@ -52,6 +52,8 @@ def test_replica_config_validation():
# Check ray_actor_options validation.
ReplicaConfig(
Class,
tuple(),
dict(),
ray_actor_options={
"num_cpus": 1.0,
"num_gpus": 10,

View file

@ -734,6 +734,58 @@ def test_init_args(serve_instance):
check(10, 11, 12)
def test_init_kwargs(serve_instance):
with pytest.raises(TypeError):
@serve.deployment(init_kwargs=[1, 2, 3])
class BadInitArgs:
pass
@serve.deployment(init_kwargs={"a": 1, "b": 2})
class D:
def __init__(self, **kwargs):
self._kwargs = kwargs
def get_kwargs(self, *args):
return self._kwargs
D.deploy()
handle = D.get_handle()
def check(kwargs):
assert ray.get(handle.get_kwargs.remote()) == kwargs
# Basic sanity check.
check({"a": 1, "b": 2})
# Check passing args to `.deploy()`.
D.deploy(a=3, b=4)
check({"a": 3, "b": 4})
# Passing args to `.deploy()` shouldn't override those passed in decorator.
D.deploy()
check({"a": 1, "b": 2})
# Check setting with `.options()`.
new_D = D.options(init_kwargs={"c": 8, "d": 10})
new_D.deploy()
check({"c": 8, "d": 10})
# Should not have changed old deployment object.
D.deploy()
check({"a": 1, "b": 2})
# Check that args are only updated on version change.
D.options(version="1").deploy()
check({"a": 1, "b": 2})
D.options(version="1").deploy(c=10, d=11)
check({"a": 1, "b": 2})
D.options(version="2").deploy(c=10, d=11)
check({"c": 10, "d": 11})
def test_input_validation():
name = "test"

View file

@ -116,6 +116,37 @@ def test_init_args(serve_instance):
assert pid3 != pid2
def test_init_kwargs(serve_instance):
name = "test"
@serve.deployment(name=name)
class D:
def __init__(self, *, val=None):
assert val is not None
self._val = val
def __call__(self, *arg):
return self._val, os.getpid()
D.deploy(val="1")
val1, pid1 = ray.get(D.get_handle().remote())
assert val1 == "1"
del D
D2 = serve.get_deployment(name=name)
D2.deploy()
val2, pid2 = ray.get(D2.get_handle().remote())
assert val2 == "1"
assert pid2 != pid1
D2 = serve.get_deployment(name=name)
D2.deploy(val="2")
val3, pid3 = ray.get(D2.get_handle().remote())
assert val3 == "2"
assert pid3 != pid2
def test_scale_replicas(serve_instance):
name = "test"