Commit e9cea70e authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Backup. Tidying run_on_dataset

parent 9aff1eb1
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ from torch.nn.init import xavier_uniform_
from EMRI_DET.utilities import get_script_path
import dill as pickle
from pathlib import Path

import numpy as np

class LinearModel(nn.Module):
    def __init__(self, in_features, out_features, neurons, n_layers, activation, name, out_activation=None, initialisation=xavier_uniform_, use_dropout=False,drop_p=0.25, use_bn=False):
@@ -57,4 +57,8 @@ def load_mlp(model_name, device, get_state_dict=False, outdir='../models'):
    model = pickle.load(open(get_script_path()+f'/{outdir}/{model_name}/function.pickle', "rb"))  # load blank model
    if get_state_dict:
        model.load_state_dict(torch.load(open(get_script_path()+f'/{outdir}/{model_name}/model.pth', "rb"), map_location=device))
    xscalevals = np.load(get_script_path() + f'/{outdir}/{model.name}/xdata_inputs.npy')
    yscalevals = np.load(get_script_path() + f'/{outdir}/{model.name}/ydata_inputs.npy')
    model.xscalevals = xscalevals
    model.yscalevals = yscalevals
    return model
+16 −21
Original line number Diff line number Diff line
@@ -6,13 +6,14 @@ from EMRI_DET.utilities import norm_inputs, unnorm, get_script_path
import seaborn as sns


