From e3932ca3b399141c79440400cac8af76ccaf5e70 Mon Sep 17 00:00:00 2001 From: Valentin Boettcher Date: Fri, 17 May 2024 20:14:25 -0400 Subject: [PATCH] feature: implement extra modes --- rabifun/plots.py | 0 scripts/experiments/002_rabi_detuning_scan.py | 24 ++++++------ scripts/transients_without_amplification.py | 16 ++++---- scripts/weird_oscillations.py | 4 +- src/plot_utils.py | 37 +++++++++++++++---- src/rabifun/plots.py | 21 ++++++++--- src/rabifun/system.py | 36 +++++++++++++----- src/ringfit/plotting.py | 1 - 8 files changed, 95 insertions(+), 44 deletions(-) delete mode 100644 rabifun/plots.py diff --git a/rabifun/plots.py b/rabifun/plots.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/experiments/002_rabi_detuning_scan.py b/scripts/experiments/002_rabi_detuning_scan.py index 035a646..96bb982 100644 --- a/scripts/experiments/002_rabi_detuning_scan.py +++ b/scripts/experiments/002_rabi_detuning_scan.py @@ -1,27 +1,25 @@ -from rabifun.plots import plot_simulation_result from rabifun.system import * from rabifun.plots import * from rabifun.utilities import * -from plot_utils import autoclose +from plot_utils import wrap_plot # %% interactive -@autoclose def transient_rabi(): """A transient rabi oscillation without noise.""" - params = Params(η=0.001, d=0.1, laser_detuning=0, Δ=0.05) - t = time_axis(params, 2, 1) + params = Params(η=0.0001, δ=1 / 4, d=0.1, laser_detuning=0.01, Δ=0.005, N=2) + t = time_axis(params, 3, 0.1) solution = solve(t, params) signal = output_signal(t, solution.y, params.laser_detuning) - f, (_, ax) = plot_simulation_result(t, signal, params) + f, (_, ax) = plot_simulation_result(make_figure(), t, signal, params) plot_sidebands(ax, params) + # ax.set_xlim(0.73, 0.77) f.suptitle("Transient Rabi oscillation") -@autoclose def steady_rabi(): """A steady state rabi oscillation without noise.""" @@ -32,14 +30,13 @@ def steady_rabi(): signal = output_signal(t, solution.y, params.laser_detuning) f, (_, ax) = plot_simulation_result( - t, signal, params, window=(params.lifetimes(8), t[-1]) + make_figure(), t, signal, params, window=(params.lifetimes(8), t[-1]) ) plot_sidebands(ax, params) f.suptitle("Steady State Rabi oscillation. No Rabi Sidebands.") -@autoclose def noisy_transient_rabi(): """A transient rabi oscillation with noise.""" @@ -51,7 +48,7 @@ def noisy_transient_rabi(): noise_strength = 0.1 signal = add_noise(signal, noise_strength) - f, (_, ax) = plot_simulation_result(t, signal, params) + f, (_, ax) = plot_simulation_result(make_figure(), t, signal, params) plot_sidebands(ax, params) f.suptitle(f"Transient Rabi oscillation with noise strength {noise_strength}.") @@ -63,7 +60,8 @@ def ringdown_after_rabi(): off_lifetime = 4 laser_detuning = 0.1 - params = Params(η=0.001, d=0.1, laser_detuning=laser_detuning, Δ=0.5) + params = Params(η=0.0001, d=0.01, laser_detuning=laser_detuning, Δ=0.00, N=4) + params.laser_off_time = params.lifetimes(off_lifetime) params.drive_off_time = params.lifetimes(off_lifetime) @@ -75,10 +73,10 @@ def ringdown_after_rabi(): # signal = add_noise(signal, noise_strength) f, (_, fftax) = plot_simulation_result( - t, signal, params, window=(params.lifetimes(off_lifetime), t[-1]) + make_figure(), t, signal, params, window=(params.lifetimes(off_lifetime), t[-1]) ) - fftax.axvline(params.Ω - params.laser_detuning, color="black") + fftax.axvline(params.Ω - params.δ - params.laser_detuning, color="black") fftax.axvline(params.laser_detuning, color="black") f.suptitle(f"Ringdown after rabi osci EOM after {off_lifetime} lifetimes.") diff --git a/scripts/transients_without_amplification.py b/scripts/transients_without_amplification.py index 63866aa..ccc5f86 100644 --- a/scripts/transients_without_amplification.py +++ b/scripts/transients_without_amplification.py @@ -11,19 +11,21 @@ path = ( ) scan = ScanData.from_dir(path, truncation=[0, 50]) +# %% interactive STEPS = [2, 3, 5] -fig, ax = plot_scan(scan, smoothe_output=500, normalize=True, laser=True, steps=True) +fig = plt.figure("interactive") +ax, *axs = fig.subplots(1, len(STEPS)) +plot_scan(scan, smoothe_output=500, normalize=True, laser=True, steps=True, ax=ax) -for STEP in STEPS: +for ax, STEP in zip(axs, STEPS): time, output, _ = scan.for_step(step=STEP) t, o, params, cov, scaled = fit_transient(time, output, window_size=100) - plt.figure() - plt.plot(t, o) - plt.plot(t, transient_model(t, *params)) - plt.title( + ax.plot(t, o) + ax.plot(t, transient_model(t, *params)) + ax.set_title( f"Transient 2, γ={scaled[1] / 10**3:.2f}kHz ({cov[1] / 10**3:.2f}kHz)\n ω/2π={scaled[0] / (2*np.pi * 10**3):.5f}kHz\n step={STEP}" ) freq_unit = params[1] / scaled[1] - plt.plot(t, np.sin(2 * np.pi * 4 * 10**4 * t * freq_unit), alpha=0.1) + ax.plot(t, np.sin(2 * np.pi * 4 * 10**4 * t * freq_unit), alpha=0.1) diff --git a/scripts/weird_oscillations.py b/scripts/weird_oscillations.py index b36ca03..47a3b0d 100644 --- a/scripts/weird_oscillations.py +++ b/scripts/weird_oscillations.py @@ -14,7 +14,9 @@ scan = ScanData.from_dir(path, truncation=[0, 50]) STEPS = [2, 33, 12] # %% Set Up Figures -fig, (ax1, *axs) = plt.subplots(nrows=1, ncols=len(STEPS) + 1) +fig = plt.figure("interactive") +fig.clf() +(ax1, *axs) = fig.subplots(nrows=1, ncols=len(STEPS) + 1) # %% Plot scan plot_scan(scan, smoothe_output=100, normalize=True, laser=False, steps=True, ax=ax1) diff --git a/src/plot_utils.py b/src/plot_utils.py index b02374b..b839d51 100644 --- a/src/plot_utils.py +++ b/src/plot_utils.py @@ -1,14 +1,37 @@ import matplotlib.pyplot as plt +from typing import Callable, Any +from functools import wraps +from typing_extensions import ParamSpec, TypeVar, Concatenate + +P = ParamSpec("P") +R = TypeVar("R") -def wrap_plot(f): - def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs): - fig = None - if not ax: - fig, ax = setup_function() +def make_figure(fig_name: str = "interactive", *args, **kwargs): + fig = plt.figure(fig_name, *args, **kwargs) + fig.clf() + return fig - ret_val = f(*args, ax=ax, **kwargs) - return (fig, ax, ret_val) if ret_val else (fig, ax) + +def wrap_plot( + plot_function: Callable[Concatenate[plt.Figure | None, P], R], # pyright: ignore [reportPrivateImportUsage] +) -> Callable[Concatenate[plt.Figure | None, P], tuple[plt.Figure, R]]: # pyright: ignore [reportPrivateImportUsage] + """Decorator to wrap a plot function to inject the correct figure + for interactive use. The function that this decorator wraps + should accept the figure as first argument. + + :param fig_name: Name of the figure to create. By default it is + "interactive", so that one plot window will be reused. + :param setup_function: Function that returns a figure to use. If + it is provided, the ``fig_name`` will be ignored. + """ + + def wrapped(fig, *args: P.args, **kwargs: P.kwargs): + if fig is None: + fig = make_figure() + + ret_val = plot_function(fig, *args, **kwargs) + return (fig, ret_val) return wrapped diff --git a/src/rabifun/plots.py b/src/rabifun/plots.py index 6784377..5a2f664 100644 --- a/src/rabifun/plots.py +++ b/src/rabifun/plots.py @@ -5,8 +5,13 @@ import matplotlib.pyplot as plt import numpy as np +@wrap_plot def plot_simulation_result( - t: np.ndarray, signal: np.ndarray, params: Params, window=None + fig, + t: np.ndarray, + signal: np.ndarray, + params: Params, + window=None, ): """Plot the simulation result. The signal is plotted in the first axis and the Fourier transform is plotted in the second axis. @@ -19,7 +24,7 @@ def plot_simulation_result( :returns: figure and axes """ - f, (ax1, ax2) = plt.subplots(2, 1) + (ax1, ax2) = fig.subplots(2, 1) ax1.plot(t, signal) ax1.set_title(f"Output signal\n {params}") @@ -43,7 +48,7 @@ def plot_simulation_result( ax3.plot(freq, np.angle(fft), linestyle="--", color="C2", alpha=0.5, zorder=-10) ax3.set_ylabel("Phase") - return f, (ax1, ax2) + return (ax1, ax2) def plot_sidebands(ax, params: Params): @@ -55,13 +60,17 @@ def plot_sidebands(ax, params: Params): energy = params.rabi_splitting first_sidebands = np.abs( - -params.laser_detuning + np.array([1, -1]) * energy / 2 - params.Δ / 2 + -params.laser_detuning + np.array([1, -1]) * energy / 2 + params.Δ / 2 ) second_sidebands = ( - params.Ω - params.laser_detuning + np.array([1, -1]) * energy / 2 - params.Δ / 2 + params.Ω + - params.δ + - params.laser_detuning + + np.array([1, -1]) * energy / 2 + - params.Δ / 2 ) - ax.axvline(params.Ω - params.Δ, color="black", label="steady state") + ax.axvline(params.ω_eom, color="black", label="steady state") for n, sideband in enumerate(first_sidebands): ax.axvline( diff --git a/src/rabifun/system.py b/src/rabifun/system.py index 2978582..fe0b32d 100644 --- a/src/rabifun/system.py +++ b/src/rabifun/system.py @@ -23,6 +23,9 @@ class Params: Δ: float = 0.0 """Detuning of the EOM drive.""" + δ: float = 1 / 4 + """Mode splitting.""" + laser_detuning: float = 0.0 """Detuning of the laser relative to the _A_ mode.""" @@ -32,6 +35,9 @@ class Params: drive_off_time: float | None = None """Time at which the drive is turned off.""" + rwa: bool = False + """Whether to use the rotating wave approximation.""" + def periods(self, n: float): return n * 2 * np.pi / self.Ω @@ -42,12 +48,19 @@ class Params: def rabi_splitting(self): return np.sqrt(self.d**2 + self.Δ**2) + @property + def ω_eom(self): + return self.Ω - self.δ - self.Δ + class RuntimeParams: """Secondary Parameters that are required to run the simulation.""" def __init__(self, params: Params): - self.Ωs = np.arange(0, params.N) * params.Ω - 1j * np.repeat(params.η, params.N) + Ωs = np.arange(0, params.N) * params.Ω - 1j * np.repeat(params.η, params.N) + Ωs[1:] -= params.δ + + self.Ωs = Ωs def time_axis(params: Params, lifetimes: float, resolution: float = 1): @@ -67,21 +80,23 @@ def time_axis(params: Params, lifetimes: float, resolution: float = 1): ) -def eom_drive(t, x, d, Δ, Ω): +def eom_drive(t, x, d, ω, rwa): """The electrooptical modulation drive. :param t: time :param x: amplitudes :param d: drive amplitude - :param Δ: detuning - :param Ω: FSR + :param ω: drive frequency + :param rwa: whether to use the rotating wave approximation """ + stacked = np.repeat([x], len(x), 0) - np.fill_diagonal(stacked, 0) - stacked = np.sum(stacked, axis=1) - driven_x = d * np.sin((Ω - Δ) * t) * stacked + driven_x = d * np.sin(ω * t) * stacked + + if rwa and len(x) > 2: + driven_x[2:] = 0 return driven_x @@ -93,10 +108,12 @@ def make_righthand_side(runtime_params: RuntimeParams, params: Params): differential = runtime_params.Ωs * x if (params.drive_off_time is None) or (t < params.drive_off_time): - differential += eom_drive(t, x, params.d, params.Δ, params.Ω) + differential += eom_drive(t, x, params.d, params.ω_eom, params.rwa) if (params.laser_off_time is None) or (t < params.laser_off_time): - differential += np.exp(-1j * params.laser_detuning * t) + laser = np.exp(-1j * params.laser_detuning * t) + differential[0] += laser / np.sqrt(2) + differential[1:] += laser return -1j * differential @@ -120,6 +137,7 @@ def solve(t: np.ndarray, params: Params): (np.min(t), np.max(t)), initial, vectorized=False, + # max_step=0.1 * np.pi / (params.Ω * params.N), t_eval=t, ) diff --git a/src/ringfit/plotting.py b/src/ringfit/plotting.py index 3ed883d..2c4fd53 100644 --- a/src/ringfit/plotting.py +++ b/src/ringfit/plotting.py @@ -25,7 +25,6 @@ def fancy_error(x, y, err, ax, **kwargs): return line, err -@wrap_plot def plot_scan( data: data.ScanData, laser=False,