Commit 293569a4 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Dataloading for first batch of schwarz data updated

parent 9d2d3461
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ def model_train_test(data, model, device, n_epochs, n_batches, loss_function, le
                    for param in model.parameters():
                        param.grad = None
                    outputs = model(inputs[i * ytrainsize // n_batches:(i+1)*ytrainsize // n_batches])
                    loss = torch.sqrt(loss_function(outputs, targets[i * ytrainsize // n_batches: (i+1)*ytrainsize // n_batches]))
                    loss = loss_function(outputs, targets[i * ytrainsize // n_batches: (i+1)*ytrainsize // n_batches])
                    loss.backward()
                    optimizer.step()
                    current_loss += loss.item()
@@ -74,7 +74,7 @@ def model_train_test(data, model, device, n_epochs, n_batches, loss_function, le

                    for i in range(n_batches):
                        outputs = model(inputs[i * ytestsize // n_batches: (i+1)*ytestsize // n_batches])
                        loss = torch.sqrt(loss_function(outputs, targets[i * ytestsize // n_batches: (i+1)*ytestsize // n_batches]))
                        loss = loss_function(outputs, targets[i * ytestsize // n_batches: (i+1)*ytestsize // n_batches])
                        current_loss += loss.item()

                    test_losses.append(current_loss / n_batches)
+0 −0

File added.

Preview suppressed by a .gitattributes entry or the file's encoding is unsupported.

+0 −0

File added.

Preview suppressed by a .gitattributes entry or the file's encoding is unsupported.

+62 −0
Original line number Diff line number Diff line
import numpy as np
from pathlib import Path
import pandas as pd

data_directory = '../schwarz_data/{}'

# Grid
snr_grid = np.load(data_directory.format('grid_snrs.npy'))
intrinsics = np.load(data_directory.format('grid_intrinsics.npy'))
angulars = np.load(data_directory.format('grid_extrinsics.npy'))
plunges = np.load(data_directory.format('grid_plunges.npy'))

intrinsics[:,:2] = np.log10(intrinsics[:,:2])

grid_out = np.zeros(shape=(snr_grid.size,10))
flat_snr_grid = snr_grid.flatten()

num = 0
for intrinsic_set in intrinsics:
    for angular_set in angulars:
        for time in plunges:
            out = np.zeros(9)
            out[:5] = intrinsic_set
            out[5:8] = angular_set
            out[8] = time

            grid_out[num,:9] = out

            snr_here = flat_snr_grid[num]
            if snr_here == 0:
                snr_here += 1e-6
            grid_out[num,9] = snr_here
            num += 1

save_dir = '../schwarz_data/{}'
Path('../schwarz_data/').mkdir(parents=True, exist_ok=True)

cols = ['logM','logq','a','e','Y0','thetaS','phiS','thetaK','t','SNR']
df_out = pd.DataFrame(grid_out, columns=cols)
df_out.to_csv(data_directory.format('grid_dataframe.csv'), index=False)

# Samples
snr_list = np.load(data_directory.format('samp_snrs.npy'))
inds_to_keep = ~np.isnan(snr_list)

intrinsics = np.load(data_directory.format('samp_intrinsics.npy'))
angulars = np.load(data_directory.format('samp_extrinsics.npy'))
intrinsics[:,:2] = np.log10(intrinsics[:,:2])

samp_out = np.zeros(shape=(snr_list.size,10))

for i in range(snr_list.size):
    here = np.zeros(9)
    here[:5] = intrinsics[i,:5]
    here[5:8] = angulars[i,:]
    here[8] = intrinsics[i,5]
    samp_out[i,:9] = here
    samp_out[i,9] = snr_list[i]

samp_out = samp_out[inds_to_keep,:]
samp_df_out = pd.DataFrame(samp_out, columns=cols)
samp_df_out.to_csv(data_directory.format('samp_dataframe.csv'), index=False)
Loading