import combLib
from functools import namedtuple
import h5py
import hashlib
import logging
import numpy as np
import os
import random
import signalDelay
import signal
import subprocess
import sqlitedict as sqd
import warnings
import shutil

import binfootprint as bf

log = logging.getLogger(__name__)

HIMetaKey_type = namedtuple(
    "HIMetaKey_type", ["HiP", "IntP", "SysP", "Eta", "EtaTherm"]
)

RESULT_TYPE_ZEROTH_ORDER_ONLY = "ZEROTH_ORDER_ONLY"
RESULT_TYPE_ZEROTH_ORDER_AND_ETA_LAMBDA = "ZEROTH_ORDER_AND_ETA_LAMBDA"
RESULT_TYPE_ALL = "ALL"

CHAR_SET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"


def projector_psi_t(psi_t, normed=False):
    """
    assume shape (len(t), dim_H_sys)
    """
    psi_t_col = np.expand_dims(psi_t, axis=2)
    psi_t_row = np.expand_dims(psi_t, axis=1)

    if normed:
        N = np.sum(np.conj(psi_t) * psi_t, axis=1).reshape(psi_t.shape[0], 1, 1)
        return (psi_t_col * np.conj(psi_t_row)) / N
    else:
        return psi_t_col * np.conj(psi_t_row)


class HiP(object):
    """
    a purely readable (non binary) data object
    """

    __slots__ = [
        "k_max",
        "g_scale",
        "sample_method",
        "seed",
        "nonlinear",
        "normalized",
        "terminator",
        "result_type",
        "accum_only",
        "rand_skip",
    ]

    def __init__(
        self,
        k_max,
        g_scale=None,
        sample_method="random",
        seed=0,
        nonlinear=True,
        normalized=False,
        terminator=False,
        result_type=RESULT_TYPE_ZEROTH_ORDER_ONLY,
        accum_only=None,
        rand_skip=None,
    ):
        self.k_max = k_max
        self.g_scale = g_scale
        self.sample_method = sample_method
        self.seed = seed
        self.nonlinear = nonlinear
        self.normalized = normalized
        self.terminator = terminator
        self.result_type = result_type
        self.accum_only = accum_only
        self.rand_skip = rand_skip
        if accum_only is None:
            if self.rand_skip is not None:
                raise ValueError(
                    "if accum_only is 'None' (not set) rand_skip must also be 'None'"
                )

    def __bfkey__(self):
        if self.accum_only is None:
            if self.rand_skip is not None:
                raise ValueError(
                    "if accum_only is 'None' (not set) rand_skip must also be 'None'"
                )
            return [
                self.k_max,
                self.g_scale,
                self.sample_method,
                self.seed,
                self.nonlinear,
                self.normalized,
                self.terminator,
                self.result_type,
            ]

        if self.rand_skip is None:
            return [
                self.k_max,
                self.g_scale,
                self.sample_method,
                self.seed,
                self.nonlinear,
                self.normalized,
                self.terminator,
                self.result_type,
                self.accum_only,
            ]

        return [
            self.k_max,
            self.g_scale,
            self.sample_method,
            self.seed,
            self.nonlinear,
            self.normalized,
            self.terminator,
            self.result_type,
            self.accum_only,
            self.rand_skip,
        ]

    def __repr__(self):
        return (
            "k_max        : {}\n".format(self.k_max)
            + "g_scale      : {}\n".format(self.g_scale)
            + "sample_method: {}\n".format(self.sample_method)
            + "seed         : {}\n".format(self.seed)
            + "nonlinear    : {}\n".format(self.nonlinear)
            + "normalized   : {}\n".format(self.normalized)
            + "terminator   : {}\n".format(self.terminator)
            + "result_type  : {}\n".format(self.result_type)
            + "accum_only   : {}\n".format(self.accum_only)
            + "rand_skip    : {}\n".format(self.rand_skip)
            + ""
        )


