feature: implement extra modes

This commit is contained in:
Valentin Boettcher 2024-05-17 20:14:25 -04:00
parent 8beb82df15
commit e3932ca3b3
8 changed files with 95 additions and 44 deletions

View file

View file

@ -1,27 +1,25 @@
from rabifun.plots import plot_simulation_result
from rabifun.system import *
from rabifun.plots import *
from rabifun.utilities import *
from plot_utils import autoclose
from plot_utils import wrap_plot
# %% interactive
@autoclose
def transient_rabi():
"""A transient rabi oscillation without noise."""
params = Params(η=0.001, d=0.1, laser_detuning=0, Δ=0.05)
t = time_axis(params, 2, 1)
params = Params(η=0.0001, δ=1 / 4, d=0.1, laser_detuning=0.01, Δ=0.005, N=2)
t = time_axis(params, 3, 0.1)
solution = solve(t, params)
signal = output_signal(t, solution.y, params.laser_detuning)
f, (_, ax) = plot_simulation_result(t, signal, params)
f, (_, ax) = plot_simulation_result(make_figure(), t, signal, params)
plot_sidebands(ax, params)
# ax.set_xlim(0.73, 0.77)
f.suptitle("Transient Rabi oscillation")
@autoclose
def steady_rabi():
"""A steady state rabi oscillation without noise."""
@ -32,14 +30,13 @@ def steady_rabi():
signal = output_signal(t, solution.y, params.laser_detuning)
f, (_, ax) = plot_simulation_result(
t, signal, params, window=(params.lifetimes(8), t[-1])
make_figure(), t, signal, params, window=(params.lifetimes(8), t[-1])
)
plot_sidebands(ax, params)
f.suptitle("Steady State Rabi oscillation. No Rabi Sidebands.")
@autoclose
def noisy_transient_rabi():
"""A transient rabi oscillation with noise."""
@ -51,7 +48,7 @@ def noisy_transient_rabi():
noise_strength = 0.1
signal = add_noise(signal, noise_strength)
f, (_, ax) = plot_simulation_result(t, signal, params)
f, (_, ax) = plot_simulation_result(make_figure(), t, signal, params)
plot_sidebands(ax, params)
f.suptitle(f"Transient Rabi oscillation with noise strength {noise_strength}.")
@ -63,7 +60,8 @@ def ringdown_after_rabi():
off_lifetime = 4
laser_detuning = 0.1
params = Params(η=0.001, d=0.1, laser_detuning=laser_detuning, Δ=0.5)
params = Params(η=0.0001, d=0.01, laser_detuning=laser_detuning, Δ=0.00, N=4)
params.laser_off_time = params.lifetimes(off_lifetime)
params.drive_off_time = params.lifetimes(off_lifetime)
@ -75,10 +73,10 @@ def ringdown_after_rabi():
# signal = add_noise(signal, noise_strength)
f, (_, fftax) = plot_simulation_result(
t, signal, params, window=(params.lifetimes(off_lifetime), t[-1])
make_figure(), t, signal, params, window=(params.lifetimes(off_lifetime), t[-1])
)
fftax.axvline(params.Ω - params.laser_detuning, color="black")
fftax.axvline(params.Ω - params.δ - params.laser_detuning, color="black")
fftax.axvline(params.laser_detuning, color="black")
f.suptitle(f"Ringdown after rabi osci EOM after {off_lifetime} lifetimes.")

View file

