support not dumping the aggregate at every step

This commit is contained in:
Valentin Boettcher 2022-12-16 14:05:39 -05:00
parent def9e607d8
commit 86ef9fa646
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -421,8 +421,12 @@ class Model(ABC):
stream_pipe: str = "results.fifo",
results_directory: str = "results",
**kwargs,
) -> Optional[
tuple[EnsembleValue, EnsembleValue, EnsembleValue, EnsembleValue, EnsembleValue]
) -> tuple[
Optional[EnsembleValue],
Optional[EnsembleValue],
Optional[EnsembleValue],
Optional[EnsembleValue],
Optional[EnsembleValue],
]:
"""Calculates the bath energy flow, the interaction energy,
the interaction power, the system energy and the system power
@ -454,13 +458,19 @@ class Model(ABC):
self.system.derivative(), self.t, normalize=True, real=True
)
flow, interaction, interaction_power, system, system_power = (
None,
None,
None,
None,
None,
)
aggregates = [None for _ in range(5)]
paths = [
os.path.join(results_directory, path)
for path in [
self.online_flow_name,
self.online_interaction_name,
self.online_interaction_power_name,
self.online_system_name,
self.online_system_power_name,
]
]
flow, interaction, interaction_power, system, system_power = aggregates
with open(stream_pipe, "rb") as fifo:
while True:
@ -473,52 +483,36 @@ class Model(ABC):
_,
rng_seed,
) = pickle.load(fifo)
flow = hopsflow.util.ensemble_mean_online(
(psi0, aux_states, rng_seed),
os.path.join(results_directory, self.online_flow_name),
flow_worker,
idx,
**kwargs,
)
interaction = hopsflow.util.ensemble_mean_online(
(psi0, aux_states, rng_seed),
os.path.join(results_directory, self.online_interaction_name),
interaction_worker,
idx,
**kwargs,
)
for path, (i, aggregator), args in zip(
paths,
enumerate(aggregates),
[
((psi0, aux_states, rng_seed), flow_worker),
((psi0, aux_states, rng_seed), interaction_worker),
((psi0, aux_states, rng_seed), interaction_power_worker),
((psi0), system_worker),
((psi0), system_power_worker),
],
):
interaction_power = hopsflow.util.ensemble_mean_online(
(psi0, aux_states, rng_seed),
os.path.join(
results_directory, self.online_interaction_power_name
),
interaction_power_worker,
idx,
**kwargs,
)
system = hopsflow.util.ensemble_mean_online(
(psi0),
os.path.join(results_directory, self.online_system_name),
system_worker,
idx,
**kwargs,
)
system_power = hopsflow.util.ensemble_mean_online(
(psi0),
os.path.join(results_directory, self.online_system_power_name),
system_power_worker,
idx,
**kwargs,
)
aggregates[i] = hopsflow.util.ensemble_mean_online(
*args, save=path, aggregator=aggregator, i=idx, **kwargs
)
except EOFError:
break
return flow, interaction, interaction_power, system, system_power
for path, aggregate in zip(paths, aggregates):
if aggregate is not None:
aggregate.dump(path)
return tuple(
[
(aggregate.ensemble_value if aggregate else None)
for aggregate in aggregates
]
)
def interaction_energy(
self, data: Optional[HIData] = None, results_path: str = "results", **kwargs