class IntP(object):
    """
    a purely readable (non binary) data object
    """

    __slots__ = [
        "t_max",
        "t_steps",
        "integrator_name",
        "atol",
        "rtol",
        "order",
        "nsteps",
        "method",
        "t_steps_skip",
    ]

    def __init__(
        self,
        t_max,
        t_steps,
        integrator_name="zvode",
        atol=1e-8,
        rtol=1e-8,
        order=5,
        nsteps=5000,
        method="bdf",
        t_steps_skip=1,
    ):
        self.t_max = t_max
        self.t_steps = t_steps
        self.integrator_name = integrator_name
        self.atol = atol
        self.rtol = rtol
        self.order = order
        self.nsteps = nsteps
        self.method = method
        self.t_steps_skip = t_steps_skip

    def __bfkey__(self):
        if self.t_steps_skip == 1:
            return [
                self.t_max,
                self.t_steps,
                self.integrator_name,
                self.atol,
                self.rtol,
                self.order,
                self.nsteps,
                self.method,
            ]
        else:
            return [
                self.t_max,
                self.t_steps,
                self.integrator_name,
                self.atol,
                self.rtol,
                self.order,
                self.nsteps,
                self.method,
                self.t_steps_skip,
            ]

    def __repr__(self):
        return (
            "t_max          : {}\n".format(self.t_max)
            + "t_steps        : {}\n".format(self.t_steps)
            + "integrator_name: {}\n".format(self.integrator_name)
            + "atol           : {}\n".format(self.atol)
            + "rtol           : {}\n".format(self.rtol)
            + "order          : {}\n".format(self.order)
            + "nsteps         : {}\n".format(self.nsteps)
            + "method         : {}\n".format(self.method)
            + "t_steps_skip   : {}\n".format(self.t_steps_skip)
        )


class SysP(object):
    __slots__ = [
        "H_sys",
        "L",
        "psi0",
        "g",
        "w",
        "H_dynamic",
        "bcf_scale",
        "gw_hash",
        "len_gw",
        "gw_info",
        "T",
        "T_method",
    ]

    def __init__(
        self,
        H_sys,
        L,
        psi0,
        g,
        w,
        H_dynamic,
        bcf_scale,
        gw_hash,
        len_gw,
        gw_info=None,  # these are only info fields
        T=None,  # which do not enter the key
        T_method=None,
    ):  # as T is included either in g/w or in EtaTherm
        self.H_sys = H_sys
        self.L = L
        self.psi0 = psi0
        self.g = g
        self.w = w
        self.H_dynamic = H_dynamic
        self.bcf_scale = bcf_scale
        self.gw_hash = gw_hash
        self.len_gw = len_gw
        self.gw_info = gw_info
        self.T = T
        self.T_method = T_method

        if (self.gw_hash is None) and ((self.g is None) or (self.w is None)):
            raise ValueError("specify either g/w or gw_hash")

    def __bfkey__(self):
        return [
            self.H_sys,
            self.L,
            self.psi0,
            self.g,
            self.w,
            self.H_dynamic,
            self.bcf_scale,
            self.gw_hash,
        ]

    def __repr__(self):
        return (
            "H_sys    : {}\n".format(self.H_sys)
            + "L        : {}\n".format(self.L)
            + "psi0     : {}\n".format(self.psi0)
            + "g        : {}\n".format(self.g)
            + "w        : {}\n".format(self.w)
            + "H_dynamic: {}\n".format(self.H_dynamic)
            + "bcf_scale: {}\n".format(self.bcf_scale)
            + "gw_hash  : {}\n".format(self.gw_hash)
        )


RAND_STR_ASCII_IDX_LIST = (
    list(range(48, 58)) + list(range(65, 91)) + list(range(97, 123))
)


def rand_str(l=8):
    s = ""
    for i in range(l):
        s += chr(random.choice(RAND_STR_ASCII_IDX_LIST))
    return s


HIData_default_size_stoc_traj = 10
HIData_default_size_rho_t_accum_part = 10


