mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
244 lines
8.3 KiB
Python
244 lines
8.3 KiB
Python
"""
|
|
Optional utils module contains utility methods
|
|
that require optional dependencies.
|
|
"""
|
|
import asyncio
|
|
import collections
|
|
import functools
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import traceback
|
|
from collections import namedtuple
|
|
from typing import Any
|
|
|
|
import ray.dashboard.consts as dashboard_consts
|
|
from ray.ray_constants import env_bool
|
|
|
|
try:
|
|
create_task = asyncio.create_task
|
|
except AttributeError:
|
|
create_task = asyncio.ensure_future
|
|
|
|
# All third-party dependencies that are not included in the minimal Ray
|
|
# installation must be included in this file. This allows us to determine if
|
|
# the agent has the necessary dependencies to be started.
|
|
from ray.dashboard.optional_deps import aiohttp, hdrs, PathLike, RouteDef
|
|
from ray.dashboard.utils import to_google_style, CustomEncoder
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ClassMethodRouteTable:
|
|
"""A helper class to bind http route to class method."""
|
|
|
|
_bind_map = collections.defaultdict(dict)
|
|
_routes = aiohttp.web.RouteTableDef()
|
|
|
|
class _BindInfo:
|
|
def __init__(self, filename, lineno, instance):
|
|
self.filename = filename
|
|
self.lineno = lineno
|
|
self.instance = instance
|
|
|
|
@classmethod
|
|
def routes(cls):
|
|
return cls._routes
|
|
|
|
@classmethod
|
|
def bound_routes(cls):
|
|
bound_items = []
|
|
for r in cls._routes._items:
|
|
if isinstance(r, RouteDef):
|
|
route_method = getattr(r.handler, "__route_method__")
|
|
route_path = getattr(r.handler, "__route_path__")
|
|
instance = cls._bind_map[route_method][route_path].instance
|
|
if instance is not None:
|
|
bound_items.append(r)
|
|
else:
|
|
bound_items.append(r)
|
|
routes = aiohttp.web.RouteTableDef()
|
|
routes._items = bound_items
|
|
return routes
|
|
|
|
@classmethod
|
|
def _register_route(cls, method, path, **kwargs):
|
|
def _wrapper(handler):
|
|
if path in cls._bind_map[method]:
|
|
bind_info = cls._bind_map[method][path]
|
|
raise Exception(
|
|
f"Duplicated route path: {path}, "
|
|
f"previous one registered at "
|
|
f"{bind_info.filename}:{bind_info.lineno}"
|
|
)
|
|
|
|
bind_info = cls._BindInfo(
|
|
handler.__code__.co_filename, handler.__code__.co_firstlineno, None
|
|
)
|
|
|
|
@functools.wraps(handler)
|
|
async def _handler_route(*args) -> aiohttp.web.Response:
|
|
try:
|
|
# Make the route handler as a bound method.
|
|
# The args may be:
|
|
# * (Request, )
|
|
# * (self, Request)
|
|
req = args[-1]
|
|
return await handler(bind_info.instance, req)
|
|
except Exception:
|
|
logger.exception("Handle %s %s failed.", method, path)
|
|
return rest_response(success=False, message=traceback.format_exc())
|
|
|
|
cls._bind_map[method][path] = bind_info
|
|
_handler_route.__route_method__ = method
|
|
_handler_route.__route_path__ = path
|
|
return cls._routes.route(method, path, **kwargs)(_handler_route)
|
|
|
|
return _wrapper
|
|
|
|
@classmethod
|
|
def head(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_HEAD, path, **kwargs)
|
|
|
|
@classmethod
|
|
def get(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_GET, path, **kwargs)
|
|
|
|
@classmethod
|
|
def post(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_POST, path, **kwargs)
|
|
|
|
@classmethod
|
|
def put(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_PUT, path, **kwargs)
|
|
|
|
@classmethod
|
|
def patch(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_PATCH, path, **kwargs)
|
|
|
|
@classmethod
|
|
def delete(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_DELETE, path, **kwargs)
|
|
|
|
@classmethod
|
|
def view(cls, path, **kwargs):
|
|
return cls._register_route(hdrs.METH_ANY, path, **kwargs)
|
|
|
|
@classmethod
|
|
def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None:
|
|
cls._routes.static(prefix, path, **kwargs)
|
|
|
|
@classmethod
|
|
def bind(cls, instance):
|
|
def predicate(o):
|
|
if inspect.ismethod(o):
|
|
return hasattr(o, "__route_method__") and hasattr(o, "__route_path__")
|
|
return False
|
|
|
|
handler_routes = inspect.getmembers(instance, predicate)
|
|
for _, h in handler_routes:
|
|
cls._bind_map[h.__func__.__route_method__][
|
|
h.__func__.__route_path__
|
|
].instance = instance
|
|
|
|
|
|
def rest_response(
|
|
success, message, convert_google_style=True, **kwargs
|
|
) -> aiohttp.web.Response:
|
|
# In the dev context we allow a dev server running on a
|
|
# different port to consume the API, meaning we need to allow
|
|
# cross-origin access
|
|
if os.environ.get("RAY_DASHBOARD_DEV") == "1":
|
|
headers = {"Access-Control-Allow-Origin": "*"}
|
|
else:
|
|
headers = {}
|
|
return aiohttp.web.json_response(
|
|
{
|
|
"result": success,
|
|
"msg": message,
|
|
"data": to_google_style(kwargs) if convert_google_style else kwargs,
|
|
},
|
|
dumps=functools.partial(json.dumps, cls=CustomEncoder),
|
|
headers=headers,
|
|
)
|
|
|
|
|
|
# The cache value type used by aiohttp_cache.
|
|
_AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
|
|
# The methods with no request body used by aiohttp_cache.
|
|
_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
|
|
|
|
|
|
def aiohttp_cache(
|
|
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
|
|
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
|
|
enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False),
|
|
):
|
|
assert maxsize > 0
|
|
cache = collections.OrderedDict()
|
|
|
|
def _wrapper(handler):
|
|
if enable:
|
|
|
|
@functools.wraps(handler)
|
|
async def _cache_handler(*args) -> aiohttp.web.Response:
|
|
# Make the route handler as a bound method.
|
|
# The args may be:
|
|
# * (Request, )
|
|
# * (self, Request)
|
|
req = args[-1]
|
|
# Make key.
|
|
if req.method in _AIOHTTP_CACHE_NOBODY_METHODS:
|
|
key = req.path_qs
|
|
else:
|
|
key = (req.path_qs, await req.read())
|
|
# Query cache.
|
|
value = cache.get(key)
|
|
if value is not None:
|
|
cache.move_to_end(key)
|
|
if not value.task.done() or value.expiration >= time.time():
|
|
# Update task not done or the data is not expired.
|
|
return aiohttp.web.Response(**value.data)
|
|
|
|
def _update_cache(task):
|
|
try:
|
|
response = task.result()
|
|
except Exception:
|
|
response = rest_response(
|
|
success=False, message=traceback.format_exc()
|
|
)
|
|
data = {
|
|
"status": response.status,
|
|
"headers": dict(response.headers),
|
|
"body": response.body,
|
|
}
|
|
cache[key] = _AiohttpCacheValue(
|
|
data, time.time() + ttl_seconds, task
|
|
)
|
|
cache.move_to_end(key)
|
|
if len(cache) > maxsize:
|
|
cache.popitem(last=False)
|
|
return response
|
|
|
|
task = create_task(handler(*args))
|
|
task.add_done_callback(_update_cache)
|
|
if value is None:
|
|
return await task
|
|
else:
|
|
return aiohttp.web.Response(**value.data)
|
|
|
|
suffix = f"[cache ttl={ttl_seconds}, max_size={maxsize}]"
|
|
_cache_handler.__name__ += suffix
|
|
_cache_handler.__qualname__ += suffix
|
|
return _cache_handler
|
|
else:
|
|
return handler
|
|
|
|
if inspect.iscoroutinefunction(ttl_seconds):
|
|
target_func = ttl_seconds
|
|
ttl_seconds = dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS
|
|
return _wrapper(target_func)
|
|
else:
|
|
return _wrapper
|