diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 7601a7cd8..58181eff6 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -1309,6 +1309,9 @@ class Deployment: if version is None: version = self._version + if prev_version is None: + prev_version = self._prev_version + if init_args is None: init_args = self._init_args @@ -1350,6 +1353,61 @@ class Deployment: _internal=True, ) + @PublicAPI(stability="alpha") + def set_options( + self, + func_or_class: Optional[Callable] = None, + name: Optional[str] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + init_args: Optional[Tuple[Any]] = None, + init_kwargs: Optional[Dict[Any, Any]] = None, + route_prefix: Union[str, None, DEFAULT] = DEFAULT.VALUE, + num_replicas: Optional[int] = None, + ray_actor_options: Optional[Dict] = None, + user_config: Optional[Any] = None, + max_concurrent_queries: Optional[int] = None, + _autoscaling_config: Optional[Union[Dict, AutoscalingConfig]] = None, + _graceful_shutdown_wait_loop_s: Optional[float] = None, + _graceful_shutdown_timeout_s: Optional[float] = None, + _health_check_period_s: Optional[float] = None, + _health_check_timeout_s: Optional[float] = None, + ) -> None: + """Overwrite this deployment's options. Mutates the deployment. + + Only those options passed in will be updated, all others will remain + unchanged. + """ + + validated = self.options( + func_or_class=func_or_class, + name=name, + version=version, + prev_version=prev_version, + init_args=init_args, + init_kwargs=init_kwargs, + route_prefix=route_prefix, + num_replicas=num_replicas, + ray_actor_options=ray_actor_options, + user_config=user_config, + max_concurrent_queries=max_concurrent_queries, + _autoscaling_config=_autoscaling_config, + _graceful_shutdown_wait_loop_s=_graceful_shutdown_wait_loop_s, + _graceful_shutdown_timeout_s=_graceful_shutdown_timeout_s, + _health_check_period_s=_health_check_period_s, + _health_check_timeout_s=_health_check_timeout_s, + ) + + self._func_or_class = validated._func_or_class + self._name = validated._name + self._version = validated._version + self._prev_version = validated._prev_version + self._init_args = validated._init_args + self._init_kwargs = validated._init_kwargs + self._route_prefix = validated._route_prefix + self._ray_actor_options = validated._ray_actor_options + self._config = validated._config + def __eq__(self, other): return all( [ diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index e8d5736dc..c9c1f4b3a 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -336,6 +336,44 @@ def test_shutdown_destructor(serve_instance): B.delete() +class TestSetOptions: + def test_set_options_basic(self): + @serve.deployment( + num_replicas=4, + max_concurrent_queries=3, + prev_version="abcd", + ray_actor_options={"num_cpus": 2}, + _health_check_timeout_s=17, + ) + def f(): + pass + + f.set_options( + num_replicas=9, + prev_version="abcd", + version="efgh", + ray_actor_options={"num_gpus": 3}, + ) + + assert f.num_replicas == 9 + assert f.max_concurrent_queries == 3 + assert f.prev_version == "abcd" + assert f.version == "efgh" + assert f.ray_actor_options == {"num_gpus": 3} + assert f._config.health_check_timeout_s == 17 + + def test_set_options_validation(self): + @serve.deployment + def f(): + pass + + with pytest.raises(TypeError): + f.set_options(init_args=-4) + + with pytest.raises(ValueError): + f.set_options(max_concurrent_queries=-4) + + if __name__ == "__main__": import sys