feature: implement extra modes

This commit is contained in:
Valentin Boettcher 2024-05-17 20:14:25 -04:00
parent 8beb82df15
commit e3932ca3b3
8 changed files with 95 additions and 44 deletions

View file

View file

@ -1,27 +1,25 @@
from rabifun.plots import plot_simulation_result
from rabifun.system import * from rabifun.system import *
from rabifun.plots import * from rabifun.plots import *
from rabifun.utilities import * from rabifun.utilities import *
from plot_utils import autoclose from plot_utils import wrap_plot
# %% interactive # %% interactive
@autoclose
def transient_rabi(): def transient_rabi():
"""A transient rabi oscillation without noise.""" """A transient rabi oscillation without noise."""
params = Params(η=0.001, d=0.1, laser_detuning=0, Δ=0.05) params = Params(η=0.0001, δ=1 / 4, d=0.1, laser_detuning=0.01, Δ=0.005, N=2)
t = time_axis(params, 2, 1) t = time_axis(params, 3, 0.1)
solution = solve(t, params) solution = solve(t, params)
signal = output_signal(t, solution.y, params.laser_detuning) signal = output_signal(t, solution.y, params.laser_detuning)
f, (_, ax) = plot_simulation_result(t, signal, params) f, (_, ax) = plot_simulation_result(make_figure(), t, signal, params)
plot_sidebands(ax, params) plot_sidebands(ax, params)
# ax.set_xlim(0.73, 0.77)
f.suptitle("Transient Rabi oscillation") f.suptitle("Transient Rabi oscillation")
@autoclose
def steady_rabi(): def steady_rabi():
"""A steady state rabi oscillation without noise.""" """A steady state rabi oscillation without noise."""
@ -32,14 +30,13 @@ def steady_rabi():
signal = output_signal(t, solution.y, params.laser_detuning) signal = output_signal(t, solution.y, params.laser_detuning)
f, (_, ax) = plot_simulation_result( f, (_, ax) = plot_simulation_result(
t, signal, params, window=(params.lifetimes(8), t[-1]) make_figure(), t, signal, params, window=(params.lifetimes(8), t[-1])
) )
plot_sidebands(ax, params) plot_sidebands(ax, params)
f.suptitle("Steady State Rabi oscillation. No Rabi Sidebands.") f.suptitle("Steady State Rabi oscillation. No Rabi Sidebands.")
@autoclose
def noisy_transient_rabi(): def noisy_transient_rabi():
"""A transient rabi oscillation with noise.""" """A transient rabi oscillation with noise."""
@ -51,7 +48,7 @@ def noisy_transient_rabi():
noise_strength = 0.1 noise_strength = 0.1
signal = add_noise(signal, noise_strength) signal = add_noise(signal, noise_strength)
f, (_, ax) = plot_simulation_result(t, signal, params) f, (_, ax) = plot_simulation_result(make_figure(), t, signal, params)
plot_sidebands(ax, params) plot_sidebands(ax, params)
f.suptitle(f"Transient Rabi oscillation with noise strength {noise_strength}.") f.suptitle(f"Transient Rabi oscillation with noise strength {noise_strength}.")
@ -63,7 +60,8 @@ def ringdown_after_rabi():
off_lifetime = 4 off_lifetime = 4
laser_detuning = 0.1 laser_detuning = 0.1
params = Params(η=0.001, d=0.1, laser_detuning=laser_detuning, Δ=0.5) params = Params(η=0.0001, d=0.01, laser_detuning=laser_detuning, Δ=0.00, N=4)
params.laser_off_time = params.lifetimes(off_lifetime) params.laser_off_time = params.lifetimes(off_lifetime)
params.drive_off_time = params.lifetimes(off_lifetime) params.drive_off_time = params.lifetimes(off_lifetime)
@ -75,10 +73,10 @@ def ringdown_after_rabi():
# signal = add_noise(signal, noise_strength) # signal = add_noise(signal, noise_strength)
f, (_, fftax) = plot_simulation_result( f, (_, fftax) = plot_simulation_result(
t, signal, params, window=(params.lifetimes(off_lifetime), t[-1]) make_figure(), t, signal, params, window=(params.lifetimes(off_lifetime), t[-1])
) )
fftax.axvline(params.Ω - params.laser_detuning, color="black") fftax.axvline(params.Ω - params.δ - params.laser_detuning, color="black")
fftax.axvline(params.laser_detuning, color="black") fftax.axvline(params.laser_detuning, color="black")
f.suptitle(f"Ringdown after rabi osci EOM after {off_lifetime} lifetimes.") f.suptitle(f"Ringdown after rabi osci EOM after {off_lifetime} lifetimes.")

View file