def is_int_power(x, b=2):
    n_float = np.log(x) / np.log(b)
    n = int(n_float)
    if b ** n == x:
        return n
    else:
        return None


def file_does_not_exists_or_is_empty(fname):
    if not os.path.exists(fname):
        return True
    else:
        return os.path.getsize(fname) == 0


class HIData(object):
    def __init__(
        self,
        hdf5_name,
        size_sys,
        size_t,
        size_y,
        size_temp_y,
        size_aux_state=0,
        num_bcf_terms=0,
        accum_only=False,
        read_only=False,
    ):
        self.hdf5_name = hdf5_name

        self.accum_only = accum_only

        if file_does_not_exists_or_is_empty(hdf5_name):
            # print("file_does_not_exists_or_is_empty {} -> call init_file".format(hdf5_name))
            self.init_file(
                size_sys, size_t, size_y, size_temp_y, size_aux_state, num_bcf_terms
            )
        else:
            if not read_only:
                try:
                    self._test_file_version(hdf5_name)
                except Exception as e:
                    print("test_file_version FAILED with exception {}".format(e))
                    r = input("to ignore the error type y")
                    if r != "y":
                        raise

        # print("read_only", read_only)

        if read_only:
            # print("before h5py.File(self.hdf5_name, 'r', swmr=True, libver='latest')")
            self.h5File = h5py.File(self.hdf5_name, "r", swmr=True, libver="latest")
            # print("after")
            # print(self.h5File)
        else:
            # print("before h5py.File(self.hdf5_name, 'r+', libver='latest')")
            # print("open r+", self.hdf5_name)
            try:
                self.h5File = h5py.File(self.hdf5_name, "r+", libver="latest")
            except OSError:
                print("FAILED to open h5 file '{}'".format(self.hdf5_name))
                raise
            self.h5File.swmr_mode = True
            # print("after")
            # print(self.h5File)

        try:
            self.stoc_traj = self.h5File["/stoc_traj"]
            self.rho_t_accum = self.h5File["/rho_t_accum"]
            self.rho_t_accum_part = self.h5File["/rho_t_accum_part"]
            self.rho_t_accum_part_tracker = self.h5File["/rho_t_accum_part_tracker"]
            self.samples = self.h5File["/samples"]
            self.largest_idx = self.h5File["/largest_idx"]
            self.tracker = self.h5File["/tracker"]
            self.y = self.h5File["/y"]
            self.temp_y = self.h5File["/temp_y"]

            if size_aux_state != 0:
                self.aux_states = self.h5File["/aux_states"]
            else:
                self.aux_states = None
                if "aux_states" in self.h5File:
                    raise TypeError(
                        "HIData with aux_states=0 finds h5 file with /aux_states"
                    )

            if num_bcf_terms != 0:
                self.stoc_proc = self.h5File["/stoc_proc"]
            else:
                self.stoc_proc = None
                if "stoc_proc" in self.h5File:
                    raise TypeError(
                        "HIData init FAILED: num_bcf_terms=0 but h5 file {} has /stoc_proc".format(
                            self.hdf5_name
                        )
                    )

            self.time = self.h5File["/time"]
            self.time_set = self.h5File["/time_set"]
        except KeyError:
            print("KeyError in  hdf5 file '{}'".format(hdf5_name))
            raise

        self.size_t = size_t
        self.size_sys = size_sys
        self.size_aux_state_plus_sys = size_aux_state + size_sys
        self.num_bcf_terms = num_bcf_terms
        self.size_y = size_y
        self.size_temp_y = size_temp_y

        self._idx_cnt = len(self.tracker)
        self._idx_rho_t_accum_part_tracker_cnt = len(self.rho_t_accum_part_tracker)

    def init_file(
        self, size_sys, size_t, size_y, size_temp_y, size_aux_state=0, num_bcf_terms=0
    ):

        with h5py.File(self.hdf5_name, "w", libver="latest") as h5File:
            # mode 'x': Create file, fail if exists

            if not self.accum_only:
                size_stoc_traj = HIData_default_size_stoc_traj
            else:
                # need at least one stoch traj to show k_max convergence
                size_stoc_traj = 1

            # size_stoc_traj may be overwritten HIData_default_size_stoc_traj to account for accum_only
            h5File.create_dataset(
                "stoc_traj",
                (size_stoc_traj, size_t, size_sys),
                dtype=np.complex128,
                maxshape=(None, size_t, size_sys),
                chunks=(1, size_t, size_sys),
            )
            h5File.create_dataset(
                "y",
                (size_stoc_traj, size_y),
                dtype=np.complex128,
                maxshape=(None, size_y),
                chunks=(1, size_y),
            )
            h5File.create_dataset(
                "temp_y",
                (size_stoc_traj, size_temp_y),
                dtype=np.complex128,
                maxshape=(None, size_temp_y),
                chunks=(1, size_temp_y),
            )
            h5File.create_dataset(
                "rho_t_accum", (size_t, size_sys, size_sys), dtype=np.complex128
            )
            h5File.create_dataset(
                "rho_t_accum_part",
                (HIData_default_size_rho_t_accum_part, size_t, size_sys, size_sys),
                dtype=np.complex128,
                maxshape=(None, size_t, size_sys, size_sys),
                chunks=(1, size_t, size_sys, size_sys),
            )

            h5File.create_dataset(
                "rho_t_accum_part_tracker",
                data=HIData_default_size_rho_t_accum_part * [False],
                dtype="bool",
                maxshape=(None,),
            )

            h5File.create_dataset("samples", (1,), dtype=np.uint32)
            h5File.create_dataset("largest_idx", (1,), dtype=np.uint32)

            h5File.create_dataset(
                "tracker",
                data=HIData_default_size_stoc_traj * [False],
                dtype="bool",
                maxshape=(None,),
            )

            if size_aux_state != 0:
                # size_stoc_traj may be overwritten HIData_default_size_stoc_traj to account for accum_only
                print((size_stoc_traj, size_t, size_aux_state))
                h5File.create_dataset(
                    "aux_states",
                    (size_stoc_traj, size_t, size_aux_state),
                    dtype=np.complex128,
                    maxshape=(None, size_t, size_aux_state),
                    chunks=(1, size_t, size_aux_state),
                )
            if num_bcf_terms != 0:
                # size_stoc_traj may be overwritten HIData_default_size_stoc_traj to account for accum_only
                h5File.create_dataset(
                    "stoc_proc",
                    (size_stoc_traj, size_t, num_bcf_terms),
                    dtype=np.complex128,
                    maxshape=(None, size_t, num_bcf_terms),
                    chunks=(1, size_t, num_bcf_terms),
                )

            h5File.create_dataset("time", (size_t,), dtype=np.float64)
            h5File.create_dataset("time_set", (1,), dtype=np.bool)
            h5File["/time_set"][0] = False

    def _test_file_version(self, hdf5_name):
        p = get_process_accessing_file(hdf5_name)
        if len(p) > 0:
            # another process accesses the file, assume that the file has allready the new format,
            # since that other process has already changed it
            return

        with h5py.File(hdf5_name, libver="latest") as h5File:
            try:
                h5File.swmr_mode = True
            except ValueError as e:
                s = str(e)
                if s.startswith("File superblock version should be at least 3"):
                    print(
                        "got Value Error with msg 'File superblock version should be at least 3' -> change h5 file to new version"
                    )
                    h5File.close()
                    change_file_version_to_latest(hdf5_name)
            except Exception as e:
                pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def close(self):
        self.h5File.close()

    def has_sample(self, idx):
        if idx < self._idx_cnt:
            return self.tracker[idx]
        else:
            return False

    def _resize(self, size):
        self.tracker.resize(size=(size,))
        if not self.accum_only:
            self.stoc_traj.resize(size=(size, self.size_t, self.size_sys))
            self.y.resize(size=(size, self.size_y))
            self.temp_y.resize(size=(size, self.size_temp_y))
            if self.aux_states is not None:
                self.aux_states.resize(
                    size=(
                        size,
                        self.size_t,
                        self.size_aux_state_plus_sys - self.size_sys,
                    )
                )
            if self.stoc_proc is not None:
                # TODO: ask richard
                # self.stoc_proc.resize(size=(size, self.size_t, self.num_bcf_terms))
                pass
        self._idx_cnt = size

    def _inc_size(self, idx):
        if self._idx_cnt <= idx:
            new_idx_cnt = 2 * self._idx_cnt
            while new_idx_cnt <= idx:
                new_idx_cnt *= 2

            self._resize(new_idx_cnt)

    def _resize_rho_t_accum_part(self, size):
        self.rho_t_accum_part_tracker.resize(size=(size,))
        self.rho_t_accum_part.resize(
            size=(size, self.size_t, self.size_sys, self.size_sys)
        )
        self._idx_rho_t_accum_part_tracker_cnt = size

    def _inc_rho_t_accum_part_tracker_size(self, n):
        if self._idx_rho_t_accum_part_tracker_cnt <= n:
            new_idx_cnt = (
                self._idx_rho_t_accum_part_tracker_cnt
                + HIData_default_size_rho_t_accum_part
            )
            while new_idx_cnt <= n:
                new_idx_cnt += HIData_default_size_rho_t_accum_part
            self._resize_rho_t_accum_part(new_idx_cnt)

    def has_rho_t_accum_part(self, n):
        if n < self._idx_rho_t_accum_part_tracker_cnt:
            return self.rho_t_accum_part_tracker[n]
        else:
            return False

    def set_time(self, time, force=False):
        if (self.time_set[0] == False) or force:
            self.time[:] = time
            self.time_set[0] = True
        else:
            warnings.warn("time has already been set (ignore new time)")

    def get_time(self):
        if self.time_set[0] == False:
            raise RuntimeError("can not get time, time has not been set yet.")
        return self.time[:]

    def new_samples(self, idx, psi_all, result_type, normed, y, temp_y):
        self._inc_size(idx)
        if (not self.accum_only) or (idx == 0):
            c = psi_all.shape[1]
            self.y[idx] = y
            self.temp_y[idx] = temp_y
            if result_type == RESULT_TYPE_ZEROTH_ORDER_ONLY:
                self.stoc_traj[idx] = psi_all
            elif result_type == RESULT_TYPE_ZEROTH_ORDER_AND_ETA_LAMBDA:
                self.stoc_traj[idx] = psi_all[:, : self.size_sys]
                if c > self.size_sys:  # the linear HI has no stoc_proc data
                    self.stoc_proc[idx] = psi_all[:, self.size_sys :]
            elif result_type == RESULT_TYPE_ALL:
                self.stoc_traj[idx] = psi_all[:, : self.size_sys]
                self.aux_states[idx] = psi_all[
                    :, self.size_sys : self.size_aux_state_plus_sys
                ]
                if (
                    c > self.size_aux_state_plus_sys
                ):  # the linear HI has no stoc_proc data
                    self.stoc_proc = psi_all[:, self.size_aux_state_plus_sys :]

        n = is_int_power(self.samples[0] + 1, b=2)
        if n is not None:
            self._inc_rho_t_accum_part_tracker_size(n)

        with signalDelay.sig_delay([signal.SIGINT, signal.SIGTERM]):
            self.tracker[idx] = True
            self.largest_idx[0] = max(self.largest_idx[0], idx)

            self.rho_t_accum[:] += projector_psi_t(
                psi_all[:, : self.size_sys], normed=normed
            )
            self.samples[0] += 1

            if n is not None:
                self.rho_t_accum_part_tracker[n] = True
                self.rho_t_accum_part[n] = self.get_rho_t()

            self.h5File.flush()

    def get_rho_t(self, res=None):
        if res is None:
            res = np.empty(
                shape=(self.size_t, self.size_sys, self.size_sys), dtype=np.complex128
            )
        self.rho_t_accum.read_direct(dest=res)
        return res / self.get_samples()

    def get_rho_t_part(self, n):
        if self.has_rho_t_accum_part(n):
            return self.rho_t_accum_part[n]
        else:
            raise RuntimeError(
                "rho_t_accum_part with index {} has not been chrunched yet".format(n)
            )

    def get_samples(self):
        return self.samples[0]

    def clear(self):
        self._resize(HIData_default_size_stoc_traj)
        self._resize_rho_t_accum_part(HIData_default_size_rho_t_accum_part)
        self.time[0] = False
        self.rho_t_accum[:] = 0j
        self.samples[0] = 0
        self.tracker[:] = False
        self.rho_t_accum_part_tracker[:] = False
        self.largest_idx[0] = 0

    def get_stoc_traj(self, idx):
        if self.has_sample(idx):
            return self.stoc_traj[idx]
        else:
            raise RuntimeError(
                "sample with idx {} has not yet been chrunched".format(idx)
            )

    def get_sub_rho_t(self, idx_low, idx_high, normed, overwrite=False):
        name = "{}_{}".format(int(idx_low), int(idx_high))
        if overwrite and name in self.h5File:
            del self.h5File[name]

        if not name in self.h5File:
            smp = 0
            rho_t_accum = np.zeros(
                shape=(self.size_t, self.size_sys, self.size_sys), dtype=np.complex128
            )
            for i in range(idx_low, idx_high):
                if self.has_sample(i):
                    smp += 1
                    rho_t_accum += projector_psi_t(self.stoc_traj[i], normed=normed)
            rho_t = rho_t_accum / smp
            h5_data = self.h5File.create_dataset(
                name,
                shape=(self.size_t, self.size_sys, self.size_sys),
                dtype=np.complex128,
            )
            h5_data[:] = rho_t
            h5_data.attrs["smp"] = smp
        else:
            rho_t = np.empty(
                shape=(self.size_t, self.size_sys, self.size_sys), dtype=np.complex128
            )
            self.h5File[name].read_direct(dest=rho_t)
            smp = self.h5File[name].attrs["smp"]

        return rho_t, smp

    def rewrite_rho_t(self, normed):
        smp = 0
        rho_t_accum = np.zeros(
            shape=(self.size_t, self.size_sys, self.size_sys), dtype=np.complex128
        )
        for i in range(self.largest_idx[0] + 1):
            if self.has_sample(i):
                smp += 1
                rho_t_accum += projector_psi_t(self.stoc_traj[i], normed=normed)

        with signalDelay.sig_delay([signal.SIGINT, signal.SIGTERM]):
            self.samples[0] = smp
            self.rho_t_accum[:] = rho_t_accum


