From ffeb1c63890025041fdcc078079a42be2fe12e7a Mon Sep 17 00:00:00 2001 From: Christian Chapman-Bird Date: Tue, 23 Nov 2021 17:34:15 +0000 Subject: [PATCH] out activation properly added --- EMRI_DET/nn/model_creation.py | 2 +- EMRI_DET/validate.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/EMRI_DET/nn/model_creation.py b/EMRI_DET/nn/model_creation.py index 645ae69..42ae684 100644 --- a/EMRI_DET/nn/model_creation.py +++ b/EMRI_DET/nn/model_creation.py @@ -43,7 +43,7 @@ def create_mlp(input_features, output_features, neurons, layers, activation, mod raise RuntimeError('Length of neuron vector does not equal number of hidden layers.') else: neurons = [neurons, ] - model = LinearModel(input_features, output_features, neurons, layers, activation, model_name, initialisation=init, use_dropout=use_dropout,drop_p=drop_p,use_bn=use_bn) + model = LinearModel(input_features, output_features, neurons, layers, activation, model_name, initialisation=init, use_dropout=use_dropout,drop_p=drop_p,use_bn=use_bn, out_activation=out_activation) model.norm_type=norm_type Path(get_script_path()+f'/../models/{model_name}/').mkdir(parents=True, exist_ok=True) pickle.dump(model, open(get_script_path()+f'/../models/{model_name}/function.pickle', "wb"), pickle.HIGHEST_PROTOCOL) # save blank model diff --git a/EMRI_DET/validate.py b/EMRI_DET/validate.py index 7a55764..5633d12 100644 --- a/EMRI_DET/validate.py +++ b/EMRI_DET/validate.py @@ -66,7 +66,8 @@ def run_on_dataset(model, test_data, distances=None, n_batches=1, device=None, y distances = np.ones(xdata.shape[0]) * 0.5 out_unnorm *= (0.5/distances)[:,None] - + if ydata.ndim == 1: + out_unnorm = out_unnorm.flatten() outputs = (out_unnorm,) if runtime: -- GitLab