Commit 50229f6d authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Added mismatch heatmap corner plot, for assessing accuracy across the parameter space.

parent e6df6b5d
Loading
Loading
Loading
Loading
+126 −0
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@ import numpy as np
import matplotlib.pyplot as plt
import time
from EMRI_DET.utilities import norm_inputs, unnorm, get_script_path
import seaborn as sns
sns.set_theme()


def run_on_dataset(model, test_data, n_batches=1, device=None, y_transform_fn=None, runtime=False):
@@ -162,3 +164,127 @@ def plot_difference_histogram(comparison_sets, model_name, xlabel, title=None, t

    plt.savefig(get_script_path() + f'/../models/{model_name}_hist_{logname}{diff}diff.png', **save_kwargs)
    plt.close()


def grid_heatmap_corner(dataframe, truth_column, pred_column, log=True, ratio=False, save=True, save_fp='heatmap_corner.png', savefig_kwargs={}):
    """
    Only for grids (in long form as a pandas dataframe, not Nd array)! Don't use for uniform data.
    Args:
        dataframe:
        truth_column:
        pred_column:
        log:
        ratio:

    Returns:

    """
    subframe = dataframe.drop([truth_column,pred_column],axis=1)
    # dimensions = np.array([len(np.unique(col)) for col in subframe])  # dimensions of the grid
    param_mins = [np.min(subframe[col].to_numpy()) for col in subframe]

    param_maxes = [np.max(subframe[col].to_numpy()) for col in subframe]
    titles = list(subframe)  # column headings
    nparams = len(titles)
    # plotmaps = np.zeros(shape=tuple(dimensions))
    plotmaps = []
    for i in range(nparams):
        col1 = subframe[titles[i]].to_numpy()
        gridvalues1 = np.unique(col1)
        for j in range(i+1,nparams):
            col2 = subframe[titles[j]].to_numpy()
            gridvalues2 = np.unique(col2)

            heatmap_here = np.zeros((gridvalues1.size,gridvalues2.size))
            for k,gv1 in enumerate(gridvalues1):
                for l,gv2 in enumerate(gridvalues2):
                    inds_to_keep = np.where(np.logical_and(col1 == gv1,col2 == gv2))
                    truths = dataframe[truth_column].to_numpy()[inds_to_keep]
                    preds = dataframe[pred_column].to_numpy()[inds_to_keep]

                    if ratio:
                        temp = preds/truths
                        if log:
                            temp = np.log10(temp)
                    else:
                        temp = preds - truths
                        if log:
                            temp = np.log10(abs(temp))

                    heatmap_here[k,l] = np.mean(temp)
            plotmaps.append(heatmap_here)

    vmin = np.min([np.min(hmap) for hmap in plotmaps])
    vmax = np.max([np.max(hmap) for hmap in plotmaps])

    fig, ax = plt.subplots(ncols=5, nrows=5, figsize=(8 + int(2*nparams),6 + int(2*nparams)))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    num = 0
    for i in range(nparams):
        for j in range(i+1,nparams):
            temp_im = ax[j,i].imshow(plotmaps[num].T, origin='lower', aspect='auto',
                                     extent=[param_mins[i],param_maxes[i],param_mins[j],param_maxes[j]],
                                     vmin=vmin, vmax=vmax)
            ax[i,j].axis('off')
            ax[j,i].grid(False)
            if i != 0:
                ax[j,i].yaxis.set_visible(False)
            if j != 4:
                ax[j,i].xaxis.set_visible(False)

            num += 1

    for i in range(nparams):
        ax[-1, i].set_xlabel(titles[i])

    for i in range(1, nparams):
        ax[i, 0].set_ylabel(titles[i])


    singles = []
    for i in range(nparams):
        col1 = subframe[titles[i]].to_numpy()
        gridvalues1 = np.unique(col1)

        this_line = np.zeros(len(gridvalues1))
        for k,gv1 in enumerate(gridvalues1):
            inds_to_keep = np.where(col1 == gv1)
            truths = dataframe[truth_column].to_numpy()[inds_to_keep]
            preds = dataframe[pred_column].to_numpy()[inds_to_keep]

            if ratio:
                temp = preds/truths
                if log:
                    temp = np.log10(temp)
            else:
                temp = preds - truths
                if log:
                    temp = np.log10(abs(temp))
            this_line[k] = np.mean(temp)

        singles.append(this_line)

    for i in range(nparams):
        xvals = np.unique(subframe[titles[i]].to_numpy())
        ax[i, i].plot(xvals, singles[i], linestyle='dashed',marker='x',c='k')
        ax[i,i].set_title(titles[i])
        ax[i,i].yaxis.tick_right()
        ax[i,i].grid('on')
        ax[i, i].tick_params(axis="both", which='both', direction='in')

        if i != nparams-1:
            ax[i,i].tick_params(labelbottom=False)

    cbarlabel = ''
    if log:
        cbarlabel += 'log10'
    if ratio:
        cbarlabel+='(ratio)'
    else:
        cbarlabel+='|diff|'
    fig.colorbar(temp_im, ax=ax.ravel().tolist(), location='left', label=cbarlabel, pad=0.08)

    if save:
        fig.savefig(save_fp, **savefig_kwargs)
    else:
        plt.show()
 No newline at end of file
+29 −0
Original line number Diff line number Diff line
import numpy as np
import pandas as pd
from EMRI_DET.validate import grid_heatmap_corner


def function_1(x1,x2,x3,x4,x5):
    return x1/x2/x3*x4*x5


def function_2(x1,x2,x3,x4,x5):
    return x1/x2/x3*x4*x5**2


vals = np.linspace(2,10,10)
combos = [[v1,v2,v3,v4,v5] for v1 in vals for v2 in vals for v3 in vals for v4 in vals for v5 in vals]

titles = ['x1','x2','x3','x4','x5','Truth','Pred']
out = np.zeros(shape=(len(combos),7))
for i,combo in enumerate(combos):
    truth = function_1(*combo)
    pred = function_2(*combo)
    out[i,:5] = combo
    out[i,5] = truth
    out[i,6] = pred

df = pd.DataFrame(out,columns=titles)

grid_heatmap_corner(df,'Truth','Pred',log=True, ratio=True, save=False)
+1 −0
Original line number Diff line number Diff line
@@ -7,4 +7,5 @@ setup(
   author='Christian Chapman-Bird',
   author_email='c.chapman-bird.1@research.gla.ac.uk',
   packages=['EMRI_DET','EMRI_DET.nn'],
   zip_safe=False,
)