mirror of
https://github.com/vale981/hopsflow
synced 2025-03-04 16:31:38 -05:00
fix repeatedly adding the final value
This commit is contained in:
parent
7d8be96fc8
commit
8fac8a2926
1 changed files with 24 additions and 24 deletions
|
@ -601,25 +601,6 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
|
|||
return res, index
|
||||
|
||||
|
||||
def process_chunks(highest_index, aggregate, chunks, every, results, N):
|
||||
while highest_index in chunks:
|
||||
for res in chunks[highest_index]:
|
||||
aggregate.update(res)
|
||||
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
|
||||
results.append(
|
||||
(
|
||||
aggregate.n,
|
||||
aggregate.mean.copy(),
|
||||
aggregate.ensemble_std.copy(),
|
||||
)
|
||||
)
|
||||
|
||||
del chunks[highest_index]
|
||||
highest_index += 1
|
||||
|
||||
return highest_index, results
|
||||
|
||||
|
||||
def ensemble_mean(
|
||||
arg_iter: Iterator[Any],
|
||||
function: Callable[..., np.ndarray],
|
||||
|
@ -627,7 +608,7 @@ def ensemble_mean(
|
|||
every: Optional[int] = None,
|
||||
save: Optional[str] = None,
|
||||
overwrite_cache: bool = False,
|
||||
chunk_size: Optinal[int] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> EnsembleValue:
|
||||
results = []
|
||||
first_result = function(next(arg_iter))
|
||||
|
@ -714,7 +695,7 @@ def ensemble_mean(
|
|||
except StopIteration:
|
||||
next_val = None
|
||||
|
||||
if len(processing_refs) > in_flight:
|
||||
if len(processing_refs) > in_flight or not next_val:
|
||||
finished, processing_refs = ray.wait(
|
||||
processing_refs,
|
||||
num_returns=len(processing_refs) - in_flight
|
||||
|
@ -733,9 +714,28 @@ def ensemble_mean(
|
|||
|
||||
finished = []
|
||||
if has_downloaded:
|
||||
highest_index, results = process_chunks(
|
||||
highest_index, aggregate, chunks, every, results, N
|
||||
)
|
||||
while highest_index in chunks:
|
||||
next_chunk = chunks[highest_index]
|
||||
del chunks[highest_index]
|
||||
|
||||
len_chunk = len(next_chunk) - 1
|
||||
for i, res in enumerate(next_chunk):
|
||||
aggregate.update(res)
|
||||
if (
|
||||
every is not None
|
||||
and (aggregate.n % every) == 0
|
||||
or aggregate.n == N
|
||||
or (not next_val and not chunks and i == len_chunk)
|
||||
):
|
||||
results.append(
|
||||
(
|
||||
aggregate.n,
|
||||
aggregate.mean.copy(),
|
||||
aggregate.ensemble_std.copy(),
|
||||
)
|
||||
)
|
||||
|
||||
highest_index += 1
|
||||
|
||||
if next_val:
|
||||
chunk_ref = ray.put(next_val[0])
|
||||
|
|
Loading…
Add table
Reference in a new issue