class HIMetaData(object):
    def __init__(self, hid_name, hid_path):
        self.name = hid_name
        self.path = os.path.join(hid_path, "__" + self.name)
        if os.path.exists(self.path):
            if not os.path.isdir(self.path):
                raise NotADirectoryError(
                    "the path '{}' exists but is not a directory".format(self.path)
                )
        else:
            os.mkdir(self.path)
        self._fname = os.path.join(self.path, self.name + ".sqld")
        self._l = 8
        self.db = sqd.SqliteDict(filename=self._fname, autocommit=False)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.db.close()

    def close(self):
        self.db.close()

    def _new_rand_file_name(self, pre, end):
        if not os.path.exists(self.path):
            os.mkdir(self.path)

        c = 0
        while True:
            fname = pre + rand_str(self._l) + end
            full_name = os.path.join(self.path, fname)

            try:
                open(
                    full_name, "x"
                ).close()  # open for exclusive creation, failing if the file already exists
                return fname
            except FileExistsError:
                pass

            c += 1
            if c > 10:
                self._l += 2
                c = 0
                print("INFO: increase random file name length to", self._l)

    def get_hashed_key(self, key):
        binkey = bf.dump(key)
        return hashlib.sha384(binkey).hexdigest()

    def get_HIData_fname(self, key):
        hashed_key = self.get_hashed_key(key)
        # print("hashed_key", hashed_key)
        try:
            hdf5_name = self.db[hashed_key][0]
        except KeyError:
            hdf5_name = self._new_rand_file_name(pre=self.name + "_", end=".h5")
            self.db[hashed_key] = (hdf5_name, key)
            self.db.commit()

        return hdf5_name

    def get_HIData(self, key, read_only=False):
        hdf5_name = self.get_HIData_fname(key)

        if key.HiP.result_type == RESULT_TYPE_ZEROTH_ORDER_ONLY:
            size_aux_state = 0
            num_bcf_terms = 0
        elif key.HiP.result_type == RESULT_TYPE_ZEROTH_ORDER_AND_ETA_LAMBDA:
            size_aux_state = 0
            num_bcf_terms = key.SysP.len_gw
        elif key.HiP.result_type == RESULT_TYPE_ALL:
            num_bcf_terms = key.SysP.len_gw
            size_aux_state = combLib.number_of_all_combinations_old(
                n=num_bcf_terms, k_max=key.HiP.k_max
            )

        if not key.HiP.nonlinear:
            num_bcf_terms = 0

        if hasattr(key.HiP, "accum_only"):
            accum_only = key.HiP.accum_only
        else:
            accum_only = False
        return HIData(
            os.path.join(self.path, hdf5_name),
            size_sys=key.SysP.H_sys.shape[0],
            size_t=key.IntP.t_steps,
            size_aux_state=size_aux_state,
            num_bcf_terms=num_bcf_terms,
            accum_only=accum_only,
            read_only=read_only,
            size_y=key.Eta.get_num_y(),
            size_temp_y=key.EtaTherm.get_num_y() if key.EtaTherm else 1,
        )


