mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[serve] Support kwargs to deployment constructor (#19023)
This commit is contained in:
parent
77d0a08c38
commit
9316a9977f
8 changed files with 147 additions and 17 deletions
|
@ -188,7 +188,8 @@ class Client:
|
|||
def deploy(self,
|
||||
name: str,
|
||||
backend_def: Union[Callable, Type[Callable], str],
|
||||
*init_args: Any,
|
||||
init_args: Tuple[Any],
|
||||
init_kwargs: Dict[Any, Any],
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None,
|
||||
version: Optional[str] = None,
|
||||
|
@ -212,7 +213,10 @@ class Client:
|
|||
del ray_actor_options["runtime_env"]["working_dir"]
|
||||
|
||||
replica_config = ReplicaConfig(
|
||||
backend_def, *init_args, ray_actor_options=ray_actor_options)
|
||||
backend_def,
|
||||
init_args=init_args,
|
||||
init_kwargs=init_kwargs,
|
||||
ray_actor_options=ray_actor_options)
|
||||
|
||||
if isinstance(config, dict):
|
||||
backend_config = BackendConfig.parse_obj(config)
|
||||
|
@ -601,6 +605,7 @@ class Deployment:
|
|||
version: Optional[str] = None,
|
||||
prev_version: Optional[str] = None,
|
||||
init_args: Optional[Tuple[Any]] = None,
|
||||
init_kwargs: Optional[Tuple[Any]] = None,
|
||||
route_prefix: Optional[str] = None,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
_internal=False) -> None:
|
||||
|
@ -626,6 +631,8 @@ class Deployment:
|
|||
raise TypeError("prev_version must be a string.")
|
||||
if not (init_args is None or isinstance(init_args, tuple)):
|
||||
raise TypeError("init_args must be a tuple.")
|
||||
if not (init_kwargs is None or isinstance(init_kwargs, dict)):
|
||||
raise TypeError("init_kwargs must be a dict.")
|
||||
if route_prefix is not None:
|
||||
if not isinstance(route_prefix, str):
|
||||
raise TypeError("route_prefix must be a string.")
|
||||
|
@ -642,6 +649,8 @@ class Deployment:
|
|||
|
||||
if init_args is None:
|
||||
init_args = ()
|
||||
if init_kwargs is None:
|
||||
init_kwargs = {}
|
||||
|
||||
# TODO(architkulkarni): Enforce that autoscaling_config and
|
||||
# user-provided num_replicas should be mutually exclusive.
|
||||
|
@ -657,6 +666,7 @@ class Deployment:
|
|||
self._prev_version = prev_version
|
||||
self._config = config
|
||||
self._init_args = init_args
|
||||
self._init_kwargs = init_kwargs
|
||||
self._route_prefix = route_prefix
|
||||
self._ray_actor_options = ray_actor_options
|
||||
|
||||
|
@ -714,7 +724,12 @@ class Deployment:
|
|||
|
||||
@property
|
||||
def init_args(self) -> Tuple[Any]:
|
||||
"""Arguments passed to the underlying class's constructor."""
|
||||
"""Positional args passed to the underlying class's constructor."""
|
||||
return self._init_args
|
||||
|
||||
@property
|
||||
def init_kwargs(self) -> Tuple[Any]:
|
||||
"""Keyword args passed to the underlying class's constructor."""
|
||||
return self._init_args
|
||||
|
||||
@property
|
||||
|
@ -728,20 +743,25 @@ class Deployment:
|
|||
"Use `deployment.deploy() instead.`")
|
||||
|
||||
@PublicAPI
|
||||
def deploy(self, *init_args, _blocking=True):
|
||||
def deploy(self, *init_args, _blocking=True, **init_kwargs):
|
||||
"""Deploy or update this deployment.
|
||||
|
||||
Args:
|
||||
init_args (optional): args to pass to the class __init__
|
||||
method. Not valid if this deployment wraps a function.
|
||||
init_kwargs (optional): kwargs to pass to the class __init__
|
||||
method. Not valid if this deployment wraps a function.
|
||||
"""
|
||||
if len(init_args) == 0 and self._init_args is not None:
|
||||
init_args = self._init_args
|
||||
if len(init_kwargs) == 0 and self._init_kwargs is not None:
|
||||
init_kwargs = self._init_kwargs
|
||||
|
||||
return _get_global_client().deploy(
|
||||
self._name,
|
||||
self._func_or_class,
|
||||
*init_args,
|
||||
init_args,
|
||||
init_kwargs,
|
||||
ray_actor_options=self._ray_actor_options,
|
||||
config=self._config,
|
||||
version=self._version,
|
||||
|
@ -780,6 +800,7 @@ class Deployment:
|
|||
version: Optional[str] = None,
|
||||
prev_version: Optional[str] = None,
|
||||
init_args: Optional[Tuple[Any]] = None,
|
||||
init_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
route_prefix: Optional[str] = None,
|
||||
num_replicas: Optional[int] = None,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
|
@ -813,6 +834,9 @@ class Deployment:
|
|||
if init_args is None:
|
||||
init_args = self._init_args
|
||||
|
||||
if init_kwargs is None:
|
||||
init_kwargs = self._init_kwargs
|
||||
|
||||
if route_prefix is None:
|
||||
if self._route_prefix == f"/{self._name}":
|
||||
route_prefix = None
|
||||
|
@ -832,6 +856,7 @@ class Deployment:
|
|||
version=version,
|
||||
prev_version=prev_version,
|
||||
init_args=init_args,
|
||||
init_kwargs=init_kwargs,
|
||||
route_prefix=route_prefix,
|
||||
ray_actor_options=ray_actor_options,
|
||||
_internal=True,
|
||||
|
@ -843,6 +868,7 @@ class Deployment:
|
|||
self._version == other._version,
|
||||
self._config == other._config,
|
||||
self._init_args == other._init_args,
|
||||
self._init_kwargs == other._init_kwargs,
|
||||
self._route_prefix == other._route_prefix,
|
||||
self._ray_actor_options == self._ray_actor_options,
|
||||
])
|
||||
|
@ -872,6 +898,7 @@ def deployment(
|
|||
prev_version: Optional[str] = None,
|
||||
num_replicas: Optional[int] = None,
|
||||
init_args: Optional[Tuple[Any]] = None,
|
||||
init_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
user_config: Optional[Any] = None,
|
||||
max_concurrent_queries: Optional[int] = None,
|
||||
|
@ -888,6 +915,7 @@ def deployment(
|
|||
prev_version: Optional[str] = None,
|
||||
num_replicas: Optional[int] = None,
|
||||
init_args: Optional[Tuple[Any]] = None,
|
||||
init_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
route_prefix: Optional[str] = None,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
user_config: Optional[Any] = None,
|
||||
|
@ -911,7 +939,10 @@ def deployment(
|
|||
not check the existing deployment's version.
|
||||
num_replicas (Optional[int]): The number of processes to start up that
|
||||
will handle requests to this deployment. Defaults to 1.
|
||||
init_args (Optional[Tuple]): Arguments to be passed to the class
|
||||
init_args (Optional[Tuple]): Positional args to be passed to the class
|
||||
constructor when starting up deployment replicas. These can also be
|
||||
passed when you call `.deploy()` on the returned Deployment.
|
||||
init_kwargs (Optional[Dict]): Keyword args to be passed to the class
|
||||
constructor when starting up deployment replicas. These can also be
|
||||
passed when you call `.deploy()` on the returned Deployment.
|
||||
route_prefix (Optional[str]): Requests to paths under this HTTP path
|
||||
|
@ -968,6 +999,7 @@ def deployment(
|
|||
version=version,
|
||||
prev_version=prev_version,
|
||||
init_args=init_args,
|
||||
init_kwargs=init_kwargs,
|
||||
route_prefix=route_prefix,
|
||||
ray_actor_options=ray_actor_options,
|
||||
_internal=True,
|
||||
|
@ -1009,6 +1041,7 @@ def get_deployment(name: str) -> Deployment:
|
|||
backend_info.backend_config,
|
||||
version=backend_info.version,
|
||||
init_args=backend_info.replica_config.init_args,
|
||||
init_kwargs=backend_info.replica_config.init_kwargs,
|
||||
route_prefix=route_prefix,
|
||||
ray_actor_options=backend_info.replica_config.ray_actor_options,
|
||||
_internal=True,
|
||||
|
@ -1032,6 +1065,7 @@ def list_deployments() -> Dict[str, Deployment]:
|
|||
backend_info.backend_config,
|
||||
version=backend_info.version,
|
||||
init_args=backend_info.replica_config.init_args,
|
||||
init_kwargs=backend_info.replica_config.init_kwargs,
|
||||
route_prefix=route_prefix,
|
||||
ray_actor_options=backend_info.replica_config.ray_actor_options,
|
||||
_internal=True,
|
||||
|
|
|
@ -164,6 +164,7 @@ class ActorReplicaWrapper:
|
|||
**backend_info.replica_config.ray_actor_options).remote(
|
||||
self.backend_tag, self.replica_tag,
|
||||
backend_info.replica_config.init_args,
|
||||
backend_info.replica_config.init_kwargs,
|
||||
backend_info.backend_config.to_proto_bytes(), version,
|
||||
self._controller_name, self._detached)
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
|
|||
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
|
||||
class RayServeWrappedReplica(object):
|
||||
async def __init__(self, backend_tag, replica_tag, init_args,
|
||||
backend_config_proto_bytes: bytes,
|
||||
init_kwargs, backend_config_proto_bytes: bytes,
|
||||
version: BackendVersion, controller_name: str,
|
||||
detached: bool):
|
||||
backend = cloudpickle.loads(serialized_backend_def)
|
||||
|
@ -72,7 +72,8 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
|
|||
# This allows backends to define an async __init__ method
|
||||
# (required for FastAPI backend definition).
|
||||
_callable = backend.__new__(backend)
|
||||
await sync_to_async(_callable.__init__)(*init_args)
|
||||
await sync_to_async(_callable.__init__)(*init_args,
|
||||
**init_kwargs)
|
||||
# Setting the context again to update the servable_object.
|
||||
ray.serve.api._set_internal_replica_context(
|
||||
backend_tag,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
import pickle
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import pydantic
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
@ -124,16 +124,23 @@ class BackendConfig(BaseModel):
|
|||
|
||||
|
||||
class ReplicaConfig:
|
||||
def __init__(self, backend_def, *init_args, ray_actor_options=None):
|
||||
def __init__(self,
|
||||
backend_def: Callable,
|
||||
init_args: Optional[Tuple[Any]] = None,
|
||||
init_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
ray_actor_options=None):
|
||||
# Validate that backend_def is an import path, function, or class.
|
||||
if isinstance(backend_def, str):
|
||||
self.func_or_class_name = backend_def
|
||||
pass
|
||||
elif inspect.isfunction(backend_def):
|
||||
self.func_or_class_name = backend_def.__name__
|
||||
if len(init_args) != 0:
|
||||
if init_args:
|
||||
raise ValueError(
|
||||
"init_args not supported for function backend.")
|
||||
if init_kwargs:
|
||||
raise ValueError(
|
||||
"init_kwargs not supported for function backend.")
|
||||
elif inspect.isclass(backend_def):
|
||||
self.func_or_class_name = backend_def.__name__
|
||||
else:
|
||||
|
@ -142,7 +149,8 @@ class ReplicaConfig:
|
|||
format(type(backend_def)))
|
||||
|
||||
self.serialized_backend_def = cloudpickle.dumps(backend_def)
|
||||
self.init_args = init_args
|
||||
self.init_args = init_args if init_args is not None else ()
|
||||
self.init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
if ray_actor_options is None:
|
||||
self.ray_actor_options = {}
|
||||
else:
|
||||
|
@ -161,12 +169,13 @@ class ReplicaConfig:
|
|||
raise TypeError("ray_actor_options must be a dictionary.")
|
||||
elif "lifetime" in self.ray_actor_options:
|
||||
raise ValueError(
|
||||
"Specifying lifetime in init_args is not allowed.")
|
||||
"Specifying lifetime in ray_actor_options is not allowed.")
|
||||
elif "name" in self.ray_actor_options:
|
||||
raise ValueError("Specifying name in init_args is not allowed.")
|
||||
raise ValueError(
|
||||
"Specifying name in ray_actor_options is not allowed.")
|
||||
elif "max_restarts" in self.ray_actor_options:
|
||||
raise ValueError("Specifying max_restarts in "
|
||||
"init_args is not allowed.")
|
||||
"ray_actor_options is not allowed.")
|
||||
else:
|
||||
# Ray defaults to zero CPUs for placement, we default to one here.
|
||||
if "num_cpus" not in self.ray_actor_options:
|
||||
|
|
|
@ -20,8 +20,8 @@ from ray.serve.common import (
|
|||
NodeId,
|
||||
ReplicaTag,
|
||||
)
|
||||
from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig)
|
||||
from ray.serve.constants import (CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY)
|
||||
from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig
|
||||
from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY
|
||||
from ray.serve.endpoint_state import EndpointState
|
||||
from ray.serve.http_state import HTTPState
|
||||
from ray.serve.storage.checkpoint_path import make_kv_store
|
||||
|
|
|
@ -52,6 +52,8 @@ def test_replica_config_validation():
|
|||
# Check ray_actor_options validation.
|
||||
ReplicaConfig(
|
||||
Class,
|
||||
tuple(),
|
||||
dict(),
|
||||
ray_actor_options={
|
||||
"num_cpus": 1.0,
|
||||
"num_gpus": 10,
|
||||
|
|
|
@ -734,6 +734,58 @@ def test_init_args(serve_instance):
|
|||
check(10, 11, 12)
|
||||
|
||||
|
||||
def test_init_kwargs(serve_instance):
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
@serve.deployment(init_kwargs=[1, 2, 3])
|
||||
class BadInitArgs:
|
||||
pass
|
||||
|
||||
@serve.deployment(init_kwargs={"a": 1, "b": 2})
|
||||
class D:
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
|
||||
def get_kwargs(self, *args):
|
||||
return self._kwargs
|
||||
|
||||
D.deploy()
|
||||
handle = D.get_handle()
|
||||
|
||||
def check(kwargs):
|
||||
assert ray.get(handle.get_kwargs.remote()) == kwargs
|
||||
|
||||
# Basic sanity check.
|
||||
check({"a": 1, "b": 2})
|
||||
|
||||
# Check passing args to `.deploy()`.
|
||||
D.deploy(a=3, b=4)
|
||||
check({"a": 3, "b": 4})
|
||||
|
||||
# Passing args to `.deploy()` shouldn't override those passed in decorator.
|
||||
D.deploy()
|
||||
check({"a": 1, "b": 2})
|
||||
|
||||
# Check setting with `.options()`.
|
||||
new_D = D.options(init_kwargs={"c": 8, "d": 10})
|
||||
new_D.deploy()
|
||||
check({"c": 8, "d": 10})
|
||||
|
||||
# Should not have changed old deployment object.
|
||||
D.deploy()
|
||||
check({"a": 1, "b": 2})
|
||||
|
||||
# Check that args are only updated on version change.
|
||||
D.options(version="1").deploy()
|
||||
check({"a": 1, "b": 2})
|
||||
|
||||
D.options(version="1").deploy(c=10, d=11)
|
||||
check({"a": 1, "b": 2})
|
||||
|
||||
D.options(version="2").deploy(c=10, d=11)
|
||||
check({"c": 10, "d": 11})
|
||||
|
||||
|
||||
def test_input_validation():
|
||||
name = "test"
|
||||
|
||||
|
|
|
@ -116,6 +116,37 @@ def test_init_args(serve_instance):
|
|||
assert pid3 != pid2
|
||||
|
||||
|
||||
def test_init_kwargs(serve_instance):
|
||||
name = "test"
|
||||
|
||||
@serve.deployment(name=name)
|
||||
class D:
|
||||
def __init__(self, *, val=None):
|
||||
assert val is not None
|
||||
self._val = val
|
||||
|
||||
def __call__(self, *arg):
|
||||
return self._val, os.getpid()
|
||||
|
||||
D.deploy(val="1")
|
||||
val1, pid1 = ray.get(D.get_handle().remote())
|
||||
assert val1 == "1"
|
||||
|
||||
del D
|
||||
|
||||
D2 = serve.get_deployment(name=name)
|
||||
D2.deploy()
|
||||
val2, pid2 = ray.get(D2.get_handle().remote())
|
||||
assert val2 == "1"
|
||||
assert pid2 != pid1
|
||||
|
||||
D2 = serve.get_deployment(name=name)
|
||||
D2.deploy(val="2")
|
||||
val3, pid3 = ray.get(D2.get_handle().remote())
|
||||
assert val3 == "2"
|
||||
assert pid3 != pid2
|
||||
|
||||
|
||||
def test_scale_replicas(serve_instance):
|
||||
name = "test"
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue