Commit 12a52393 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

bugfix

parent 71a551a9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -18,10 +18,10 @@ def model_train_test(data, model, device, n_epochs, n_batches, loss_function, op
    path = get_script_path()
    norm_type = model.norm_type
    Path(get_script_path()+f'/../models/{name}/').mkdir(parents=True, exist_ok=True)
    if norm_type = 'z-score':
    if norm_type == 'z-score':
        np.save(path+'/../models/'+name+'/xdata_inputs.npy',np.array([xtrain.mean(axis=0), xtrain.std(axis=0)]))
        np.save(path+'/../models/'+name+'/ydata_inputs.npy',np.array([ytrain.mean(), ytrain.std()]))
    elif norm_type = 'uniform':
    elif norm_type == 'uniform':
        np.save(path+'/../models/'+name+'/xdata_inputs.npy',np.array([np.min(xtrain,axis=0), np.max(xtrain,axis=0)]))
        np.save(path+'/../models/'+name+'/ydata_inputs.npy',np.array([np.min(ytrain), np.max(ytrain)]))
    xtest = torch.from_numpy(norm_inputs(xtest, ref_dataframe=xtrain, norm_type=norm_type)).to(device).float()
+16 KiB

File added.

No diff preview for this file type.