@ -11,19 +11,21 @@ path = (
)
scan = ScanData.from_dir(path, truncation=[0, 50])
# %% interactive
STEPS = [2, 3, 5]
fig, ax = plot_scan(scan, smoothe_output=500, normalize=True, laser=True, steps=True)
fig = plt.figure("interactive")
ax, *axs = fig.subplots(1, len(STEPS))
plot_scan(scan, smoothe_output=500, normalize=True, laser=True, steps=True, ax=ax)
for STEP in STEPS:
for ax, STEP in zip(axs, STEPS):
time, output, _ = scan.for_step(step=STEP)
t, o, params, cov, scaled = fit_transient(time, output, window_size=100)
plt.figure()
plt.plot(t, o)
plt.plot(t, transient_model(t, *params))
plt.title(
ax.plot(t, o)
ax.plot(t, transient_model(t, *params))
ax.set_title(
f"Transient 2, γ={scaled[1] / 10**3:.2f}kHz ({cov[1] / 10**3:.2f}kHz)\n ω/2π={scaled[0] / (2*np.pi * 10**3):.5f}kHz\n step={STEP}"
)
freq_unit = params[1] / scaled[1]
plt.plot(t, np.sin(2 * np.pi * 4 * 10**4 * t * freq_unit), alpha=0.1)
ax.plot(t, np.sin(2 * np.pi * 4 * 10**4 * t * freq_unit), alpha=0.1)

View file

@ -14,7 +14,9 @@ scan = ScanData.from_dir(path, truncation=[0, 50])
STEPS = [2, 33, 12]
# %% Set Up Figures
fig, (ax1, *axs) = plt.subplots(nrows=1, ncols=len(STEPS) + 1)
fig = plt.figure("interactive")
fig.clf()
(ax1, *axs) = fig.subplots(nrows=1, ncols=len(STEPS) + 1)
# %% Plot scan
plot_scan(scan, smoothe_output=100, normalize=True, laser=False, steps=True, ax=ax1)

View file

@ -1,14 +1,37 @@
import matplotlib.pyplot as plt
from typing import Callable, Any
from functools import wraps
from typing_extensions import ParamSpec, TypeVar, Concatenate
P = ParamSpec("P")
R = TypeVar("R")
def wrap_plot(f):
def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs):
fig = None
if not ax:
fig, ax = setup_function()
def make_figure(fig_name: str = "interactive", *args, **kwargs):
fig = plt.figure(fig_name, *args, **kwargs)
fig.clf()
return fig
ret_val = f(*args, ax=ax, **kwargs)
return (fig, ax, ret_val) if ret_val else (fig, ax)
def wrap_plot(
plot_function: Callable[Concatenate[plt.Figure | None, P], R], # pyright: ignore [reportPrivateImportUsage]
) -> Callable[Concatenate[plt.Figure | None, P], tuple[plt.Figure, R]]: # pyright: ignore [reportPrivateImportUsage]
"""Decorator to wrap a plot function to inject the correct figure
for interactive use. The function that this decorator wraps
should accept the figure as first argument.
:param fig_name: Name of the figure to create. By default it is
"interactive", so that one plot window will be reused.
:param setup_function: Function that returns a figure to use. If
it is provided, the ``fig_name`` will be ignored.
"""
def wrapped(fig, *args: P.args, **kwargs: P.kwargs):
if fig is None:
fig = make_figure()
ret_val = plot_function(fig, *args, **kwargs)
return (fig, ret_val)
return wrapped

View file

