Commit b911fc4d authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Some bugfixes. Added confusion matrix optional output for classifier accuracy test.

parent 73021811
Loading
Loading
Loading
Loading
+17 −7
Original line number Diff line number Diff line
@@ -82,16 +82,24 @@ def compute_rmse(comparison_sets):
    return rmse


def test_threshold_accuracy(comparison_sets, threshold):
def test_threshold_accuracy(comparison_sets, threshold, confusion_matrix=False):
    truth, pred = comparison_sets
    out_classified = np.zeros(shape=pred.size)
    out_classified[pred.flatten() >= threshold] = 1

    truth_classified = np.zeros(shape=truth.size)
    truth_classified[truth >= threshold] = 1
    truth_classified[truth.flatten() >= threshold] = 1

    if not confusion_matrix:
        return 1 - np.mean(np.abs(out_classified - truth_classified))
    else:
        confmat = np.zeros((2,2))
        confmat[0,0] = np.sum(np.logical_and(out_classified==0,truth_classified==0))
        confmat[0,1] = np.sum(np.logical_and(out_classified==0,truth_classified==1))
        confmat[1,0] = np.sum(np.logical_and(out_classified==1,truth_classified==0))
        confmat[1,1] = np.sum(np.logical_and(out_classified==1,truth_classified==1))

        return (1-np.mean(np.abs(out_classified-truth_classified)),confmat)

def plot_histograms(comparison_sets, model_name, xlabel, title=None, title_kwargs={}, xlabel_kwargs={}, log=True,
                    fig_kwargs={}, plot_kwargs={}, save_kwargs={}, legend_kwargs={}):
@@ -208,9 +216,9 @@ def grid_heatmap_corner(dataframe, truth_column, pred_column, log=True, ratio=Fa
                            temp = np.log10(temp)
                    else:
                        temp = preds - truths
                        temp = np.mean(temp)
                        temp = np.mean(abs(temp))
                        if log:
                            temp = np.log10(abs(temp))
                            temp = np.log10(temp)

                    heatmap_here[k,l] = temp
            plotmaps.append(heatmap_here)
@@ -255,13 +263,15 @@ def grid_heatmap_corner(dataframe, truth_column, pred_column, log=True, ratio=Fa

            if ratio:
                temp = preds/truths
                temp = np.mean(temp)
                if log:
                    temp = np.log10(temp)
            else:
                temp = preds - truths
                temp = np.mean(abs(temp))
                if log:
                    temp = np.log10(abs(temp))
            this_line[k] = np.mean(temp)
                    temp = np.log10(temp)
            this_line[k] = temp

        singles.append(this_line)