Source code for gwtails.tails

#############################################################################
##
##      Filename: tails.py
##
##      Author: Tousif Islam
##
##      Created: 12-07-2023
##
##      Modified: 09-27-2025
##
#############################################################################

import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches


def _apply_plot_style():
    """Apply gwtails default matplotlib style settings."""
    matplotlib.rcParams['mathtext.fontset'] = 'stix'
    matplotlib.rcParams['font.family'] = 'STIXGeneral'
    matplotlib.rcParams['axes.linewidth'] = 1
    plt.rcParams["figure.figsize"] = (6, 6)
    plt.rcParams['font.size'] = '18'

from scipy.signal import find_peaks
from scipy.optimize import curve_fit
import gwtools

[docs] class PostMergerAmplitudeFit: """ Class to fit post-merger late-tile tail data """
[docs] def __init__(self, t=None, f=None, filename=None, qinput=1000, throw_junk_until_indx=None, qnm_fit_window=None, tail_fit_window=None, crossterm_fit_window=None, fit_tail_envelop=False, m1_to_M_scale=True): """ Parameters: ----------- t : array-like, optional Time array. Default is None - in that case you should provide a filename that contains the strain/psi4. f : array-like, optional Function array - it could be strain h_lm or psi4_lm. Default is None - in that case you should provide a filename that contains the strain/psi4. filename : str, optional Name of the waveform data file (Required); should be in .txt or .dat format; columns should have the following format: time, h_real, h_imag, .. Default is None - in that case you should provide (time 't' and strain/psi4 'f'). qinput : float, default 1000 Mass ratio values with qinput>=1. throw_junk_until_indx : int, optional Last index until which data should be discarded before applying qnm or tail fits; Default is None. tail_fit_window : list, optional Provide a window for fitting tail coefficients only; e.g. [200,800] (in M); Default is None (in that case no tail fit is performed). qnm_fit_window : list, optional Provide a window for fitting qnm coefficients only; e.g. [200,800] (in M); Default is None (in that case no qnm fit is performed). crossterm_fit_window : list, optional Provide a window for fitting cross-term coefficients only; e.g. [100,200] (in M); Default is None (in that case no cross-term fit is performed). fit_tail_envelop : bool, default False Whether the data has oscillatory tail behaviour - in which case, we fit the tail envelop. m1_to_M_scale : bool, default True Whether a mass-scale transformation from m1 (mass of the primary) to total mass M is needed. """ # Handle data input: either from file or direct provision if filename is not None: self.filename = filename self.raw_time, self.raw_h = self._read_data_from_file() elif t is not None and f is not None: self.raw_time, self.raw_h = t, f else: raise ValueError("Either provide 'filename' or both 't' and 'f' arrays") print("..... time and strain/psi4 data is loaded") # Mass-scale conversion and data cleaning if m1_to_M_scale: m1toM = 1 / (1 + 1/qinput) self.raw_time, self.raw_h = self.raw_time * m1toM, self.raw_h * m1toM print("..... waveform mass-scale changed from m1 to M") # Remove junk data from start self.throw_junk_until_indx = throw_junk_until_indx if throw_junk_until_indx is not None: self.raw_time = self.raw_time[throw_junk_until_indx:] self.raw_h = self.raw_h[throw_junk_until_indx:] print("..... junk data is removed from the start") # Set merger time coordinate and interpolate self.peak_time = self._get_peak_time(self.raw_time, self.raw_h) self.t_interp, self.h_interp = self._get_interp_data() print("..... strain/psi4 is cast onto a common time grid") # Store and perform tail fits self.tail_fit_window = tail_fit_window self.fit_tail_envelop = fit_tail_envelop if tail_fit_window is not None: print("..... fitting the tails with a decaying power law") if fit_tail_envelop: print('.. tail data has oscillations - fitting tail envelop instead') self.popt_tail, self.pcov_tail, self.t_envelop, self.A_envelop = self._fit_tail_envelop_data() else: self.popt_tail, self.pcov_tail = self._fit_tail_data() # Store and perform QNM fits self.qnm_fit_window = qnm_fit_window if qnm_fit_window is not None: print("..... fitting QNM with the 220 mode") self.popt_qnm, self.pcov_qnm = self._fit_qnm_data() # Store and perform oscillatory intermediate fits self.oscillatory_fit_window = crossterm_fit_window if tail_fit_window is not None and qnm_fit_window is not None: print("..... fitting oscillatory intermediate part") if crossterm_fit_window is None: self.oscillatory_fit_window = [qnm_fit_window[1]+15, tail_fit_window[0]-15] self.popt_cross_terms, self.pcov_cross_terms = self._fit_cross_terms_data()
def _read_data_from_file(self): """Reads data directly from file (filename should be absolute path).""" data = np.loadtxt(self.filename, unpack=True) print(f'.......... Data read. Shape: {data.shape}') return data[0], data[1] - 1j*data[2]
[docs] def plot_raw_amplitude(self): """Plots raw amplitudes without time shift or junk removal.""" _apply_plot_style() plt.figure(figsize=(6,6)) plt.semilogy(self.raw_time, abs(self.raw_h), color='royalblue') plt.xlabel('raw time') plt.ylabel('raw amplitude')
def _get_peak_via_quadratic_fit(self, t, func): """Finds peak using quadratic fit around maximum value.""" index = np.clip(np.argmax(func), 2, len(t) - 3) # Setup quadratic fit around peak t_local = t[index-2:index+3] - t[index] f_local = func[index-2:index+3] # Solve normal equations for quadratic fit X = np.vstack([np.ones(5), t_local, t_local**2]) XTX_inv = np.linalg.inv(X @ X.T) coefs = XTX_inv @ X @ f_local # Return peak time and amplitude peak_offset = -coefs[1] / (2 * coefs[2]) peak_amp = coefs[0] - coefs[1]**2 / (4 * coefs[2]) return t[index] + peak_offset, peak_amp def _get_peak_time(self, t, modes): """Finds peak time using quadratic fit around maximum amplitude.""" return self._get_peak_via_quadratic_fit(t, abs(modes)**2)[0] def _get_interp_data(self): """Interpolate amplitudes after time shift or junk removal.""" t_transform = self.raw_time - self.peak_time t_interp = np.arange(-50,max(t_transform)-100,0.1) h_interp = gwtools.interpolate_h(t_transform, self.raw_h, t_interp) return t_interp, h_interp
[docs] def plot_interpolated_amplitude(self): """Plots interpolated amplitudes after time shift or junk removal.""" _apply_plot_style() plt.figure(figsize=(6,6)) plt.semilogy(self.t_interp, abs(self.h_interp), color='royalblue') plt.xlabel('time') plt.ylabel('amplitude')
[docs] def freq(time, h, s=0.1): """Calculate frequency from phase derivative.""" phase_interp = gwtools.interpolant_h(time, gwtools.phase(h), s=s) return abs(splev(time, phase_interp, der=1))
def _cut_data(self, tini, tout): """Extract data within time window.""" mask = (self.t_interp >= tini) & (self.t_interp <= tout) return self.t_interp[mask], abs(self.h_interp[mask])
[docs] def qnm_fit_func(self, t, Aqnm, tau): """QNM exponential decay function.""" return Aqnm * np.exp(-t/tau)
[docs] def tail_fit_func(self, t, Atail, c, n): """Tail power-law decay function.""" return Atail / (t + c)**n
[docs] def qnm_and_tail_fit_func(self, t, phi_tail, omega): """Combined QNM + tail + cross-terms fitting function.""" qnm_sq = self.qnm_fit_func(t, *self.popt_qnm)**2 tail_sq = self.tail_fit_func(t, *self.popt_tail)**2 # Cross terms qnm_amp = self.popt_qnm[0] * np.exp(-t/self.popt_qnm[1]) tail_amp = self.popt_tail[0] * (t + self.popt_tail[1])**(-self.popt_tail[2]) cross = 2 * qnm_amp * tail_amp * np.cos(phi_tail + omega*t) return np.sqrt(qnm_sq + tail_sq + cross)
def _fit_tail_data(self): """Fit tail data with power-law decay.""" xdata, ydata = self._cut_data(*self.tail_fit_window) return curve_fit(self.tail_fit_func, xdata, ydata, maxfev=10000) def _fit_tail_envelop_data(self): """Fit envelope of oscillatory tail data.""" xdata, ydata = self._cut_data(*self.tail_fit_window) peaks = find_peaks(np.log10(ydata))[0] x_env, y_env = xdata[peaks], ydata[peaks] popt, pcov = curve_fit(self.tail_fit_func, x_env, y_env, maxfev=10000) return popt, pcov, x_env, y_env def _fit_qnm_data(self): """Fit QNM data with exponential decay.""" xdata, ydata = self._cut_data(*self.qnm_fit_window) return curve_fit(self.qnm_fit_func, xdata, ydata, maxfev=10000) def _fit_cross_terms_data(self): """Fit oscillatory intermediate region.""" xdata, ydata = self._cut_data(*self.oscillatory_fit_window) return curve_fit(self.qnm_and_tail_fit_func, xdata, ydata, maxfev=10000) def _plot_qnm_fit(self, t_plot=None, xmax=None): """Plot QNM fit comparison.""" _apply_plot_style() if t_plot is None: t_plot = np.arange(10, self.tail_fit_window[1], 0.1) y_min = 10**np.floor(np.log10(min(abs(self.h_interp)))) plt.semilogy(self.t_interp, abs(self.h_interp), color='grey', lw=3, alpha=0.4, label='Data') plt.semilogy(t_plot, self.qnm_fit_func(t_plot, *self.popt_qnm), '--', label='QNM Fit') if xmax is None: plt.xlim(0, 1020) else: plt.xlim(0, xmax) plt.ylim(y_min, 1e1) plt.xlabel('t') plt.ylabel('Amplitude') plt.legend(fontsize=14) plt.show() def _plot_tail_fits(self, t_plot=None, xmax=None): """Plot tail fit comparison.""" _apply_plot_style() if t_plot is None: t_plot = np.arange(10, self.tail_fit_window[1], 0.1) y_min = 10**np.floor(np.log10(min(abs(self.h_interp)))) plt.semilogy(self.t_interp, abs(self.h_interp), color='grey', lw=4, alpha=0.4, label='Data') plt.semilogy(t_plot, self.tail_fit_func(t_plot, *self.popt_tail), '--', label='Tail Fit') if xmax is None: plt.xlim(-100, 1020) else: plt.xlim(-100, xmax) plt.ylim(y_min, 1e1) plt.xlabel('time') plt.ylabel('Amplitude') plt.legend(fontsize=14) plt.tight_layout() plt.show() def _plot_all_fits(self, t_plot=None, xmax=None): """Plot all fits in a two-panel comparison.""" _apply_plot_style() if t_plot is None: t_plot = np.arange(10, self.tail_fit_window[1], 0.1) y_min = 10**np.floor(np.log10(min(abs(self.h_interp)))) plt.figure(figsize=(14, 6)) # Left panel: Individual fits plt.subplot(121) plt.semilogy(self.t_interp, abs(self.h_interp), color='grey', lw=4, alpha=0.4, label='Data') plt.semilogy(t_plot, self.tail_fit_func(t_plot, *self.popt_tail), '--', label='Tail Fit') plt.semilogy(t_plot, self.qnm_fit_func(t_plot, *self.popt_qnm), '--', label='QNM Fit') if xmax is None: plt.xlim(-100, 1020) else: plt.xlim(-100, xmax) plt.ylim(y_min, 1e1) plt.xlabel('time') plt.ylabel('Amplitude') plt.legend(fontsize=14) # Right panel: Combined fit plt.subplot(122) plt.semilogy(self.t_interp, abs(self.h_interp), color='grey', lw=4, alpha=0.4, label='Data') plt.semilogy(t_plot, self.qnm_and_tail_fit_func(t_plot, *self.popt_cross_terms), '--', color='g', label='QNM+Tail Fit') if xmax is None: plt.xlim(-100, 1020) else: plt.xlim(-100, xmax) plt.ylim(y_min, 1e1) plt.xlabel('time') plt.ylabel('Amplitude') plt.legend(fontsize=14) plt.tight_layout() plt.show() def _print_all_fits(self): """Print all fitted parameters.""" if self.tail_fit_window is not None: print(f'Atail : {self.popt_tail[0]:.9e}') print(f'c : {self.popt_tail[1]:.9e}') print(f'n : {self.popt_tail[2]:.9e}') if self.qnm_fit_window is not None: print(f'Aqnm : {self.popt_qnm[0]:.9f}') print(f'tau : {self.popt_qnm[1]:.9f}') if self.tail_fit_window is not None and self.qnm_fit_window is not None: if hasattr(self, 'popt_cross_terms'): print(f'phi_tail : {self.popt_cross_terms[0]:.9f}') print(f'omega : {self.popt_cross_terms[1]:.9f}')