switch hopsflow to ray

This commit is contained in:
Valentin Boettcher 2022-03-29 18:51:30 +02:00
parent 679f5be2f2
commit f5a9889f8e
2 changed files with 74 additions and 96 deletions

View file

@ -12,9 +12,10 @@ from . import util
from typing import Optional, Tuple, Iterator, Union
from stocproc import StocProc
import itertools
import ray
###############################################################################
# Interface/Parameter Object#
# Interface/Parameter Object #
###############################################################################
@ -368,33 +369,10 @@ def interaction_energy_therm(run: HOPSRun, therm_run: ThermalRunParams) -> np.nd
###############################################################################
def _heat_flow_ensemble_body(
ψs: tuple[np.ndarray, np.ndarray, int],
params: SystemParams,
thermal: Optional[ThermalParams],
only_therm: bool,
):
ψ_0, ψ_1, seed = ψs
run = HOPSRun(ψ_0, ψ_1, params)
flow = (
flow_trajectory_coupling(run, params)
if not only_therm
else np.zeros(ψ_0.shape[0])
)
if thermal is not None:
therm_run = ThermalRunParams(thermal, seed)
flow += flow_trajectory_therm(run, therm_run)
return flow
def heat_flow_ensemble(
ψ_0s: Iterator[np.ndarray],
ψ_1s: Iterator[np.ndarray],
params: SystemParams,
N: Optional[int],
therm_args: Optional[Tuple[Iterator[np.ndarray], ThermalParams]] = None,
only_therm: bool = False,
**kwargs,
@ -417,39 +395,43 @@ def heat_flow_ensemble(
if therm_args is None and only_therm:
raise ValueError("Can't calculate only thermal part if therm_args are None.")
thermal = therm_args[1] if therm_args else None
params_ref = ray.put(params)
thermal_ref = ray.put(thermal)
def flow_worker(ψs: tuple[np.ndarray, np.ndarray, int]):
ψ_0, ψ_1, seed = ψs
params = ray.get(params_ref)
thermal = ray.get(thermal_ref)
run = HOPSRun(ψ_0, ψ_1, params)
flow = (
flow_trajectory_coupling(run, params)
if not only_therm
else np.zeros(ψ_0.shape[0])
)
if thermal is not None:
therm_run = ThermalRunParams(thermal, seed)
flow += flow_trajectory_therm(run, therm_run)
return flow
return util.ensemble_mean(
iter(zip(ψ_0s, ψ_1s, therm_args[0]))
if therm_args
else iter(zip(ψ_0s, ψ_1s, itertools.repeat(0))),
_heat_flow_ensemble_body,
N,
(params, therm_args[1] if therm_args else None, only_therm),
flow_worker,
**kwargs,
)
def _interaction_energy_ensemble_body(
ψs: Tuple[np.ndarray, np.ndarray, int],
params: SystemParams,
thermal: Optional[ThermalParams],
) -> np.ndarray:
ψ_0, ψ_1, seeds = ψs
run = HOPSRun(ψ_0, ψ_1, params)
energy = interaction_energy_coupling(run, params)
if thermal is not None:
therm_run = ThermalRunParams(thermal, seeds)
energy += interaction_energy_therm(run, therm_run)
return energy
def interaction_energy_ensemble(
ψ_0s: Iterator[np.ndarray],
ψ_1s: Iterator[np.ndarray],
params: SystemParams,
N: Optional[int],
therm_args: Optional[Tuple[Iterator[int], ThermalParams]] = None,
**kwargs,
) -> util.EnsembleReturn:
@ -467,13 +449,32 @@ def interaction_energy_ensemble(
:returns: the value of the flow for each time step
"""
thermal = therm_args[1] if therm_args else None
params_ref = ray.put(params)
thermal_ref = ray.put(thermal)
def interaction_energy_task(
ψs: Tuple[np.ndarray, np.ndarray, int],
) -> np.ndarray:
ψ_0, ψ_1, seeds = ψs
params = ray.get(params_ref)
thermal = ray.get(thermal_ref)
run = HOPSRun(ψ_0, ψ_1, params)
energy = interaction_energy_coupling(run, params)
if thermal is not None:
therm_run = ThermalRunParams(thermal, seeds)
energy += interaction_energy_therm(run, therm_run)
return energy
return util.ensemble_mean(
iter(zip(ψ_0s, ψ_1s, therm_args[0]))
if therm_args
else iter(zip(ψ_0s, ψ_1s, itertools.repeat(0))),
_interaction_energy_ensemble_body,
N,
(params, therm_args[1] if therm_args else None),
interaction_energy_task,
**kwargs,
)

View file

@ -18,6 +18,8 @@ import json
from functools import singledispatch, singledispatchmethod
from scipy.stats import NumericalInverseHermite
import copy
import ray
Aggregate = tuple[int, np.ndarray, np.ndarray]
EnsembleReturn = Union[Aggregate, list[Aggregate]]
@ -344,7 +346,6 @@ def operator_expectation(
def operator_expectation_ensemble(
ψs: Iterator[np.ndarray],
op: np.ndarray,
N: Optional[int],
normalize: bool = False,
real: bool = False,
**kwargs,
@ -362,9 +363,10 @@ def operator_expectation_ensemble(
:returns: the expectation value
"""
return ensemble_mean(
ψs, sandwhich_operator, N, const_args=(op, normalize, real), **kwargs
)
def op_exp_task(ψ):
return sandwhich_operator(ψ, op, normalize, real)
return ensemble_mean(ψs, op_exp_task, **kwargs)
def mulitply_hierarchy(left: np.ndarray, right: np.ndarray) -> np.ndarray:
@ -429,29 +431,6 @@ def integrate_array(
# Ensemble Mean #
###############################################################################
_ENSEMBLE_MEAN_ARGS: tuple = tuple()
_ENSEMBLE_MEAN_KWARGS: dict = dict()
def _ensemble_mean_call(arg) -> np.ndarray:
global _ENSEMBLE_MEAN_ARGS
global _ENSEMBLE_MEAN_KWARGS
return _ENSEMBLE_FUNC(arg, *_ENSEMBLE_MEAN_ARGS, **_ENSEMBLE_MEAN_KWARGS)
def _ensemble_mean_init(func: Callable, args: tuple, kwargs: dict):
global _ENSEMBLE_FUNC
global _ENSEMBLE_MEAN_ARGS
global _ENSEMBLE_MEAN_KWARGS
_ENSEMBLE_FUNC = func
_ENSEMBLE_MEAN_ARGS = args
_ENSEMBLE_MEAN_KWARGS = kwargs
# TODO: Use paramspec
class WelfordAggregator:
__slots__ = ["n", "mean", "_m_2"]
@ -521,24 +500,19 @@ def ensemble_mean(
arg_iter: Iterator[Any],
function: Callable[..., np.ndarray],
N: Optional[int] = None,
const_args: tuple = tuple(),
const_kwargs: dict = dict(),
n_proc: Optional[int] = None,
every: Optional[int] = None,
save: Optional[str] = None,
overwrite_cache: bool = False,
) -> EnsembleReturn:
results = []
aggregate = WelfordAggregator(function(next(arg_iter), *const_args))
aggregate = WelfordAggregator(function(next(arg_iter)))
path = None
json_meta_info = json.dumps(
dict(
N=N,
every=every,
const_args=const_args,
const_kwargs=const_kwargs,
function_name=function.__name__,
first_iterator_value=aggregate.mean,
),
@ -563,26 +537,29 @@ def ensemble_mean(
results = [(1, aggregate.mean, np.zeros_like(aggregate.mean))]
return results if every else results[0]
if not n_proc:
n_proc = multiprocessing.cpu_count()
remote_function = ray.remote(function)
with multiprocessing.Pool(
processes=n_proc,
initializer=_ensemble_mean_init,
initargs=(function, const_args, const_kwargs),
) as pool:
result_iter = pool.imap_unordered(
_ensemble_mean_call,
handles = [
remote_function.remote(arg)
for arg in tqdm(
itertools.islice(arg_iter, None, N - 1 if N else None),
total=N - 1 if N else None,
desc="Loading",
)
]
for res in tqdm(result_iter, total=(N - 1) if N else None):
aggregate.update(res)
progress = tqdm(total=len(handles), desc="Processing")
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
results.append(
(aggregate.n, aggregate.mean.copy(), aggregate.ensemble_std.copy())
)
while len(handles):
done_id, handles = ray.wait(handles, fetch_local=True)
res = ray.get(done_id[0])
aggregate.update(res)
progress.update()
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
results.append(
(aggregate.n, aggregate.mean.copy(), aggregate.ensemble_std.copy())
)
if not every:
results = results[-1]