diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index aaff4737e..fc63fac16 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,7 +7,8 @@ from multiprocessing import cpu_count import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, SERVE_MASTER_NAME) -from ray.serve.global_state import GlobalState, ServeMaster +from ray.serve.master import ServeMaster +from ray.serve.handle import RayServeHandle from ray.serve.kv_store_service import SQLiteKVStore from ray.serve.task_runner import RayServeMixin, TaskRunnerActor from ray.serve.utils import block_until_http_ready @@ -17,20 +18,20 @@ from ray.serve.policy import RoutePolicy from ray.serve.queues import Query from ray.serve.request_params import RequestMetadata -global_state = None +master_actor = None -def _get_global_state(): - """Used for internal purpose. Because just import serve.global_state - will always reference the original None object +def _get_master_actor(): + """Used for internal purpose because using just import serve.global_state + will always reference the original None object. """ - return global_state + return master_actor def _ensure_connected(f): @wraps(f) def check(*args, **kwargs): - if _get_global_state() is None: + if _get_master_actor() is None: raise RayServeException("Please run serve.init to initialize or " "connect to existing ray serve cluster.") return f(*args, **kwargs) @@ -103,9 +104,8 @@ def init( the backend for a service. (Default: RoutePolicy.Random) policy_kwargs: Arguments required to instantiate a queueing policy """ - global global_state - # Noop if global_state is no longer None - if global_state is not None: + global master_actor + if master_actor is not None: return # Initialize ray if needed. @@ -118,8 +118,7 @@ def init( # Try to get serve master actor if it exists try: - ray.util.get_actor(SERVE_MASTER_NAME) - global_state = GlobalState() + master_actor = ray.util.get_actor(SERVE_MASTER_NAME) return except ValueError: pass @@ -135,20 +134,20 @@ def init( _, kv_store_path = mkstemp() # Serve has not been initialized, perform init sequence - # TODO move the db to session_dir + # TODO move the db to session_dir. # ray.worker._global_node.address_info["session_dir"] def kv_store_connector(namespace): return SQLiteKVStore(namespace, db_path=kv_store_path) - master = ServeMaster.options( + master_actor = ServeMaster.options( detached=True, name=SERVE_MASTER_NAME).remote(kv_store_connector) - ray.get(master.start_router.remote(queueing_policy.value, policy_kwargs)) + ray.get( + master_actor.start_router.remote(queueing_policy.value, policy_kwargs)) - global_state = GlobalState(master) - ray.get(master.start_metric_monitor.remote(gc_window_seconds)) + ray.get(master_actor.start_metric_monitor.remote(gc_window_seconds)) if start_server: - ray.get(master.start_http_proxy.remote(http_host, http_port)) + ray.get(master_actor.start_http_proxy.remote(http_host, http_port)) if start_server and blocking: block_until_http_ready("http://{}:{}/-/routes".format( @@ -168,8 +167,8 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]): registered before returning """ ray.get( - global_state.master_actor.create_endpoint.remote( - route, endpoint_name, [m.upper() for m in methods])) + master_actor.create_endpoint.remote(route, endpoint_name, + [m.upper() for m in methods])) @_ensure_connected @@ -181,8 +180,7 @@ def set_backend_config(backend_tag, backend_config): backend_config(BackendConfig) : Desired backend configuration. """ ray.get( - global_state.master_actor.set_backend_config.remote( - backend_tag, backend_config)) + master_actor.set_backend_config.remote(backend_tag, backend_config)) @_ensure_connected @@ -192,8 +190,7 @@ def get_backend_config(backend_tag): Args: backend_tag(str): A registered backend. """ - return ray.get( - global_state.master_actor.get_backend_config.remote(backend_tag)) + return ray.get(master_actor.get_backend_config.remote(backend_tag)) def _backend_accept_batch(func_or_class): @@ -258,8 +255,8 @@ def create_backend(func_or_class, type(func_or_class))) ray.get( - global_state.master_actor.create_backend.remote( - backend_tag, creator, backend_config, arg_list)) + master_actor.create_backend.remote(backend_tag, creator, + backend_config, arg_list)) @_ensure_connected @@ -295,8 +292,8 @@ def split(endpoint_name, traffic_policy_dictionary): to their traffic weights. The weights must sum to 1. """ ray.get( - global_state.master_actor.split_traffic.remote( - endpoint_name, traffic_policy_dictionary)) + master_actor.split_traffic.remote(endpoint_name, + traffic_policy_dictionary)) @_ensure_connected @@ -319,13 +316,11 @@ def get_handle(endpoint_name, RayServeHandle """ if not missing_ok: - assert endpoint_name in global_state.get_all_endpoints() - - # Delay import due to it's dependency on global_state - from ray.serve.handle import RayServeHandle + assert endpoint_name in ray.get( + master_actor.get_all_endpoints.remote()) return RayServeHandle( - global_state.get_router(), + ray.get(master_actor.get_router.remote())[0], endpoint_name, relative_slo_ms, absolute_slo_ms, @@ -344,7 +339,7 @@ def stat(percentiles=[50, 90, 95], The longest aggregation window must be shorter or equal to the gc_window_seconds. """ - monitor = global_state.get_metric_monitor() + [monitor] = ray.get(master_actor.get_metric_monitor.remote()) return ray.get(monitor.collect.remote(percentiles, agg_windows_seconds)) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index ccb499197..366961eed 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,3 +1,4 @@ +import ray from ray import serve from ray.serve.context import TaskContext from ray.serve.exceptions import RayServeException @@ -105,9 +106,13 @@ class RayServeHandle: def get_http_endpoint(self): return DEFAULT_HTTP_ADDRESS + def get_traffic_policy(self): + master_actor = serve.api._get_master_actor() + return ray.get( + master_actor.get_traffic_policy.remote(self.endpoint_name)) + def _ensure_backend_unique(self, backend_tag=None): - global_state = serve.api._get_global_state() - traffic_policy = global_state.get_traffic_policy(self.endpoint_name) + traffic_policy = self.get_traffic_policy() if backend_tag is None: assert len(traffic_policy) == 1, ( "Multiple backends detected. " diff --git a/python/ray/serve/global_state.py b/python/ray/serve/master.py similarity index 91% rename from python/ray/serve/global_state.py rename to python/ray/serve/master.py index aefa6e451..6048b6f95 100644 --- a/python/ray/serve/global_state.py +++ b/python/ray/serve/master.py @@ -1,6 +1,6 @@ import ray from ray.serve.backend_config import BackendConfig -from ray.serve.constants import (SERVE_MASTER_NAME, ASYNC_CONCURRENCY) +from ray.serve.constants import ASYNC_CONCURRENCY from ray.serve.exceptions import batch_annotation_not_found from ray.serve.http_proxy import HTTPProxyActor from ray.serve.kv_store_service import (BackendTable, RoutingTable, @@ -232,31 +232,3 @@ class ServeMaster: ), "Backend {} is not registered.".format(backend_tag) backend_config_dict = self.backend_table.get_info(backend_tag) return BackendConfig(**backend_config_dict) - - -class GlobalState: - """Encapsulate all global state in the serving system. - - The information is fetch lazily from - 1. A collection of namespaced key value stores - 2. A actor supervisor service - """ - - def __init__(self, master_actor=None): - # Get actor nursery handle. - if master_actor is None: - master_actor = ray.util.get_actor(SERVE_MASTER_NAME) - self.master_actor = master_actor - - def get_router(self): - return ray.get(self.master_actor.get_router.remote())[0] - - def get_metric_monitor(self): - return ray.get(self.master_actor.get_metric_monitor.remote())[0] - - def get_traffic_policy(self, endpoint_name): - return ray.get( - self.master_actor.get_traffic_policy.remote(endpoint_name)) - - def get_all_endpoints(self): - return ray.get(self.master_actor.get_all_endpoints.remote()) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index a93219396..5cf201524 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -10,7 +10,7 @@ from ray.serve.handle import RayServeHandle def test_e2e(serve_instance): - serve.init() # so we have access to global state + serve.init() serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"]) retry_count = 5 @@ -181,9 +181,9 @@ def test_killing_replicas(serve_instance): serve.create_endpoint("simple", "/simple") b_config = BackendConfig(num_replicas=3, num_cpus=2) serve.create_backend(Simple, "simple:v1", backend_config=b_config) - global_state = serve.api._get_global_state() + master_actor = serve.api._get_master_actor() old_replica_tag_list = ray.get( - global_state.master_actor._list_replicas.remote("simple:v1")) + master_actor._list_replicas.remote("simple:v1")) bnew_config = serve.get_backend_config("simple:v1") # change the config @@ -191,9 +191,9 @@ def test_killing_replicas(serve_instance): # set the config serve.set_backend_config("simple:v1", bnew_config) new_replica_tag_list = ray.get( - global_state.master_actor._list_replicas.remote("simple:v1")) + master_actor._list_replicas.remote("simple:v1")) new_all_tag_list = list( - ray.get(global_state.master_actor.get_all_handles.remote()).keys()) + ray.get(master_actor.get_all_handles.remote()).keys()) # the new_replica_tag_list must be subset of all_tag_list assert set(new_replica_tag_list) <= set(new_all_tag_list) @@ -215,9 +215,9 @@ def test_not_killing_replicas(serve_instance): serve.create_endpoint("bsimple", "/bsimple") b_config = BackendConfig(num_replicas=3, max_batch_size=2) serve.create_backend(BatchSimple, "bsimple:v1", backend_config=b_config) - global_state = serve.api._get_global_state() + master_actor = serve.api._get_master_actor() old_replica_tag_list = ray.get( - global_state.master_actor._list_replicas.remote("bsimple:v1")) + master_actor._list_replicas.remote("bsimple:v1")) bnew_config = serve.get_backend_config("bsimple:v1") # change the config @@ -225,9 +225,9 @@ def test_not_killing_replicas(serve_instance): # set the config serve.set_backend_config("bsimple:v1", bnew_config) new_replica_tag_list = ray.get( - global_state.master_actor._list_replicas.remote("bsimple:v1")) + master_actor._list_replicas.remote("bsimple:v1")) new_all_tag_list = list( - ray.get(global_state.master_actor.get_all_handles.remote()).keys()) + ray.get(master_actor.get_all_handles.remote()).keys()) # the old and new replica tag list should be identical # and should be subset of all_tag_list