diff --git a/hopsflow/hopsflow.py b/hopsflow/hopsflow.py index 7529b04..47588a3 100644 --- a/hopsflow/hopsflow.py +++ b/hopsflow/hopsflow.py @@ -370,7 +370,7 @@ def heat_flow_ensemble( def _interaction_energy_ensemble_body( ψs: Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, int]], params: SystemParams, - thermal: ThermalParams, + thermal: Optional[ThermalParams], ) -> np.ndarray: ψ_0, ψ_1 = ψs[0:2] @@ -381,7 +381,7 @@ def _interaction_energy_ensemble_body( run = HOPSRun(ψ_0, ψ_1, params) energy = interaction_energy_coupling(run, params) - if isinstance(ys, int): + if thermal and (ys is not None): therm_run = ThermalRunParams(thermal, ys) energy += interaction_energy_therm(run, therm_run)