Commit 82db3bc3 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

population sampling script added

parent 480c8a69
Loading
Loading
Loading
Loading
+165 −0
Original line number Diff line number Diff line
import cupy
import numpy as np
from few.waveform import GenerateEMRIWaveform, EMRIInspiral
from utility import snr, brentq_p_at_t
from pathlib import Path
from sys import stdout
import cupy as cp
from scipy.constants import year
import time
import pandas as pd
from populations import draw_population_params, draw_other_params


mpool = cupy.get_default_memory_pool()
mpool2 = cupy.get_default_pinned_memory_pool()

outdir = '../emri_data/schwarz_posY_population_/{}'
Path('../emri_data/schwarz_posY_population/').mkdir(parents=True, exist_ok=True)

use_gpu = True
use_schwarz_separatrix = True  # set to False for negative Y, as Kerr separatrix exceeds Schwarz in this case
kerr = GenerateEMRIWaveform(
    "Pn5AAKWaveform",
    sum_kwargs=dict(pad_output=True),
    inspiral_kwargs={"max_init_len": int(1e6), "enforce_schwarz_sep": use_schwarz_separatrix},
    use_gpu=use_gpu,
    return_list=True,
)

kerr_not_list = GenerateEMRIWaveform(
    "Pn5AAKWaveform",
    sum_kwargs=dict(pad_output=True),
    inspiral_kwargs={"max_init_len": int(1e6), "enforce_schwarz_sep": use_schwarz_separatrix},
    use_gpu=use_gpu,
    return_list=False,
)

# Max waveform duration and sample spacing
T = 4  # years
Tsec = T * year
dt = 15.0

# for SNR and waveform calculations
inner_product_kwargs = dict(dt=dt, PSD="cornish_lisa_psd", use_gpu=True)
waveform_kwargs = {"T": T, "dt": dt}

nDists = 10  # dont touch
dVals = np.linspace(0, 5, nDists + 1)[1:]

# initialise trajectory module
traj = EMRIInspiral(func="pn5",inspiral_kwargs={"max_init_len": int(1e9)}, enforce_schwarz_sep=use_schwarz_separatrix)

traj_inds = np.array([0, 1, 2, 4, 5, 11, 12, 13]).astype(np.int32)  # parameter indices needed for trajectory calc

num_wform_samples = int(T * year / dt)

injection_params = np.zeros(14)


def wform_snr_at_t(wform):
    plunge_time = wform[1]
    wform = wform[0]
    ind_to_cut = int((1 - plunge_time / 4) * num_wform_samples)
    cut_waveform = wform[ind_to_cut:]  # truncate waveform
    snrval = snr([cp.array(cut_waveform.real), cp.array(cut_waveform.imag)],
                 **inner_product_kwargs).get()  # need to send waveform to gpu for SNR calc
    return snrval


cols = ['logM', 'logq', 'a', 'p0', 'e', 'Y0', 'thetaS', 'phiS', 'thetaK', 't', 'SNR']
try:  # pick up where we left off.
    dataframe = pd.read_csv(outdir.format('samp_dataframe.csv'))
    place = np.load(outdir.format('samp_position.npy')).item()
    wavegen_times = np.load(outdir.format('samp_waveform_times.npy')).tolist()
    process_times = np.load(outdir.format('samp_processing_times.npy')).tolist()

except:
    dataframe = pd.DataFrame(columns=cols)
    place = 0
    wavegen_times = []
    process_times = []

total = int(1e6)  # we can stop any time before this, though.
per_batch = 100
batches = total // per_batch
while place < batches:
    batched_forms = []
    stdout.write(
        f'\rWaveform sets generated: {place:d} out of {batches:d}, or {100 * place / (batches - 1):.2f}% done. ')
    stdout.flush()

    wave_time_tot = 0
    i = 0

    this_chunk = np.zeros((per_batch, 11))

    while i < per_batch:
        stdout.write(
            f'\rWaveform sets generated: {place + 1:d} out of {batches:d}, or {100 * place / (batches - 1):.2f}% done. Set progress: {100 * i / (per_batch - 1):.2f}%')
        stdout.flush()

        # update injection parameters

        M, mu, a, e = draw_population_params(1,M_min=5e4,M_max=5e7,M_lam=2)[0]
        Y0, qS, phiS, qK, tplunge = draw_other_params(1)[0]

        injection_params[0] = M
        injection_params[1] = mu
        injection_params[2] = a
        injection_params[4] = e
        injection_params[5] = Y0
        injection_params[6] = dVals[0]  # fiducial value (we will rescale for the others)

        root = brentq_p_at_t(traj, T, traj_args=np.take(injection_params, traj_inds).tolist(),
                             traj_kwargs={'max_init_len': int(1e9)}, kerr_separatrix=not use_schwarz_separatrix,
                             xtol=1e-12)
        injection_params[3] = root

        injection_params[7] = qS
        injection_params[9] = qK
        injection_params[8] = phiS

        wave_start = time.perf_counter()
        check_sig = kerr_not_list(*injection_params, **waveform_kwargs).get()
        wave_end = time.perf_counter()
        wave_time = wave_end - wave_start
        wave_time_tot += wave_time

        batched_forms.append([check_sig, tplunge])  # package up with plunge time for easier processing later

        # keep hold of parameter information for NN

        this_chunk[i, 0] = np.log10(injection_params[0])
        this_chunk[i, 1] = np.log10(injection_params[0] / injection_params[1])
        this_chunk[i, 2:6] = injection_params[2:6]
        this_chunk[i, 6:9] = injection_params[7:10]
        this_chunk[i, 9] = tplunge
        i += 1

    process_start = time.perf_counter()
    out_list = []
    for h_wform in batched_forms:
        out_list.append(wform_snr_at_t(h_wform))
        cp.fft.config.get_plan_cache().clear()
        mpool.free_all_blocks()
        mpool2.free_all_blocks()

    this_chunk[:, 10] = out_list
    process_end = time.perf_counter()
    process_time = process_end - process_start

    wavegen_times.append(wave_time_tot / per_batch)

    process_times.append(process_time / per_batch)

    # cache our progress
    little_dataframe = pd.DataFrame(this_chunk, columns=cols)
    dataframe = dataframe.append(little_dataframe, ignore_index=True)

    dataframe.to_csv(outdir.format('samp_dataframe.csv'), index=False)
    np.save(outdir.format('samp_waveform_times.npy'), arr=np.array(wavegen_times))
    np.save(outdir.format('samp_processing_times.npy'), arr=np.array(process_times))
    np.save(outdir.format('samp_position.npy'), arr=np.array(place))

    place += 1