add dependency change detection

This commit is contained in:
Valentin Boettcher 2021-11-11 16:09:04 +01:00
parent ca79d784b6
commit 5d0ea7b1b0
3 changed files with 179 additions and 0 deletions

View file

@ -0,0 +1,165 @@
import os
import sys
import json
import hashlib
from typing import List, Optional, Callable, Dict, Any, Tuple, Iterable
from types import ModuleType
import pickle
def sha1_file(filepath: str) -> str:
sha = hashlib.sha1()
with open(filepath, "rb") as f:
while True:
block = f.read(2 ** 10) # Magic number: one-megabyte blocks.
if not block:
break
sha.update(block)
return sha.hexdigest()
def join_hashes(hashes: Iterable[str]) -> str:
sha = hashlib.sha1()
sha.update(("".join(hashes)).encode("utf-8"))
return sha.hexdigest()
def module_path(mod: ModuleType) -> str:
return os.path.dirname(mod.__file__)
def hash_dir(
dir_path: str, rem_prefix: str = "", add_prefix: str = ""
) -> Tuple[str, Dict[str, str]]:
hashes: Dict[str, str] = dict()
for path, dirs, files in os.walk(dir_path):
for file in sorted(
files
): # we sort to guarantee that files will always go in the same order
full = os.path.join(path, file)
hashes[add_prefix + full.removeprefix(rem_prefix)] = sha1_file(full)
for dir in sorted(
dirs
): # we sort to guarantee that dirs will always go in the same order
if "." in dir or "__pycache__" in dir:
continue
full = os.path.join(path, dir)
hash, sub_hashes = hash_dir(full, rem_prefix, add_prefix)
hashes[add_prefix + full.removeprefix(rem_prefix)] = hash
hashes |= sub_hashes
break # we only need one iteration - to get files and dirs in current directory
return join_hashes(hashes.values()), hashes
class Dependencies:
def __init__(
self,
modules: List[ModuleType] = [],
files: List[str] = [],
dirs: List[str] = [],
conditions: List[Tuple[str, Callable[[Any], bool]]] = [],
hash_file: str = "",
):
self._modules = modules
self._dirs = dirs
self._files = files
self._conditions = conditions
self._hash_file = hash_file
if hash_file == "":
self._hash_file = os.path.join(sys.path[0], ".dep_hash")
def write_hash(self) -> None:
with open(self._hash_file, "w") as f:
json.dump(self.get_hash(), f, indent=2)
def load_hash(self) -> Tuple[str, Dict[str, str]]:
with open(self._hash_file, "r") as f:
return json.load(f)
def _is_fresh(self, hash: str, saved_hash: str) -> bool:
return (hash == saved_hash) and all(
[cond(self) for _, cond in self._conditions]
)
def is_fresh(self) -> bool:
try:
saved_hash = self.load_hash()
except FileNotFoundError:
return True
return self._is_fresh(saved_hash[0], self.get_hash()[0])
def get_hash(self) -> Tuple[str, Dict[str, str]]:
hashes: Dict[str, str] = dict()
for dir in self._dirs:
hash, sub_hashes = hash_dir(dir)
hashes[dir] = hash
hashes |= sub_hashes
names: List[str] = []
for mod in self._modules:
path = module_path(mod)
name = mod.__name__
while name in names:
name += "_"
names.append(name)
mod_id = f"<{name}>"
hash, sub_hashes = hash_dir(path, path, mod_id)
hashes[mod_id] = hash
hashes |= sub_hashes
for file in self._files:
hashes[file] = sha1_file(file)
hash = join_hashes(hashes.values())
return hash, hashes
def report(self) -> None:
hash, hashes = self.get_hash()
try:
saved_hash, saved_hashes = self.load_hash()
except FileNotFoundError:
print("No previous hash data found!")
return
fresh = self._is_fresh(hash[0], saved_hash[0])
print("Is fresh:", fresh)
print("Overall Hash:", hash)
if not fresh:
for name, hash in hashes.items():
other_hash = saved_hashes.get(name, "")
if other_hash != hash:
print("Deviation:", name)
print(" was: ", other_hash)
print(" is : ", hash)
for name, cond in self._conditions:
success = cond(self)
if not success:
print(f"Condition '{name}' failed!")
@property
def modules(self):
return self._modules
@property
def files(self):
return self._files
@property
def dirs(self):
return self._dirs
@property
def conditions(self):
return self._conditions

View file

@ -0,0 +1,14 @@
import stg_helper
from types import ModuleType
from typing import Callable
def get_n_samples(stg: ModuleType) -> int:
"""Get the number of samples from ``stg``."""
with stg_helper.get_hierarchy_data(stg, read_only=True) as hd:
return hd.get_samples()
def has_all_samples(stg: ModuleType) -> bool:
return stg.__HI_number_of_samples == get_n_samples(stg)
def has_all_samples_checker(stg: ModuleType) -> Callable[..., bool]:
return "Has all samples?", lambda _: has_all_samples(stg)