Commit 76d9ab5f authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

added distances to both populations and validate.run_on_dataset

parent 48ddbb7e
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -7,13 +7,14 @@ import seaborn as sns
sns.set_theme()


def run_on_dataset(model, test_data, n_batches=1, device=None, y_transform_fn=None, runtime=False):
def run_on_dataset(model, test_data, distances=None, n_batches=1, device=None, y_transform_fn=None, runtime=False):
    """
    Get the re-processed output of the supplied model on a set of supplied test data.

    Args:
        model (LinearModel): Model to test on `test_data`
        test_data (2-tuple/list): Tuple or list of the 'features' and their corresponding 'labels'.
        distances (ndarray): List of luminosity distance measurements for the input events. If None, results will not be scaled by luminosity distance. Note that this scaling is applied after the data is converted with y_transform_fn.
        n_batches (int, optional): Number of batches to process the input data in. Defaults to 1 (the entire dataset).
        device (string, optional): Device to attach the input model to, if it is not attached already.
        y_transform_fn (function, optional): If the labels/ydata have been pre-processed with a function (e.g. log),
@@ -61,6 +62,11 @@ def run_on_dataset(model, test_data, n_batches=1, device=None, y_transform_fn=No
    if y_transform_fn is not None:
        out_unnorm = y_transform_fn(out_unnorm)
    
    if distances is None:
        distances = np.ones(xdata.shape[0]) * 0.5
    
    out_unnorm *= (0.5/distances)
    
    outputs = (out_unnorm,)

    if runtime:
+32 −44
Original line number Diff line number Diff line
@@ -152,7 +152,32 @@ def convert_M(M,z, to_source=False):
    else:
        return M*(1+z)

# def draw_population_params(num, M_min=5e4, M_max=5e7, M_lam=2, z_max=20, spin_mu=0.98, spin_std=0.01, e_min=0.01,e_max=0.3):
def draw_population_params(num, M_min=5e4, M_max=5e7, M_lam=2, z_max=4, spin_mu=0.98, spin_std=0.01, e_min=0.01,e_max=0.3):

    mpop = mass_pop(M_min=M_min, M_max=M_max, lam=M_lam)
    redpop = redshift_pop(z_max=z_max)

    # spin population with scipy truncnorm
    l1, l2 = (0 - spin_mu) / spin_std, (1 - spin_mu) / spin_std  # rescale the limits (from scipy.stats.truncnorm docs)
    spinpop = truncnorm(l1, l2, loc=spin_mu,scale=spin_std)

    masses = mpop.rvs(size=num)
    second_ms = 10**np.random.uniform(-0.3,2,num)
    redshifts = redpop.rvs(size=num)
    spins = spinpop.rvs(size=num)

    lum_dists = np.zeros(num)
    for i, z in enumerate(redshifts):
        lum_dists[i] = dL(z)
    red_masses = convert_M(masses, redshifts, to_source=False)

    eccentricities = np.random.uniform(e_min, e_max, num)

    out = np.column_stack((red_masses, second_ms, spins, lum_dists, eccentricities))
    return out


# def draw_population_params(num, M_min=5e4, M_max=5e7, M_lam=2, spin_mu=0.98, spin_std=0.01, e_min=0.01,e_max=0.3):
#     '''
#     Draw sets of EMRI parameters from population distributions.
#     :param num: Number of events to draw.
@@ -164,58 +189,21 @@ def convert_M(M,z, to_source=False):
#     :param spin_std: Standard deviation for spin population.
#     :return: Numpy array of shape (num, 5) that consists of the rows (M, a, d_L, z, M_s)
#     '''
#

#     #initialise populations
#

#     mpop = mass_pop(M_min=M_min, M_max=M_max, lam=M_lam)
#     redpop = redshift_pop(z_max=z_max)
#
#     # spin population with scipy truncnorm
#     l1, l2 = (0 - spin_mu) / spin_std, (1 - spin_mu) / spin_std  # rescale the limits (from scipy.stats.truncnorm docs)
#     spinpop = truncnorm(l1, l2, loc=spin_mu,scale=spin_std)
#

#     masses = mpop.rvs(size=num)
#     redshifts = redpop.rvs(size=num)
#     second_ms = 10**np.random.uniform(-0.3,2,num)
#     spins = spinpop.rvs(size=num)
#
#     lum_dists = np.zeros(num)
#     for i, z in enumerate(redshifts):
#         lum_dists[i] = dL(z)
#     red_masses = convert_M(masses, redshifts, to_source=False)
#
#     eccentricities = np.random.uniform(e_min, e_max, num)
#
#     out = np.column_stack((red_masses, spins, lum_dists, redshifts, masses, eccentricities))
#     return out


def draw_population_params(num, M_min=5e4, M_max=5e7, M_lam=2, spin_mu=0.98, spin_std=0.01, e_min=0.01,e_max=0.3):
    '''
    Draw sets of EMRI parameters from population distributions.
    :param num: Number of events to draw.
    :param M_min: Minimum mass for mass population.
    :param M_max: Maximum mass for mass population.
    :param M_lam: Spectral index for mass population power law.
    :param z_max: Maximum redshift for redshift population.
    :param spin_mu: Mean for spin population.
    :param spin_std: Standard deviation for spin population.
    :return: Numpy array of shape (num, 5) that consists of the rows (M, a, d_L, z, M_s)
    '''

    #initialise populations

    mpop = mass_pop(M_min=M_min, M_max=M_max, lam=M_lam)
    # spin population with scipy truncnorm
    l1, l2 = (0 - spin_mu) / spin_std, (1 - spin_mu) / spin_std  # rescale the limits (from scipy.stats.truncnorm docs)
    spinpop = truncnorm(l1, l2, loc=spin_mu,scale=spin_std)

    masses = mpop.rvs(size=num)
    second_ms = 10**np.random.uniform(-0.3,2,num)
    spins = spinpop.rvs(size=num)
    eccentricities = np.random.uniform(e_min, e_max, num)

    out = np.column_stack((masses, second_ms, spins, eccentricities))
    return out
#     out = np.column_stack((masses, second_ms, spins, eccentricities))
#     return out


def draw_other_params(num):
+19 −15
Original line number Diff line number Diff line
@@ -14,8 +14,8 @@ from populations import draw_population_params, draw_other_params
mpool = cupy.get_default_memory_pool()
mpool2 = cupy.get_default_pinned_memory_pool()

outdir = '../emri_data/schwarz_posY_population_/{}'
Path('../emri_data/schwarz_posY_population/').mkdir(parents=True, exist_ok=True)
outdir = '../emri_data/schwarz_posY_population_wdists/{}'
Path('../emri_data/schwarz_posY_population_wdists/').mkdir(parents=True, exist_ok=True)

use_gpu = True
use_schwarz_separatrix = True  # set to False for negative Y, as Kerr separatrix exceeds Schwarz in this case
@@ -67,7 +67,7 @@ def wform_snr_at_t(wform):
    return snrval


cols = ['logM', 'logq', 'a', 'p0', 'e', 'Y0', 'thetaS', 'phiS', 'thetaK', 't', 'SNR']
cols = ['logM', 'logq', 'a', 'p0', 'e', 'Y0', 'dL', 'thetaS', 'phiS', 'thetaK', 't', 'SNR']
try:  # pick up where we left off.
    dataframe = pd.read_csv(outdir.format('samp_dataframe.csv'))
    place = np.load(outdir.format('samp_position.npy')).item()
@@ -80,7 +80,7 @@ except:
    wavegen_times = []
    process_times = []

total = int(1e6)  # we can stop any time before this, though.
total = int(5e3)  # we can stop any time before this, though.
per_batch = 100
batches = total // per_batch
while place < batches:
@@ -92,7 +92,7 @@ while place < batches:
    wave_time_tot = 0
    i = 0

    this_chunk = np.zeros((per_batch, 11))
    this_chunk = np.zeros((per_batch, 12))

    while i < per_batch:
        stdout.write(
@@ -101,19 +101,24 @@ while place < batches:

        # update injection parameters

        M, mu, a, e = draw_population_params(1,M_min=5e4,M_max=5e7,M_lam=2)[0]
        M, mu, a, dL, e = draw_population_params(1,M_min=8e4,M_max=5e7,M_lam=2, e_min=0.08, spin_mu=0.5)[0]
        Y0, qS, phiS, qK, tplunge = draw_other_params(1)[0]

        if M / mu < 5e4:
            continue
        injection_params[0] = M
        injection_params[1] = mu
        injection_params[2] = a
        injection_params[4] = e
        injection_params[5] = Y0
        injection_params[6] = dVals[0]  # fiducial value (we will rescale for the others)

        injection_params[6] = dL  # fiducial value (we will rescale for the others)
        try:
            root = brentq_p_at_t(traj, T, traj_args=np.take(injection_params, traj_inds).tolist(),
                             traj_kwargs={'max_init_len': int(1e9)}, kerr_separatrix=not use_schwarz_separatrix,
                             xtol=1e-12)
                             xtol=1e-8)
        except:
            print(injection_params)
            raise

        injection_params[3] = root

        injection_params[7] = qS
@@ -132,9 +137,8 @@ while place < batches:

        this_chunk[i, 0] = np.log10(injection_params[0])
        this_chunk[i, 1] = np.log10(injection_params[0] / injection_params[1])
        this_chunk[i, 2:6] = injection_params[2:6]
        this_chunk[i, 6:9] = injection_params[7:10]
        this_chunk[i, 9] = tplunge
        this_chunk[i, 2:10] = injection_params[2:10]
        this_chunk[i, 10] = tplunge
        i += 1

    process_start = time.perf_counter()
@@ -145,7 +149,7 @@ while place < batches:
        mpool.free_all_blocks()
        mpool2.free_all_blocks()

    this_chunk[:, 10] = out_list
    this_chunk[:, 11] = out_list
    process_end = time.perf_counter()
    process_time = process_end - process_start