support not dumping the aggregate at every step

This commit is contained in:
Valentin Boettcher 2022-12-16 14:05:39 -05:00
parent def9e607d8
commit 86ef9fa646
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -421,8 +421,12 @@ class Model(ABC):
stream_pipe: str = "results.fifo", stream_pipe: str = "results.fifo",
results_directory: str = "results", results_directory: str = "results",
**kwargs, **kwargs,
) -> Optional[ ) -> tuple[
tuple[EnsembleValue, EnsembleValue, EnsembleValue, EnsembleValue, EnsembleValue] Optional[EnsembleValue],
Optional[EnsembleValue],
Optional[EnsembleValue],
Optional[EnsembleValue],
Optional[EnsembleValue],
]: ]:
"""Calculates the bath energy flow, the interaction energy, """Calculates the bath energy flow, the interaction energy,
the interaction power, the system energy and the system power the interaction power, the system energy and the system power
@ -454,13 +458,19 @@ class Model(ABC):
self.system.derivative(), self.t, normalize=True, real=True self.system.derivative(), self.t, normalize=True, real=True
) )
flow, interaction, interaction_power, system, system_power = ( aggregates = [None for _ in range(5)]
None, paths = [
None, os.path.join(results_directory, path)
None, for path in [
None, self.online_flow_name,
None, self.online_interaction_name,
) self.online_interaction_power_name,
self.online_system_name,
self.online_system_power_name,
]
]
flow, interaction, interaction_power, system, system_power = aggregates
with open(stream_pipe, "rb") as fifo: with open(stream_pipe, "rb") as fifo:
while True: while True:
@ -473,52 +483,36 @@ class Model(ABC):
_, _,
rng_seed, rng_seed,
) = pickle.load(fifo) ) = pickle.load(fifo)
flow = hopsflow.util.ensemble_mean_online(
(psi0, aux_states, rng_seed),
os.path.join(results_directory, self.online_flow_name),
flow_worker,
idx,
**kwargs,
)
interaction = hopsflow.util.ensemble_mean_online( for path, (i, aggregator), args in zip(
(psi0, aux_states, rng_seed), paths,
os.path.join(results_directory, self.online_interaction_name), enumerate(aggregates),
interaction_worker, [
idx, ((psi0, aux_states, rng_seed), flow_worker),
**kwargs, ((psi0, aux_states, rng_seed), interaction_worker),
) ((psi0, aux_states, rng_seed), interaction_power_worker),
((psi0), system_worker),
((psi0), system_power_worker),
],
):
interaction_power = hopsflow.util.ensemble_mean_online( aggregates[i] = hopsflow.util.ensemble_mean_online(
(psi0, aux_states, rng_seed), *args, save=path, aggregator=aggregator, i=idx, **kwargs
os.path.join(
results_directory, self.online_interaction_power_name
),
interaction_power_worker,
idx,
**kwargs,
)
system = hopsflow.util.ensemble_mean_online(
(psi0),
os.path.join(results_directory, self.online_system_name),
system_worker,
idx,
**kwargs,
)
system_power = hopsflow.util.ensemble_mean_online(
(psi0),
os.path.join(results_directory, self.online_system_power_name),
system_power_worker,
idx,
**kwargs,
) )
except EOFError: except EOFError:
break break
return flow, interaction, interaction_power, system, system_power for path, aggregate in zip(paths, aggregates):
if aggregate is not None:
aggregate.dump(path)
return tuple(
[
(aggregate.ensemble_value if aggregate else None)
for aggregate in aggregates
]
)
def interaction_energy( def interaction_energy(
self, data: Optional[HIData] = None, results_path: str = "results", **kwargs self, data: Optional[HIData] = None, results_path: str = "results", **kwargs