@ -11,19 +11,21 @@ path = (
) )
scan = ScanData.from_dir(path, truncation=[0, 50]) scan = ScanData.from_dir(path, truncation=[0, 50])
# %% interactive
STEPS = [2, 3, 5] STEPS = [2, 3, 5]
fig, ax = plot_scan(scan, smoothe_output=500, normalize=True, laser=True, steps=True) fig = plt.figure("interactive")
ax, *axs = fig.subplots(1, len(STEPS))
plot_scan(scan, smoothe_output=500, normalize=True, laser=True, steps=True, ax=ax)
for STEP in STEPS: for ax, STEP in zip(axs, STEPS):
time, output, _ = scan.for_step(step=STEP) time, output, _ = scan.for_step(step=STEP)
t, o, params, cov, scaled = fit_transient(time, output, window_size=100) t, o, params, cov, scaled = fit_transient(time, output, window_size=100)
plt.figure() ax.plot(t, o)
plt.plot(t, o) ax.plot(t, transient_model(t, *params))
plt.plot(t, transient_model(t, *params)) ax.set_title(
plt.title(
f"Transient 2, γ={scaled[1] / 10**3:.2f}kHz ({cov[1] / 10**3:.2f}kHz)\n ω/2π={scaled[0] / (2*np.pi * 10**3):.5f}kHz\n step={STEP}" f"Transient 2, γ={scaled[1] / 10**3:.2f}kHz ({cov[1] / 10**3:.2f}kHz)\n ω/2π={scaled[0] / (2*np.pi * 10**3):.5f}kHz\n step={STEP}"
) )
freq_unit = params[1] / scaled[1] freq_unit = params[1] / scaled[1]
plt.plot(t, np.sin(2 * np.pi * 4 * 10**4 * t * freq_unit), alpha=0.1) ax.plot(t, np.sin(2 * np.pi * 4 * 10**4 * t * freq_unit), alpha=0.1)

View file

@ -14,7 +14,9 @@ scan = ScanData.from_dir(path, truncation=[0, 50])
STEPS = [2, 33, 12] STEPS = [2, 33, 12]
# %% Set Up Figures # %% Set Up Figures
fig, (ax1, *axs) = plt.subplots(nrows=1, ncols=len(STEPS) + 1) fig = plt.figure("interactive")
fig.clf()
(ax1, *axs) = fig.subplots(nrows=1, ncols=len(STEPS) + 1)
# %% Plot scan # %% Plot scan
plot_scan(scan, smoothe_output=100, normalize=True, laser=False, steps=True, ax=ax1) plot_scan(scan, smoothe_output=100, normalize=True, laser=False, steps=True, ax=ax1)

View file

@ -1,14 +1,37 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Callable, Any
from functools import wraps
from typing_extensions import ParamSpec, TypeVar, Concatenate
P = ParamSpec("P")
R = TypeVar("R")
def wrap_plot(f): def make_figure(fig_name: str = "interactive", *args, **kwargs):
def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs): fig = plt.figure(fig_name, *args, **kwargs)
fig = None fig.clf()
if not ax: return fig
fig, ax = setup_function()
ret_val = f(*args, ax=ax, **kwargs)
return (fig, ax, ret_val) if ret_val else (fig, ax) def wrap_plot(
plot_function: Callable[Concatenate[plt.Figure | None, P], R], # pyright: ignore [reportPrivateImportUsage]
) -> Callable[Concatenate[plt.Figure | None, P], tuple[plt.Figure, R]]: # pyright: ignore [reportPrivateImportUsage]
"""Decorator to wrap a plot function to inject the correct figure
for interactive use. The function that this decorator wraps
should accept the figure as first argument.
:param fig_name: Name of the figure to create. By default it is
"interactive", so that one plot window will be reused.
:param setup_function: Function that returns a figure to use. If
it is provided, the ``fig_name`` will be ignored.
"""
def wrapped(fig, *args: P.args, **kwargs: P.kwargs):
if fig is None:
fig = make_figure()
ret_val = plot_function(fig, *args, **kwargs)
return (fig, ret_val)
return wrapped return wrapped

View file

