[serve] Support kwargs to deployment constructor (#19023)

This commit is contained in:
Edward Oakes 2021-10-06 14:16:23 -05:00 committed by GitHub
parent 77d0a08c38
commit 9316a9977f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 17 deletions

View file

@ -188,7 +188,8 @@ class Client:
def deploy(self, def deploy(self,
name: str, name: str,
backend_def: Union[Callable, Type[Callable], str], backend_def: Union[Callable, Type[Callable], str],
*init_args: Any, init_args: Tuple[Any],
init_kwargs: Dict[Any, Any],
ray_actor_options: Optional[Dict] = None, ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, config: Optional[Union[BackendConfig, Dict[str, Any]]] = None,
version: Optional[str] = None, version: Optional[str] = None,
@ -212,7 +213,10 @@ class Client:
del ray_actor_options["runtime_env"]["working_dir"] del ray_actor_options["runtime_env"]["working_dir"]
replica_config = ReplicaConfig( replica_config = ReplicaConfig(
backend_def, *init_args, ray_actor_options=ray_actor_options) backend_def,
init_args=init_args,
init_kwargs=init_kwargs,
ray_actor_options=ray_actor_options)
if isinstance(config, dict): if isinstance(config, dict):
backend_config = BackendConfig.parse_obj(config) backend_config = BackendConfig.parse_obj(config)
@ -601,6 +605,7 @@ class Deployment:
version: Optional[str] = None, version: Optional[str] = None,
prev_version: Optional[str] = None, prev_version: Optional[str] = None,
init_args: Optional[Tuple[Any]] = None, init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Tuple[Any]] = None,
route_prefix: Optional[str] = None, route_prefix: Optional[str] = None,
ray_actor_options: Optional[Dict] = None, ray_actor_options: Optional[Dict] = None,
_internal=False) -> None: _internal=False) -> None:
@ -626,6 +631,8 @@ class Deployment:
raise TypeError("prev_version must be a string.") raise TypeError("prev_version must be a string.")
if not (init_args is None or isinstance(init_args, tuple)): if not (init_args is None or isinstance(init_args, tuple)):
raise TypeError("init_args must be a tuple.") raise TypeError("init_args must be a tuple.")
if not (init_kwargs is None or isinstance(init_kwargs, dict)):
raise TypeError("init_kwargs must be a dict.")
if route_prefix is not None: if route_prefix is not None:
if not isinstance(route_prefix, str): if not isinstance(route_prefix, str):
raise TypeError("route_prefix must be a string.") raise TypeError("route_prefix must be a string.")
@ -642,6 +649,8 @@ class Deployment:
if init_args is None: if init_args is None:
init_args = () init_args = ()
if init_kwargs is None:
init_kwargs = {}
# TODO(architkulkarni): Enforce that autoscaling_config and # TODO(architkulkarni): Enforce that autoscaling_config and
# user-provided num_replicas should be mutually exclusive. # user-provided num_replicas should be mutually exclusive.
@ -657,6 +666,7 @@ class Deployment:
self._prev_version = prev_version self._prev_version = prev_version
self._config = config self._config = config
self._init_args = init_args self._init_args = init_args
self._init_kwargs = init_kwargs
self._route_prefix = route_prefix self._route_prefix = route_prefix
self._ray_actor_options = ray_actor_options self._ray_actor_options = ray_actor_options
@ -714,7 +724,12 @@ class Deployment:
@property @property
def init_args(self) -> Tuple[Any]: def init_args(self) -> Tuple[Any]:
"""Arguments passed to the underlying class's constructor.""" """Positional args passed to the underlying class's constructor."""
return self._init_args
@property
def init_kwargs(self) -> Tuple[Any]:
"""Keyword args passed to the underlying class's constructor."""
return self._init_args return self._init_args
@property @property
@ -728,20 +743,25 @@ class Deployment:
"Use `deployment.deploy() instead.`") "Use `deployment.deploy() instead.`")
@PublicAPI @PublicAPI
def deploy(self, *init_args, _blocking=True): def deploy(self, *init_args, _blocking=True, **init_kwargs):
"""Deploy or update this deployment. """Deploy or update this deployment.
Args: Args:
init_args (optional): args to pass to the class __init__ init_args (optional): args to pass to the class __init__
method. Not valid if this deployment wraps a function. method. Not valid if this deployment wraps a function.
init_kwargs (optional): kwargs to pass to the class __init__
method. Not valid if this deployment wraps a function.
""" """
if len(init_args) == 0 and self._init_args is not None: if len(init_args) == 0 and self._init_args is not None:
init_args = self._init_args init_args = self._init_args
if len(init_kwargs) == 0 and self._init_kwargs is not None:
init_kwargs = self._init_kwargs
return _get_global_client().deploy( return _get_global_client().deploy(
self._name, self._name,
self._func_or_class, self._func_or_class,
*init_args, init_args,
init_kwargs,
ray_actor_options=self._ray_actor_options, ray_actor_options=self._ray_actor_options,
config=self._config, config=self._config,
version=self._version, version=self._version,
@ -780,6 +800,7 @@ class Deployment:
version: Optional[str] = None, version: Optional[str] = None,
prev_version: Optional[str] = None, prev_version: Optional[str] = None,
init_args: Optional[Tuple[Any]] = None, init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
route_prefix: Optional[str] = None, route_prefix: Optional[str] = None,
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
ray_actor_options: Optional[Dict] = None, ray_actor_options: Optional[Dict] = None,
@ -813,6 +834,9 @@ class Deployment:
if init_args is None: if init_args is None:
init_args = self._init_args init_args = self._init_args
if init_kwargs is None:
init_kwargs = self._init_kwargs
if route_prefix is None: if route_prefix is None:
if self._route_prefix == f"/{self._name}": if self._route_prefix == f"/{self._name}":
route_prefix = None route_prefix = None
@ -832,6 +856,7 @@ class Deployment:
version=version, version=version,
prev_version=prev_version, prev_version=prev_version,
init_args=init_args, init_args=init_args,
init_kwargs=init_kwargs,
route_prefix=route_prefix, route_prefix=route_prefix,
ray_actor_options=ray_actor_options, ray_actor_options=ray_actor_options,
_internal=True, _internal=True,
@ -843,6 +868,7 @@ class Deployment:
self._version == other._version, self._version == other._version,
self._config == other._config, self._config == other._config,
self._init_args == other._init_args, self._init_args == other._init_args,
self._init_kwargs == other._init_kwargs,
self._route_prefix == other._route_prefix, self._route_prefix == other._route_prefix,
self._ray_actor_options == self._ray_actor_options, self._ray_actor_options == self._ray_actor_options,
]) ])
@ -872,6 +898,7 @@ def deployment(
prev_version: Optional[str] = None, prev_version: Optional[str] = None,
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
init_args: Optional[Tuple[Any]] = None, init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
ray_actor_options: Optional[Dict] = None, ray_actor_options: Optional[Dict] = None,
user_config: Optional[Any] = None, user_config: Optional[Any] = None,
max_concurrent_queries: Optional[int] = None, max_concurrent_queries: Optional[int] = None,
@ -888,6 +915,7 @@ def deployment(
prev_version: Optional[str] = None, prev_version: Optional[str] = None,
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
init_args: Optional[Tuple[Any]] = None, init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
route_prefix: Optional[str] = None, route_prefix: Optional[str] = None,
ray_actor_options: Optional[Dict] = None, ray_actor_options: Optional[Dict] = None,
user_config: Optional[Any] = None, user_config: Optional[Any] = None,
@ -911,7 +939,10 @@ def deployment(
not check the existing deployment's version. not check the existing deployment's version.
num_replicas (Optional[int]): The number of processes to start up that num_replicas (Optional[int]): The number of processes to start up that
will handle requests to this deployment. Defaults to 1. will handle requests to this deployment. Defaults to 1.
init_args (Optional[Tuple]): Arguments to be passed to the class init_args (Optional[Tuple]): Positional args to be passed to the class
constructor when starting up deployment replicas. These can also be
passed when you call `.deploy()` on the returned Deployment.
init_kwargs (Optional[Dict]): Keyword args to be passed to the class
constructor when starting up deployment replicas. These can also be constructor when starting up deployment replicas. These can also be
passed when you call `.deploy()` on the returned Deployment. passed when you call `.deploy()` on the returned Deployment.
route_prefix (Optional[str]): Requests to paths under this HTTP path route_prefix (Optional[str]): Requests to paths under this HTTP path
@ -968,6 +999,7 @@ def deployment(
version=version, version=version,
prev_version=prev_version, prev_version=prev_version,
init_args=init_args, init_args=init_args,
init_kwargs=init_kwargs,
route_prefix=route_prefix, route_prefix=route_prefix,
ray_actor_options=ray_actor_options, ray_actor_options=ray_actor_options,
_internal=True, _internal=True,
@ -1009,6 +1041,7 @@ def get_deployment(name: str) -> Deployment:
backend_info.backend_config, backend_info.backend_config,
version=backend_info.version, version=backend_info.version,
init_args=backend_info.replica_config.init_args, init_args=backend_info.replica_config.init_args,
init_kwargs=backend_info.replica_config.init_kwargs,
route_prefix=route_prefix, route_prefix=route_prefix,
ray_actor_options=backend_info.replica_config.ray_actor_options, ray_actor_options=backend_info.replica_config.ray_actor_options,
_internal=True, _internal=True,
@ -1032,6 +1065,7 @@ def list_deployments() -> Dict[str, Deployment]:
backend_info.backend_config, backend_info.backend_config,
version=backend_info.version, version=backend_info.version,
init_args=backend_info.replica_config.init_args, init_args=backend_info.replica_config.init_args,
init_kwargs=backend_info.replica_config.init_kwargs,
route_prefix=route_prefix, route_prefix=route_prefix,
ray_actor_options=backend_info.replica_config.ray_actor_options, ray_actor_options=backend_info.replica_config.ray_actor_options,
_internal=True, _internal=True,

View file

@ -164,6 +164,7 @@ class ActorReplicaWrapper:
**backend_info.replica_config.ray_actor_options).remote( **backend_info.replica_config.ray_actor_options).remote(
self.backend_tag, self.replica_tag, self.backend_tag, self.replica_tag,
backend_info.replica_config.init_args, backend_info.replica_config.init_args,
backend_info.replica_config.init_kwargs,
backend_info.backend_config.to_proto_bytes(), version, backend_info.backend_config.to_proto_bytes(), version,
self._controller_name, self._detached) self._controller_name, self._detached)

View file

@ -43,7 +43,7 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
# TODO(architkulkarni): Add type hints after upgrading cloudpickle # TODO(architkulkarni): Add type hints after upgrading cloudpickle
class RayServeWrappedReplica(object): class RayServeWrappedReplica(object):
async def __init__(self, backend_tag, replica_tag, init_args, async def __init__(self, backend_tag, replica_tag, init_args,
backend_config_proto_bytes: bytes, init_kwargs, backend_config_proto_bytes: bytes,
version: BackendVersion, controller_name: str, version: BackendVersion, controller_name: str,
detached: bool): detached: bool):
backend = cloudpickle.loads(serialized_backend_def) backend = cloudpickle.loads(serialized_backend_def)
@ -72,7 +72,8 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
# This allows backends to define an async __init__ method # This allows backends to define an async __init__ method
# (required for FastAPI backend definition). # (required for FastAPI backend definition).
_callable = backend.__new__(backend) _callable = backend.__new__(backend)
await sync_to_async(_callable.__init__)(*init_args) await sync_to_async(_callable.__init__)(*init_args,
**init_kwargs)
# Setting the context again to update the servable_object. # Setting the context again to update the servable_object.
ray.serve.api._set_internal_replica_context( ray.serve.api._set_internal_replica_context(
backend_tag, backend_tag,

View file

@ -1,7 +1,7 @@
import inspect import inspect
import pickle import pickle
from enum import Enum from enum import Enum
from typing import Any, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import pydantic import pydantic
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
@ -124,16 +124,23 @@ class BackendConfig(BaseModel):
class ReplicaConfig: class ReplicaConfig:
def __init__(self, backend_def, *init_args, ray_actor_options=None): def __init__(self,
backend_def: Callable,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any, Any]] = None,
ray_actor_options=None):
# Validate that backend_def is an import path, function, or class. # Validate that backend_def is an import path, function, or class.
if isinstance(backend_def, str): if isinstance(backend_def, str):
self.func_or_class_name = backend_def self.func_or_class_name = backend_def
pass pass
elif inspect.isfunction(backend_def): elif inspect.isfunction(backend_def):
self.func_or_class_name = backend_def.__name__ self.func_or_class_name = backend_def.__name__
if len(init_args) != 0: if init_args:
raise ValueError( raise ValueError(
"init_args not supported for function backend.") "init_args not supported for function backend.")
if init_kwargs:
raise ValueError(
"init_kwargs not supported for function backend.")
elif inspect.isclass(backend_def): elif inspect.isclass(backend_def):
self.func_or_class_name = backend_def.__name__ self.func_or_class_name = backend_def.__name__
else: else:
@ -142,7 +149,8 @@ class ReplicaConfig:
format(type(backend_def))) format(type(backend_def)))
self.serialized_backend_def = cloudpickle.dumps(backend_def) self.serialized_backend_def = cloudpickle.dumps(backend_def)
self.init_args = init_args self.init_args = init_args if init_args is not None else ()
self.init_kwargs = init_kwargs if init_kwargs is not None else {}
if ray_actor_options is None: if ray_actor_options is None:
self.ray_actor_options = {} self.ray_actor_options = {}
else: else:
@ -161,12 +169,13 @@ class ReplicaConfig:
raise TypeError("ray_actor_options must be a dictionary.") raise TypeError("ray_actor_options must be a dictionary.")
elif "lifetime" in self.ray_actor_options: elif "lifetime" in self.ray_actor_options:
raise ValueError( raise ValueError(
"Specifying lifetime in init_args is not allowed.") "Specifying lifetime in ray_actor_options is not allowed.")
elif "name" in self.ray_actor_options: elif "name" in self.ray_actor_options:
raise ValueError("Specifying name in init_args is not allowed.") raise ValueError(
"Specifying name in ray_actor_options is not allowed.")
elif "max_restarts" in self.ray_actor_options: elif "max_restarts" in self.ray_actor_options:
raise ValueError("Specifying max_restarts in " raise ValueError("Specifying max_restarts in "
"init_args is not allowed.") "ray_actor_options is not allowed.")
else: else:
# Ray defaults to zero CPUs for placement, we default to one here. # Ray defaults to zero CPUs for placement, we default to one here.
if "num_cpus" not in self.ray_actor_options: if "num_cpus" not in self.ray_actor_options:

