Commit 1e645fca authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

testing script for real data using package validate functions

parent c6829a71
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from EMRI_DET.utilities import norm, norm_inputs, unnorm_inputs, unnorm, get_scr
from sys import stdout
from pathlib import Path


def model_train_test(data, model, device, n_epochs, n_batches, loss_function, learning_rate, verbose=False, return_losses=False):
    xtrain, ytrain, xtest, ytest = data
    model.to(device)
+32 −0
Original line number Diff line number Diff line
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from EMRI_DET.utilities import norm, norm_inputs, unnorm
from EMRI_DET.validate import run_on_dataset, compute_rmse, grid_heatmap_corner
from EMRI_DET.nn.model_creation import load_mlp
import pandas as pd

device = 'cuda:0'
model_name = 'silu_1'
mlp = load_mlp(model_name, get_state_dict=True).to(device)
mlp.eval()

df = pd.read_csv('../schwarz_data/grid_dataframe.csv')
x_inds = [0, 1, 2, 3, 4, 5, 6, 7, 8]  # [0,1,2,4,5,6,7,8,9]
y_inds = [9]

xdata = df.iloc[:,x_inds].to_numpy()
ydata = np.log(df.iloc[:,y_inds].to_numpy())

xmeanstd = np.load(f'../models/{model_name}/xdata_mean_std.npy')
ymeanstd = np.load(f'../models/{model_name}/ydata_mean_std.npy')

net_out = run_on_dataset(mlp,[xdata, ydata],device=device,y_transform_fn=np.exp)

rmse = compute_rmse([np.exp(ydata),np.exp(net_out)])
print('RMSE: ', rmse)

big_df = df.assign(pred=net_out)

grid_heatmap_corner(big_df,'SNR','pred')
+6.74 KiB (29.1 KiB)
Loading image diff...
+56 −13

File changed.

Preview size limit exceeded, changes collapsed.