diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 9d3fb6e4f..a72d24979 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -99,7 +99,7 @@ class ActorReplicaWrapper: except ValueError: logger.debug("Starting replica '{}' for backend '{}'.".format( self._replica_tag, self._backend_tag)) - self._actor_handle = ray.remote(backend_info.worker_class).options( + self._actor_handle = backend_info.actor_def.options( name=self._actor_name, lifetime="detached" if self._detached else None, placement_group=self._placement_group, diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index dafdf7a05..b73466b93 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -1,10 +1,11 @@ from dataclasses import dataclass, field from pydantic import BaseModel -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from uuid import UUID import numpy as np +from ray.actor import ActorClass from ray.serve.config import BackendConfig, ReplicaConfig BackendTag = str @@ -24,9 +25,7 @@ class EndpointInfo: class BackendInfo(BaseModel): - # TODO(architkulkarni): Add type hint for worker_class after upgrading - # cloudpickle and adding types to RayServeWrappedReplica - worker_class: Any + actor_def: Optional[ActorClass] version: Optional[str] backend_config: BackendConfig replica_config: ReplicaConfig diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e1b52a720..4fed74e01 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -212,8 +212,9 @@ class ServeController: """Register a new backend under the specified tag.""" async with self.write_lock: backend_info = BackendInfo( - worker_class=create_backend_replica( - backend_tag, replica_config.serialized_backend_def), + actor_def=ray.remote( + create_backend_replica( + backend_tag, replica_config.serialized_backend_def)), version=RESERVED_VERSION_TAG, backend_config=backend_config, replica_config=replica_config) @@ -244,7 +245,7 @@ class ServeController: raise ValueError(f"Backend {backend_tag} is not registered.") backend_info = BackendInfo( - worker_class=existing_info.worker_class, + actor_def=existing_info.actor_def, version=existing_info.version, backend_config=existing_info.backend_config.copy( update=config_options.dict(exclude_unset=True)), @@ -279,8 +280,9 @@ class ServeController: async with self.write_lock: backend_info = BackendInfo( - worker_class=create_backend_replica( - name, replica_config.serialized_backend_def), + actor_def=ray.remote( + create_backend_replica( + name, replica_config.serialized_backend_def)), version=version, backend_config=backend_config, replica_config=replica_config) diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index 8447aaaae..fe35793e0 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -95,7 +95,7 @@ def backend_info(version: Optional[str] = None, num_replicas: Optional[int] = 1, **config_opts) -> BackendInfo: return BackendInfo( - worker_class=None, + actor_def=None, version=version, backend_config=BackendConfig(num_replicas=num_replicas, **config_opts), replica_config=ReplicaConfig(lambda x: x))