@ -5,8 +5,13 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
@wrap_plot
def plot_simulation_result( def plot_simulation_result(
t: np.ndarray, signal: np.ndarray, params: Params, window=None fig,
t: np.ndarray,
signal: np.ndarray,
params: Params,
window=None,
): ):
"""Plot the simulation result. The signal is plotted in the first axis """Plot the simulation result. The signal is plotted in the first axis
and the Fourier transform is plotted in the second axis. and the Fourier transform is plotted in the second axis.
@ -19,7 +24,7 @@ def plot_simulation_result(
:returns: figure and axes :returns: figure and axes
""" """
f, (ax1, ax2) = plt.subplots(2, 1) (ax1, ax2) = fig.subplots(2, 1)
ax1.plot(t, signal) ax1.plot(t, signal)
ax1.set_title(f"Output signal\n {params}") ax1.set_title(f"Output signal\n {params}")
@ -43,7 +48,7 @@ def plot_simulation_result(
ax3.plot(freq, np.angle(fft), linestyle="--", color="C2", alpha=0.5, zorder=-10) ax3.plot(freq, np.angle(fft), linestyle="--", color="C2", alpha=0.5, zorder=-10)
ax3.set_ylabel("Phase") ax3.set_ylabel("Phase")
return f, (ax1, ax2) return (ax1, ax2)
def plot_sidebands(ax, params: Params): def plot_sidebands(ax, params: Params):
@ -55,13 +60,17 @@ def plot_sidebands(ax, params: Params):
energy = params.rabi_splitting energy = params.rabi_splitting
first_sidebands = np.abs( first_sidebands = np.abs(
-params.laser_detuning + np.array([1, -1]) * energy / 2 - params.Δ / 2 -params.laser_detuning + np.array([1, -1]) * energy / 2 + params.Δ / 2
) )
second_sidebands = ( second_sidebands = (
params.Ω - params.laser_detuning + np.array([1, -1]) * energy / 2 - params.Δ / 2 params.Ω
- params.δ
- params.laser_detuning
+ np.array([1, -1]) * energy / 2
- params.Δ / 2
) )
ax.axvline(params.Ω - params.Δ, color="black", label="steady state") ax.axvline(params.ω_eom, color="black", label="steady state")
for n, sideband in enumerate(first_sidebands): for n, sideband in enumerate(first_sidebands):
ax.axvline( ax.axvline(

View file

@ -23,6 +23,9 @@ class Params:
Δ: float = 0.0 Δ: float = 0.0
"""Detuning of the EOM drive.""" """Detuning of the EOM drive."""
δ: float = 1 / 4
"""Mode splitting."""
laser_detuning: float = 0.0 laser_detuning: float = 0.0
"""Detuning of the laser relative to the _A_ mode.""" """Detuning of the laser relative to the _A_ mode."""
@ -32,6 +35,9 @@ class Params:
drive_off_time: float | None = None drive_off_time: float | None = None
"""Time at which the drive is turned off.""" """Time at which the drive is turned off."""
rwa: bool = False
"""Whether to use the rotating wave approximation."""
def periods(self, n: float): def periods(self, n: float):
return n * 2 * np.pi / self.Ω return n * 2 * np.pi / self.Ω
@ -42,12 +48,19 @@ class Params:
def rabi_splitting(self): def rabi_splitting(self):
return np.sqrt(self.d**2 + self.Δ**2) return np.sqrt(self.d**2 + self.Δ**2)
@property
def ω_eom(self):
return self.Ω - self.δ - self.Δ
class RuntimeParams: class RuntimeParams:
"""Secondary Parameters that are required to run the simulation.""" """Secondary Parameters that are required to run the simulation."""
def __init__(self, params: Params): def __init__(self, params: Params):
self.Ωs = np.arange(0, params.N) * params.Ω - 1j * np.repeat(params.η, params.N) Ωs = np.arange(0, params.N) * params.Ω - 1j * np.repeat(params.η, params.N)
Ωs[1:] -= params.δ
self.Ωs = Ωs
def time_axis(params: Params, lifetimes: float, resolution: float = 1): def time_axis(params: Params, lifetimes: float, resolution: float = 1):
@ -67,21 +80,23 @@ def time_axis(params: Params, lifetimes: float, resolution: float = 1):
) )
def eom_drive(t, x, d, Δ, Ω): def eom_drive(t, x, d, ω, rwa):
"""The electrooptical modulation drive. """The electrooptical modulation drive.
:param t: time :param t: time
:param x: amplitudes :param x: amplitudes
:param d: drive amplitude :param d: drive amplitude
:param Δ: detuning :param ω: drive frequency
:param Ω: FSR :param rwa: whether to use the rotating wave approximation
""" """
stacked = np.repeat([x], len(x), 0) stacked = np.repeat([x], len(x), 0)
np.fill_diagonal(stacked, 0)
stacked = np.sum(stacked, axis=1) stacked = np.sum(stacked, axis=1)
driven_x = d * np.sin((Ω - Δ) * t) * stacked driven_x = d * np.sin(ω * t) * stacked
if rwa and len(x) > 2:
driven_x[2:] = 0
return driven_x return driven_x
@ -93,10 +108,12 @@ def make_righthand_side(runtime_params: RuntimeParams, params: Params):
differential = runtime_params.Ωs * x differential = runtime_params.Ωs * x
if (params.drive_off_time is None) or (t < params.drive_off_time): if (params.drive_off_time is None) or (t < params.drive_off_time):
differential += eom_drive(t, x, params.d, params.Δ, params.Ω) differential += eom_drive(t, x, params.d, params.ω_eom, params.rwa)
if (params.laser_off_time is None) or (t < params.laser_off_time): if (params.laser_off_time is None) or (t < params.laser_off_time):
differential += np.exp(-1j * params.laser_detuning * t) laser = np.exp(-1j * params.laser_detuning * t)
differential[0] += laser / np.sqrt(2)
differential[1:] += laser
return -1j * differential return -1j * differential
@ -120,6 +137,7 @@ def solve(t: np.ndarray, params: Params):
(np.min(t), np.max(t)), (np.min(t), np.max(t)),
initial, initial,
vectorized=False, vectorized=False,
# max_step=0.1 * np.pi / (params.Ω * params.N),
t_eval=t, t_eval=t,
) )

View file

@ -25,7 +25,6 @@ def fancy_error(x, y, err, ax, **kwargs):
return line, err return line, err
@wrap_plot
def plot_scan( def plot_scan(
data: data.ScanData, data: data.ScanData,
laser=False, laser=False,