View file

@ -20,8 +20,8 @@ from ray.serve.common import (
NodeId, NodeId,
ReplicaTag, ReplicaTag,
) )
from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig) from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig
from ray.serve.constants import (CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY) from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY
from ray.serve.endpoint_state import EndpointState from ray.serve.endpoint_state import EndpointState
from ray.serve.http_state import HTTPState from ray.serve.http_state import HTTPState
from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.storage.checkpoint_path import make_kv_store

View file

@ -52,6 +52,8 @@ def test_replica_config_validation():
# Check ray_actor_options validation. # Check ray_actor_options validation.
ReplicaConfig( ReplicaConfig(
Class, Class,
tuple(),
dict(),
ray_actor_options={ ray_actor_options={
"num_cpus": 1.0, "num_cpus": 1.0,
"num_gpus": 10, "num_gpus": 10,

View file

@ -734,6 +734,58 @@ def test_init_args(serve_instance):
check(10, 11, 12) check(10, 11, 12)
def test_init_kwargs(serve_instance):
with pytest.raises(TypeError):
@serve.deployment(init_kwargs=[1, 2, 3])
class BadInitArgs:
pass
@serve.deployment(init_kwargs={"a": 1, "b": 2})
class D:
def __init__(self, **kwargs):
self._kwargs = kwargs
def get_kwargs(self, *args):
return self._kwargs
D.deploy()
handle = D.get_handle()
def check(kwargs):
assert ray.get(handle.get_kwargs.remote()) == kwargs
# Basic sanity check.
check({"a": 1, "b": 2})
# Check passing args to `.deploy()`.
D.deploy(a=3, b=4)
check({"a": 3, "b": 4})
# Passing args to `.deploy()` shouldn't override those passed in decorator.
D.deploy()
check({"a": 1, "b": 2})
# Check setting with `.options()`.
new_D = D.options(init_kwargs={"c": 8, "d": 10})
new_D.deploy()
check({"c": 8, "d": 10})
# Should not have changed old deployment object.
D.deploy()
check({"a": 1, "b": 2})
# Check that args are only updated on version change.
D.options(version="1").deploy()
check({"a": 1, "b": 2})
D.options(version="1").deploy(c=10, d=11)
check({"a": 1, "b": 2})
D.options(version="2").deploy(c=10, d=11)
check({"c": 10, "d": 11})
def test_input_validation(): def test_input_validation():
name = "test" name = "test"

View file

@ -116,6 +116,37 @@ def test_init_args(serve_instance):
assert pid3 != pid2 assert pid3 != pid2
def test_init_kwargs(serve_instance):
name = "test"
@serve.deployment(name=name)
class D:
def __init__(self, *, val=None):
assert val is not None
self._val = val
def __call__(self, *arg):
return self._val, os.getpid()
D.deploy(val="1")
val1, pid1 = ray.get(D.get_handle().remote())
assert val1 == "1"
del D
D2 = serve.get_deployment(name=name)
D2.deploy()
val2, pid2 = ray.get(D2.get_handle().remote())
assert val2 == "1"
assert pid2 != pid1
D2 = serve.get_deployment(name=name)
D2.deploy(val="2")
val3, pid3 = ray.get(D2.get_handle().remote())
assert val3 == "2"
assert pid3 != pid2
def test_scale_replicas(serve_instance): def test_scale_replicas(serve_instance):
name = "test" name = "test"