def get_process_accessing_file(fname):
    cmd = 'lsof "{}"'.format(fname)
    if not os.path.exists(fname):
        raise FileNotFoundError(fname)
    r = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        shell=True,
        universal_newlines=True,
        encoding="utf8",
    )
    if r.stderr != "":
        log.info("command '{}' stderr:\n{}".format(cmd, r.stderr))

    if r.returncode == 0:
        # success
        out = r.stdout.split("\n")
        head = out[0].split()
        idx_PID = head.index("PID")
        pid_list = []
        for l in out[1:]:
            l = l.split()
            if len(l) == 0:
                continue

            pid_list.append(int(l[idx_PID]))
        return pid_list
    else:
        # failure, also happens when no process was found
        if r.stdout == "":
            log.info(
                "lsof has non-zero return code and empty stdout -> assume not process has access to file"
            )
            return []


def get_rand_file_name(l=12, must_not_exists=True):
    n = len(CHAR_SET)
    while True:
        fname = ""
        for i in range(l):
            fname += CHAR_SET[np.random.randint(0, n)]
        if (not os.path.exists(fname)) or (not must_not_exists):
            return fname


def change_file_version_to_latest(h5fname):

    pid_list = get_process_accessing_file(h5fname)
    if len(pid_list) > 0:
        raise RuntimeError(
            "can not change file version! the following processes have access to the file: {}".format(
                pid_list
            )
        )

    rand_fname = get_rand_file_name()
    with h5py.File(rand_fname, "w", libver="latest") as f_new:
        with h5py.File(h5fname, "r") as f_old:
            for i in f_old:
                f_old.copy(
                    i,
                    f_new["/"],
                    shallow=False,
                    expand_soft=False,
                    expand_external=False,
                    expand_refs=False,
                    without_attrs=False,
                )
            for k, v in f_old.attrs.items():
                f_new.attrs[k] = v
    print("updated h5 file {} to latest version".format(os.path.abspath(h5fname)))

    shutil.move(h5fname, h5fname + rand_fname + ".old")
    shutil.move(rand_fname, h5fname)
    os.remove(h5fname + rand_fname + ".old")

    assert not os.path.exists(rand_fname)