import logging
import os
from abc import ABC, abstractmethod
from os import path

from ray.streaming.config import ConfigHelper, Config

logger = logging.getLogger(__name__)


class ContextBackend(ABC):
    @abstractmethod
    def get(self, key):
        pass

    @abstractmethod
    def put(self, key, value):
        pass

    @abstractmethod
    def remove(self, key):
        pass


class MemoryContextBackend(ContextBackend):
    def __init__(self, conf):
        self.__dic = dict()

    def get(self, key):
        return self.__dic.get(key)

    def put(self, key, value):
        self.__dic[key] = value

    def remove(self, key):
        if key in self.__dic:
            del self.__dic[key]


class LocalFileContextBackend(ContextBackend):
    def __init__(self, conf):
        self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf)
        logger.info("Start init local file state backend, root_dir={}.".format(
            self.__dir))
        try:
            os.mkdir(self.__dir)
        except FileExistsError:
            logger.info("dir already exists, skipped.")

    def put(self, key, value):
        logger.info("Put value of key {} start.".format(key))
        with open(self.__gen_file_path(key), "wb") as f:
            f.write(value)

    def get(self, key):
        logger.info("Get value of key {} start.".format(key))
        full_path = self.__gen_file_path(key)
        if not os.path.isfile(full_path):
            return None
        with open(full_path, "rb") as f:
            return f.read()

    def remove(self, key):
        logger.info("Remove value of key {} start.".format(key))
        try:
            os.remove(self.__gen_file_path(key))
        except Exception:
            # ignore exception
            pass

    def rename(self, src, dst):
        logger.info("rename {} to {}".format(src, dst))
        os.rename(self.__gen_file_path(src), self.__gen_file_path(dst))

    def exists(self, key) -> bool:
        return os.path.exists(key)

    def __gen_file_path(self, key):
        return path.join(self.__dir, key)


class AtomicFsContextBackend(LocalFileContextBackend):
    def __init__(self, conf):
        super().__init__(conf)
        self.__tmp_flag = "_tmp"

    def put(self, key, value):
        tmp_key = key + self.__tmp_flag
        if super().exists(tmp_key) and not super().exists(key):
            super().rename(tmp_key, key)
        super().put(tmp_key, value)
        super().remove(key)
        super().rename(tmp_key, key)

    def get(self, key):
        tmp_key = key + self.__tmp_flag
        if super().exists(tmp_key) and not super().exists(key):
            return super().get(tmp_key)
        return super().get(key)

    def remove(self, key):
        tmp_key = key + self.__tmp_flag
        if super().exists(tmp_key):
            super().remove(tmp_key)
        super().remove(key)


class ContextBackendFactory:
    @staticmethod
    def get_context_backend(worker_config) -> ContextBackend:
        backend_type = ConfigHelper.get_cp_context_backend_type(worker_config)
        context_backend = None
        if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE:
            context_backend = AtomicFsContextBackend(worker_config)
        elif backend_type == Config.CP_STATE_BACKEND_MEMORY:
            context_backend = MemoryContextBackend(worker_config)
        return context_backend