mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from os import path
|
|
|
|
from ray.streaming.config import ConfigHelper, Config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ContextBackend(ABC):
|
|
@abstractmethod
|
|
def get(self, key):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def put(self, key, value):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def remove(self, key):
|
|
pass
|
|
|
|
|
|
class MemoryContextBackend(ContextBackend):
|
|
def __init__(self, conf):
|
|
self.__dic = dict()
|
|
|
|
def get(self, key):
|
|
return self.__dic.get(key)
|
|
|
|
def put(self, key, value):
|
|
self.__dic[key] = value
|
|
|
|
def remove(self, key):
|
|
if key in self.__dic:
|
|
del self.__dic[key]
|
|
|
|
|
|
class LocalFileContextBackend(ContextBackend):
|
|
def __init__(self, conf):
|
|
self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf)
|
|
logger.info("Start init local file state backend, root_dir={}.".format(
|
|
self.__dir))
|
|
try:
|
|
os.mkdir(self.__dir)
|
|
except FileExistsError:
|
|
logger.info("dir already exists, skipped.")
|
|
|
|
def put(self, key, value):
|
|
logger.info("Put value of key {} start.".format(key))
|
|
with open(self.__gen_file_path(key), "wb") as f:
|
|
f.write(value)
|
|
|
|
def get(self, key):
|
|
logger.info("Get value of key {} start.".format(key))
|
|
full_path = self.__gen_file_path(key)
|
|
if not os.path.isfile(full_path):
|
|
return None
|
|
with open(full_path, "rb") as f:
|
|
return f.read()
|
|
|
|
def remove(self, key):
|
|
logger.info("Remove value of key {} start.".format(key))
|
|
try:
|
|
os.remove(self.__gen_file_path(key))
|
|
except Exception:
|
|
# ignore exception
|
|
pass
|
|
|
|
def rename(self, src, dst):
|
|
logger.info("rename {} to {}".format(src, dst))
|
|
os.rename(self.__gen_file_path(src), self.__gen_file_path(dst))
|
|
|
|
def exists(self, key) -> bool:
|
|
return os.path.exists(key)
|
|
|
|
def __gen_file_path(self, key):
|
|
return path.join(self.__dir, key)
|
|
|
|
|
|
class AtomicFsContextBackend(LocalFileContextBackend):
|
|
def __init__(self, conf):
|
|
super().__init__(conf)
|
|
self.__tmp_flag = "_tmp"
|
|
|
|
def put(self, key, value):
|
|
tmp_key = key + self.__tmp_flag
|
|
if super().exists(tmp_key) and not super().exists(key):
|
|
super().rename(tmp_key, key)
|
|
super().put(tmp_key, value)
|
|
super().remove(key)
|
|
super().rename(tmp_key, key)
|
|
|
|
def get(self, key):
|
|
tmp_key = key + self.__tmp_flag
|
|
if super().exists(tmp_key) and not super().exists(key):
|
|
return super().get(tmp_key)
|
|
return super().get(key)
|
|
|
|
def remove(self, key):
|
|
tmp_key = key + self.__tmp_flag
|
|
if super().exists(tmp_key):
|
|
super().remove(tmp_key)
|
|
super().remove(key)
|
|
|
|
|
|
class ContextBackendFactory:
|
|
@staticmethod
|
|
def get_context_backend(worker_config) -> ContextBackend:
|
|
backend_type = ConfigHelper.get_cp_context_backend_type(worker_config)
|
|
context_backend = None
|
|
if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE:
|
|
context_backend = AtomicFsContextBackend(worker_config)
|
|
elif backend_type == Config.CP_STATE_BACKEND_MEMORY:
|
|
context_backend = MemoryContextBackend(worker_config)
|
|
return context_backend
|