def run_on_dataset(model, test_data, distances=None, n_batches=1, device=None, y_transform_fn=None, runtime=False, outdir='../models'):
def run_on_dataset(model, xdata, distances=None, n_batches=1, device=None, y_transform_fn=None, runtime=False,
                    eval_model = True):
    """
    Get the re-processed output of the supplied model on a set of supplied test data.

    Args:
        model (LinearModel): Model to test on `test_data`
        test_data (2-tuple/list): Tuple or list of the 'features' and their corresponding 'labels'.
        model (LinearModel): Model to test on `xdata`
        xdata (ndarray): Array of features to test against
        distances (ndarray): List of luminosity distance measurements for the input events. If None, results will not be scaled by luminosity distance. Note that this scaling is applied after the data is converted with y_transform_fn.
        n_batches (int, optional): Number of batches to process the input data in. Defaults to 1 (the entire dataset).
        device (string, optional): Device to attach the input model to, if it is not attached already.
@@ -32,46 +33,40 @@ def run_on_dataset(model, test_data, distances=None, n_batches=1, device=None, y
    """
    if device is not None:
        model = model.to(device)
    if eval_model:
        model.eval()

    xdata, ydata = test_data

    xscalevals = np.load(get_script_path() + f'/{outdir}/{model.name}/xdata_inputs.npy')
    yscalevals = np.load(get_script_path() + f'/{outdir}/{model.name}/ydata_inputs.npy')

    test_input = torch.Tensor(xdata)
    normed_input = norm_inputs(test_input, ref_inputs=xscalevals,norm_type=model.norm_type).float().to(device)
    test_input = torch.from_numpy(xdata).to(device)
    normed_input = norm_inputs(test_input, ref_inputs=model.xscalevals,norm_type=model.norm_type).float().to(device)

    if runtime:
        st = time.perf_counter()
    with torch.no_grad():
        out = []
        for i in range(n_batches):
            output = model(normed_input[i * ydata.size // n_batches: (i + 1) * ydata.size // n_batches])
            output = model(normed_input[i * xdata.shape[0] // n_batches: (i + 1) * xdata.shape[0] // n_batches])
            out.append(output.detach().cpu().numpy())

    if runtime:
        et = time.perf_counter()
        total_time = et - st
        per_point = (et - st) / ydata.size
        per_point = (et - st) / xdata.shape[0]

    output = np.concatenate(out)
    out_unnorm = unnorm(output, ref_inputs=yscalevals,norm_type=model.norm_type)
    out_unnorm = unnorm(output, ref_inputs=model.yscalevals,norm_type=model.norm_type)

    if y_transform_fn is not None:
        out_unnorm = y_transform_fn(out_unnorm)
    
    if distances is None:
        distances = np.ones(xdata.shape[0]) * 0.5
        distances = np.ones(xdata.shape[0])
    
    out_unnorm *= (0.5/distances)[:,None]
    if ydata.ndim == 1:
    out_unnorm *= (1/distances)[:,None]
    if output.ndim == 1:
        out_unnorm = out_unnorm.flatten()
    outputs = (out_unnorm,)

    outputs = out_unnorm
    if runtime:
        outputs += (total_time, per_point,)

        outputs = (out_unnorm, total_time, per_point,)
    return outputs


+1.15 MiB

File added.

No diff preview for this file type.

+21.2 KiB
Loading image diff...
+179 −0
Original line number Diff line number Diff line
import cupy
import numpy as np
from few.waveform import GenerateEMRIWaveform, EMRIInspiral
from utility import snr, brentq_p_at_t
from pathlib import Path
from sys import stdout
import cupy as cp
from scipy.constants import year
import time
import pandas as pd
from pathlib import PurePath

mpool = cupy.get_default_memory_pool()
mpool2 = cupy.get_default_pinned_memory_pool()

outdir = PurePath('/scratch/wiay/christian/torch/EMRI_DET/emri_data/plunge_post_window/first_attempt')
Path(outdir).mkdir(parents=True, exist_ok=True)

use_gpu = True
use_schwarz_separatrix = False  # set to False for negative Y, as Kerr separatrix exceeds Schwarz in this case

prograde = False

kerr_not_list = GenerateEMRIWaveform(
    "Pn5AAKWaveform",
    sum_kwargs=dict(pad_output=True),
    inspiral_kwargs={"max_init_len": int(1e9), "enforce_schwarz_sep": use_schwarz_separatrix},
    use_gpu=use_gpu,
    return_list=False,
)

# Max waveform duration and sample spacing
T = 10 # years
Tsec = T * year
dt = 15.0

# for SNR and waveform calculations
inner_product_kwargs = dict(dt=dt, PSD="cornish_lisa_psd", use_gpu=use_gpu)
waveform_kwargs = {"T": T, "dt": dt}

# initialise trajectory module
traj = EMRIInspiral(func="pn5",inspiral_kwargs={"max_init_len": int(1e9)}, enforce_schwarz_sep=use_schwarz_separatrix)

traj_inds = np.array([0, 1, 2, 4, 5, 11, 12, 13]).astype(np.int32)  # parameter indices needed for trajectory calc

num_wform_samples = int(T * year / dt)

injection_params = np.zeros(14)


def cut_wform_for_plunge_at_time(signal, time, T, dt, window_size):
    total_samples = int(T * 3.154e7 / dt)
    end = total_samples - np.max([0, (time - window_size)*total_samples/T]).astype(np.int32)
    start = total_samples - np.min([total_samples, (time)*total_samples/T]).astype(np.int32)
    return signal[start:end]


def wform_snr_at_t(wform, plunge_time):
    cut_waveform = cut_wform_for_plunge_at_time(wform, plunge_time, T, dt, window_size=4)  # truncate waveform
    if len(cut_waveform) == 0:
        snrval = 0.
    else:
        snrval = snr([cp.array(cut_waveform.real), cp.array(cut_waveform.imag)],
                    **inner_product_kwargs).get()  # need to send waveform to gpu for SNR calc
    return snrval

total = int(1e5)  # we can stop any time before this, though.
per_batch = 100
batches = total // per_batch
plunge_times_to_run = 20


if prograde:
    mult = 1
    add = 'prograde_'
else:
    mult = -1
    add = 'retrograde_'


cols = ['logM', 'logq', 'a', 'p0', 'e', 'Y0', 'thetaS', 'phiS-phiK', 'thetaK', 't', 'SNR', 'wave_runtime', 'traj_runtime']
try:  # pick up where we left off.
    dataframe = pd.read_csv(outdir / (add + 'samp_dataframe.csv'))
    place = int(len(dataframe) / per_batch / plunge_times_to_run)
except:
    dataframe = pd.DataFrame(columns=cols)
    place = 0

print(dataframe)

while place < batches:
    batched_forms = []
    stdout.write(
        f'\rWaveform sets generated: {place:d} out of {batches:d}, or {100 * place / (batches - 1):.2f}% done. ')
    stdout.flush()

    i = 0

    this_chunk = np.zeros((per_batch, 13))

    while i < per_batch:
        # update injection parameters
        injection_params[0] = 10 ** np.random.uniform(np.log10(8e4), np.log10(5e7))
        mass_ratio = 10**(np.random.uniform(-4,np.log10(2e-8)))
        injection_params[1] = injection_params[0] * mass_ratio
        if injection_params[1] < 0.5 or injection_params[1] > 100:
            continue
        injection_params[2] = np.random.uniform(0.01, 0.9999)
        injection_params[4] = np.random.uniform(0.08, 0.5)
        injection_params[5] = np.random.uniform(0.25,0.99) * mult
        injection_params[6] = 1  # fiducial value (we will rescale for the others)
        injection_params[7] = np.arccos(np.random.uniform(-1,1))
        injection_params[9] = np.arccos(np.random.uniform(-1,1))
        injection_params[8] = np.random.uniform(0, 2 * np.pi)

        stdout.write(
            f'\rWaveform sets generated: {place + 1:d} out of {batches:d}, or {100 * place / (batches - 1):.2f}% done. Set progress: {100 * i / (per_batch - 1):.2f}%')
        stdout.flush()

        traj_start = time.perf_counter()
        try:
            root, diff = brentq_p_at_t(traj, T, traj_args=np.take(injection_params, traj_inds).tolist(),
                                traj_kwargs={'max_init_len': int(1e9)}, kerr_separatrix=True,
                                xtol=1e-10, error_handle=True, return_diff=True)
        except KeyboardInterrupt:
            raise
        except Exception as e:
            print(e)
            continue
        #stdout.write(f' | p0 found, diff={abs(diff) / 86400:.2e} days')
        if abs(diff)/86400 >  10: # accuracy of \1\ 10 day(s)
            stdout.write(f' --- Rejected - logM={np.log10(injection_params[0])}, a={injection_params[2]}')
            continue
        traj_end = time.perf_counter()
        traj_time = traj_end - traj_start
        injection_params[3] = root
        
        wave_start = time.perf_counter()
        try:
            check_sig = kerr_not_list(*injection_params, **waveform_kwargs).get()
        except KeyboardInterrupt:
            raise
        except SystemError:
            print('ERROR FOUND: waveform generator.')
            continue
        wave_end = time.perf_counter()
        wave_time = wave_end - wave_start
        batched_forms.append(check_sig)  # package up with plunge time for easier processing later

        # keep hold of parameter information for NN

        this_chunk[i, 0] = np.log10(injection_params[0])
        this_chunk[i, 1] = np.log10(mass_ratio)
        this_chunk[i, 2:6] = injection_params[2:6]
        this_chunk[i, 6:9] = injection_params[7:10]
        this_chunk[i,11] = wave_time
        this_chunk[i,12] = traj_time
        i += 1
    
    for plunge_run in range(plunge_times_to_run):
        time_chunk = this_chunk.copy()
        plunge_times = np.random.uniform(0,T,len(batched_forms))
        out_list = []
        for T_index,h_wform in enumerate(batched_forms):
            out_list.append(wform_snr_at_t(h_wform, plunge_times[T_index]))

        cp.fft.config.get_plan_cache().clear()
        mpool.free_all_blocks()
        mpool2.free_all_blocks()
        time_chunk[:, 9] = plunge_times
        time_chunk[:, 10] = out_list

        # cache our progress
        little_dataframe = pd.DataFrame(time_chunk, columns=cols)
        dataframe = dataframe.append(little_dataframe, ignore_index=True)

    dataframe.to_csv(outdir / (add+'samp_dataframe.csv'), index=False)

    place += 1
Loading