fix repeatedly adding the final value

This commit is contained in:
Valentin Boettcher 2022-05-16 19:59:03 +02:00
parent 7d8be96fc8
commit 8fac8a2926

View file

@ -601,25 +601,6 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
return res, index
def process_chunks(highest_index, aggregate, chunks, every, results, N):
while highest_index in chunks:
for res in chunks[highest_index]:
aggregate.update(res)
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
results.append(
(
aggregate.n,
aggregate.mean.copy(),
aggregate.ensemble_std.copy(),
)
)
del chunks[highest_index]
highest_index += 1
return highest_index, results
def ensemble_mean(
arg_iter: Iterator[Any],
function: Callable[..., np.ndarray],
@ -627,7 +608,7 @@ def ensemble_mean(
every: Optional[int] = None,
save: Optional[str] = None,
overwrite_cache: bool = False,
chunk_size: Optinal[int] = None,
chunk_size: Optional[int] = None,
) -> EnsembleValue:
results = []
first_result = function(next(arg_iter))
@ -714,7 +695,7 @@ def ensemble_mean(
except StopIteration:
next_val = None
if len(processing_refs) > in_flight:
if len(processing_refs) > in_flight or not next_val:
finished, processing_refs = ray.wait(
processing_refs,
num_returns=len(processing_refs) - in_flight
@ -733,9 +714,28 @@ def ensemble_mean(
finished = []
if has_downloaded:
highest_index, results = process_chunks(
highest_index, aggregate, chunks, every, results, N
)
while highest_index in chunks:
next_chunk = chunks[highest_index]
del chunks[highest_index]
len_chunk = len(next_chunk) - 1
for i, res in enumerate(next_chunk):
aggregate.update(res)
if (
every is not None
and (aggregate.n % every) == 0
or aggregate.n == N
or (not next_val and not chunks and i == len_chunk)
):
results.append(
(
aggregate.n,
aggregate.mean.copy(),
aggregate.ensemble_std.copy(),
)
)
highest_index += 1
if next_val:
chunk_ref = ray.put(next_val[0])