do not use the input data for the cache -> data not required

This commit is contained in:
Valentin Boettcher 2022-07-22 17:06:11 +02:00
parent 62fea6ba82
commit 05e00f0f64

View file

@ -714,8 +714,6 @@ def ensemble_mean(
gc_sleep: float = 0.1, gc_sleep: float = 0.1,
) -> EnsembleValue: ) -> EnsembleValue:
results = [] results = []
first_result = function(next(arg_iter))
aggregate = WelfordAggregator(function(next(arg_iter)))
path = None path = None
json_meta_info = json.dumps( json_meta_info = json.dumps(
@ -728,19 +726,6 @@ def ensemble_mean(
cls=JSONEncoder, cls=JSONEncoder,
ensure_ascii=False, ensure_ascii=False,
).encode("utf-8") ).encode("utf-8")
json_meta_info_old = json.dumps(
dict(
N=N,
every=every,
function_name=function.__name__,
first_iterator_value="<not serializable>",
),
cls=JSONEncoder,
ensure_ascii=False,
default=lambda obj: obj.__dict__
if hasattr(obj, "__dict__")
else "<not serializable>",
).encode("utf-8")
if save: if save:
key = hashlib.sha256(json_meta_info).hexdigest() key = hashlib.sha256(json_meta_info).hexdigest()
@ -748,17 +733,13 @@ def ensemble_mean(
f"{save}_{function.__name__}_{N}_{every}_{key}.npy" f"{save}_{function.__name__}_{N}_{every}_{key}.npy"
) )
key_old = hashlib.sha256(json_meta_info_old).hexdigest()
path_old = Path("results") / Path(
f"{save}_{function.__name__}_{N}_{every}_{key_old}.npy"
)
if path_old.exists():
shutil.move(path_old, path)
if not overwrite_cache and path.exists(): if not overwrite_cache and path.exists():
logging.debug(f"Loading cache from: {path}") logging.debug(f"Loading cache from: {path}")
return EnsembleValue(np.load(str(path), allow_pickle=True)) results = np.load(str(path), allow_pickle=True)
return EnsembleValue([tuple(res) for res in results])
first_result = function(next(arg_iter))
aggregate = WelfordAggregator(first_result)
if N == 1: if N == 1:
return EnsembleValue([(1, aggregate.mean, np.zeros_like(aggregate.mean))]) return EnsembleValue([(1, aggregate.mean, np.zeros_like(aggregate.mean))])
@ -767,7 +748,7 @@ def ensemble_mean(
chunk_size = max(100000 // (first_result.size * first_result.itemsize), 1) chunk_size = max(100000 // (first_result.size * first_result.itemsize), 1)
logging.debug(f"Setting chunk size to {chunk_size}.") logging.debug(f"Setting chunk size to {chunk_size}.")
num_chunks = math.ceil(N / chunk_size) if N is not None else None num_chunks = math.ceil((N - 1) / chunk_size) if N is not None else None
chunk_iterator = iter( chunk_iterator = iter(
tqdm( tqdm(
zip( zip(