Commit c6829a71 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

directory fix for using model not created with create_mlp.

parent 293569a4
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -3,14 +3,15 @@ import numpy as np
import matplotlib.pyplot as plt
from EMRI_DET.utilities import norm, norm_inputs, unnorm_inputs, unnorm, get_script_path
from sys import stdout

from pathlib import Path

def model_train_test(data, model, device, n_epochs, n_batches, loss_function, learning_rate, verbose=False, return_losses=False):
    xtrain, ytrain, xtest, ytest = data
    model.to(device)

    name = model.name
    path = get_script_path()

    Path(get_script_path()+f'/../models/{name}/').mkdir(parents=True, exist_ok=True)
    np.save(path+'/../models/'+name+'/xdata_mean_std.npy',np.array([xtrain.mean(axis=0), xtrain.std(axis=0)]))
    np.save(path+'/../models/'+name+'/ydata_mean_std.npy',np.array([ytrain.mean(), ytrain.std()]))

+4 −4
Original line number Diff line number Diff line
@@ -6,10 +6,10 @@ import pandas as pd

if __name__ == '__main__':
    device = "cuda:0"
    fp = '../schwarz_negY/{}'
    fp = '../schwarz_data/{}'
    #['logM', 'logq', 'a', 'p0', 'e', 'Y0', 'thetaS', 'phiS', 'thetaK', 't', 'SNR']
    train_inds = [0,1,2,4,5,6,7,8,9]
    test_inds = [10]
    train_inds = [0,1,2,3,4,5,6,7,8]#[0,1,2,4,5,6,7,8,9]
    test_inds = [9]

    traindata = pd.read_csv(fp.format('samp_dataframe.csv'))
    xtrain = traindata.iloc[:,train_inds].to_numpy()
@@ -23,7 +23,7 @@ if __name__ == '__main__':
    out_features = 1
    layers = 4
    neurons = [256,128,64,32]
    activation = nn.Tanh
    activation = nn.SiLU
    model = create_mlp(input_features=in_features,output_features=out_features,neurons=neurons,layers=layers,activation=activation,device=device, model_name='m1')

    data = [xtrain, ytrain, xtest, ytest]
+22.3 KiB
Loading image diff...
+19.3 KiB

File added.

No diff preview for this file type.

+160 B

File added.

No diff preview for this file type.

Loading