@ -5,8 +5,13 @@ import matplotlib.pyplot as plt
import numpy as np
@wrap_plot
def plot_simulation_result(
t: np.ndarray, signal: np.ndarray, params: Params, window=None
fig,
t: np.ndarray,
signal: np.ndarray,
params: Params,
window=None,
):
"""Plot the simulation result. The signal is plotted in the first axis
and the Fourier transform is plotted in the second axis.
@ -19,7 +24,7 @@ def plot_simulation_result(
:returns: figure and axes
"""
f, (ax1, ax2) = plt.subplots(2, 1)
(ax1, ax2) = fig.subplots(2, 1)
ax1.plot(t, signal)
ax1.set_title(f"Output signal\n {params}")
@ -43,7 +48,7 @@ def plot_simulation_result(
ax3.plot(freq, np.angle(fft), linestyle="--", color="C2", alpha=0.5, zorder=-10)
ax3.set_ylabel("Phase")
return f, (ax1, ax2)
return (ax1, ax2)
def plot_sidebands(ax, params: Params):
@ -55,13 +60,17 @@ def plot_sidebands(ax, params: Params):
energy = params.rabi_splitting
first_sidebands = np.abs(
-params.laser_detuning + np.array([1, -1]) * energy / 2 - params.Δ / 2
-params.laser_detuning + np.array([1, -1]) * energy / 2 + params.Δ / 2
)
second_sidebands = (
params.Ω - params.laser_detuning + np.array([1, -1]) * energy / 2 - params.Δ / 2
params.Ω
- params.δ
- params.laser_detuning
+ np.array([1, -1]) * energy / 2
- params.Δ / 2
)
ax.axvline(params.Ω - params.Δ, color="black", label="steady state")
ax.axvline(params.ω_eom, color="black", label="steady state")
for n, sideband in enumerate(first_sidebands):
ax.axvline(

View file

@ -23,6 +23,9 @@ class Params:
Δ: float = 0.0
"""Detuning of the EOM drive."""
δ: float = 1 / 4
"""Mode splitting."""
laser_detuning: float = 0.0
"""Detuning of the laser relative to the _A_ mode."""
@ -32,6 +35,9 @@ class Params:
drive_off_time: float | None = None
"""Time at which the drive is turned off."""
rwa: bool = False
"""Whether to use the rotating wave approximation."""
def periods(self, n: float):
return n * 2 * np.pi / self.Ω
@ -42,12 +48,19 @@ class Params:
def rabi_splitting(self):
return np.sqrt(self.d**2 + self.Δ**2)
@property
def ω_eom(self):
return self.Ω - self.δ - self.Δ
class RuntimeParams:
"""Secondary Parameters that are required to run the simulation."""
def __init__(self, params: Params):
self.Ωs = np.arange(0, params.N) * params.Ω - 1j * np.repeat(params.η, params.N)
Ωs = np.arange(0, params.N) * params.Ω - 1j * np.repeat(params.η, params.N)
Ωs[1:] -= params.δ
self.Ωs = Ωs
def time_axis(params: Params, lifetimes: float, resolution: float = 1):
@ -67,21 +80,23 @@ def time_axis(params: Params, lifetimes: float, resolution: float = 1):
)
def eom_drive(t, x, d, Δ, Ω):
def eom_drive(t, x, d, ω, rwa):
"""The electrooptical modulation drive.
:param t: time
:param x: amplitudes
:param d: drive amplitude
:param Δ: detuning
:param Ω: FSR
:param ω: drive frequency
:param rwa: whether to use the rotating wave approximation
"""
stacked = np.repeat([x], len(x), 0)
np.fill_diagonal(stacked, 0)
stacked = np.sum(stacked, axis=1)
driven_x = d * np.sin((Ω - Δ) * t) * stacked
driven_x = d * np.sin(ω * t) * stacked
if rwa and len(x) > 2:
driven_x[2:] = 0
return driven_x
@ -93,10 +108,12 @@ def make_righthand_side(runtime_params: RuntimeParams, params: Params):
differential = runtime_params.Ωs * x
if (params.drive_off_time is None) or (t < params.drive_off_time):
differential += eom_drive(t, x, params.d, params.Δ, params.Ω)
differential += eom_drive(t, x, params.d, params.ω_eom, params.rwa)
if (params.laser_off_time is None) or (t < params.laser_off_time):
differential += np.exp(-1j * params.laser_detuning * t)
laser = np.exp(-1j * params.laser_detuning * t)
differential[0] += laser / np.sqrt(2)
differential[1:] += laser
return -1j * differential
@ -120,6 +137,7 @@ def solve(t: np.ndarray, params: Params):
(np.min(t), np.max(t)),
initial,
vectorized=False,
# max_step=0.1 * np.pi / (params.Ω * params.N),
t_eval=t,
)

View file

@ -25,7 +25,6 @@ def fancy_error(x, y, err, ax, **kwargs):
return line, err
@wrap_plot
def plot_scan(
data: data.ScanData,
laser=False,