Commit d409c132 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Changed package name

parent d9bd22ca
Loading
Loading
Loading
Loading

emri_comfi/model.py

0 → 100644
+320 −0
Original line number Diff line number Diff line
from functools import partial
import numpy as np
from scipy.integrate import quad
from scipy.constants import c
try:
    import cupy as cp
    xp = cp
    from cupyx.scipy.ndimage import map_coordinates
    from cupyx.scipy.special import erf
except ImportError:
    xp = np
    from scipy.ndimage import map_coordinates
    from scipy.special import erf

def trapz(y, x=None, dx=1.0, axis=-1):
    y = xp.asanyarray(y)
    if x is None:
        d = dx
    else:
        x = xp.asanyarray(x)
        if x.ndim == 1:
            d = xp.diff(x)
            # reshape to correct shape
            shape = [1] * y.ndim
            shape[axis] = d.shape[0]
            d = d.reshape(shape)
        else:
            d = xp.diff(x, axis=axis)
    ndim = y.ndim
    slice1 = [slice(None)] * ndim
    slice2 = [slice(None)] * ndim
    slice1[axis] = slice(1, None)
    slice2[axis] = slice(None, -1)
    product = d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0
    try:
        ret = product.sum(axis)
    except ValueError:
        ret = xp.add.reduce(product, axis)
    return ret

def trapz_cumsum(y, x=None, dx=1.0, axis=-1, flipped=False):
    y = xp.asanyarray(y)
    if x is None:
        d = dx
    else:
        x = xp.asanyarray(x)
        if x.ndim == 1:
            d = xp.diff(x)
            # reshape to correct shape
            shape = [1] * y.ndim
            shape[axis] = d.shape[0]
            d = d.reshape(shape)
        else:
            d = xp.diff(x, axis=axis)
    ndim = y.ndim
    slice1 = [slice(None)] * ndim
    slice2 = [slice(None)] * ndim
    slice1[axis] = slice(1, None)
    slice2[axis] = slice(None, -1)
    product = d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0
    if not flipped:
        ret = xp.cumsum(product, axis=axis)
    elif flipped:
        ret = xp.cumsum(xp.flip(product, axis=axis), axis=axis)
    return ret

def dL(z):
    '''
    Luminosity distance from redshift assuming default cosmology.
    :param z: Redshift
    :return: dL, in Gpc.
    '''
    h = 0.6774
    omega_m = 0.3089
    omega_lambda = 1 - omega_m
    dH = 1e-5 * c / h

    def E(z):
        return np.sqrt((omega_m * (1 + z) ** (3) + omega_lambda))
    def I(z):
        fact = lambda x: 1 / E(x)
        integral = quad(fact, 0, z)
        return integral[0]

    return (1+z) * dH * I(z) * 1e-3

def truncnorm(xx, mu, sigma, high, low):
    # breakpoint()
    x_op = xp.repeat(xx, mu.size).reshape((xx.size,mu.size)).T
    hi_op = xp.repeat(high, xx.size).reshape((high.size,xx.size))
    lo_op = xp.repeat(low, xx.size).reshape((low.size,xx.size))

    norm = 2**0.5 / np.pi**0.5 / sigma
    norm /= erf((high - mu) / 2**0.5 / sigma) + erf((mu - low) / 2**0.5 / sigma)  #vector of norms
    try:
        prob = xp.exp(-xp.power(xx[None,:] - mu[:,None], 2) / (2 * sigma[:,None]**2)) # array of dims len(xx) * len(mu)
        prob *= norm[:,None]  # should be fine considering dimensionality
        prob[x_op < lo_op] = 0
        prob[x_op > hi_op] = 0
    except IndexError:
        prob = xp.exp(-xp.power(xx - mu, 2) / (2 * sigma**2)) # vector of len(xx)
        prob *= norm
        prob *= (xx <= high) & (xx >= low)
    return prob

def powerlaw(xx, lam, xmin, xmax):
    x_op = xp.repeat(xx, lam.size).reshape((xx.size,lam.size)).T
    hi_op = xp.repeat(xmax, xx.size).reshape((xmax.size,xx.size))
    lo_op = xp.repeat(xmin, xx.size).reshape((xmin.size,xx.size))

    norm = (1+lam)/(xmax**(1+lam) - xmin**(1+lam)) # vector of norms
    try:
        out =  xx[None,:]**lam[:,None] * norm[:,None] # array of dims len(xx) * len(lam)
        out[x_op < lo_op] = 0
        out[x_op > hi_op] = 0
    except IndexError:
        out =  xx**lam * norm # array of dims len(xx) * len(lam)
        out *= (xx <= xmax) & (xx >= xmin)
    return out

def smooth_exp(m, smooth_scale):
    return xp.exp((smooth_scale/m)+(smooth_scale/(m-smooth_scale)))

def smoothing(masses, mmin, delta_m):
    big_masses = xp.repeat(masses, delta_m.size).reshape((masses.size, delta_m.size)).T
    lows = xp.repeat(mmin, masses.size).reshape((mmin.size,masses.size))
    deltas = xp.repeat(delta_m, masses.size).reshape((delta_m.size,masses.size))
    try:
        ans = (1 + smooth_exp(masses[None,:] - mmin[:,None], delta_m[:,None]))**-1
        ans[big_masses < lows] = 0
        ans[big_masses > lows + deltas] = 1
    except IndexError:
        ans = (1 + smooth_exp(masses - mmin, delta_m))**-1
        ans[masses < mmin] = 0
        ans[masses > mmin + delta_m] = 1
    return ans

def ligo_ppop(m, parameters):
    tcs = two_component_single(m, parameters['alpha'],parameters['mmin'],parameters['mmax'],parameters['lam'],parameters['mpp'],parameters['sigpp'])
    smth = smoothing(m, parameters['mmin'], parameters['delta_m'])
    return tcs * smth  # NOT normalised!

def two_component_single(
    mass, alpha, mmin, mmax, lam, mpp, sigpp, gaussian_mass_maximum=100
):
    p_pow = powerlaw(mass, -alpha, mmin, mmax)  # 2d array, dims N_samp * N_hp
    p_norm = truncnorm(mass, mu=mpp, sigma=sigpp, high=xp.asarray(gaussian_mass_maximum), low=mmin)  # same as above

    try:
        prob = (1 - lam[:,None]) * p_pow + lam[:,None] * p_norm
    except IndexError:
        prob = (1 - lam) * p_pow + lam * p_norm
    return prob

def p1(m, p1norm, truths):
    p = ligo_ppop(m, truths)
    return p/p1norm
    
def p2_no_m1(m, beta, mmin, delta_m):
    return m**beta * smoothing(m, mmin, delta_m)

def p2_total(m, Kfunc, p1norm, beta, mmin, delta_m):
    return p2_no_m1(m, beta, mmin, delta_m)/p1norm * Kfunc(m) # integral is a function of m

def construct_I(mvec, dm, beta, mmin, delta_m):
    prob = p2_no_m1(mvec, beta, mmin, delta_m)
    cumulative_integral = xp.append(0,trapz_cumsum(prob, dx=dm, axis=-1))
    return lambda m: xp.interp(m, mvec, cumulative_integral)

def construct_K(mvec, dm, p1probs, beta, mmin, delta_m):
    Ifunc = construct_I(mvec, dm, beta, mmin, delta_m)
    prob = xp.nan_to_num(p1probs / Ifunc(mvec))
    cumulative_integral = xp.flip(xp.append(0,trapz_cumsum(prob, dx=dm, axis=-1, flipped=True)),axis=-1) # flip back
    return lambda m: xp.interp(m, mvec, cumulative_integral)

def ligo_combined_pm(m, mvec,dm, truths, return_pone_ptwo=False):
    pm1 =ligo_ppop(mvec, parameters=truths)  # cache + get p1 normalisation
    pm1_norm = trapz(pm1, dx=dm)
    pone = p1(m, pm1_norm, truths)

    Kfunc = construct_K(mvec, dm, pm1, truths["beta"], truths["mmin"], truths["delta_m"])
    ptwo = p2_total(m, Kfunc, pm1_norm, truths["beta"], truths["mmin"], truths["delta_m"])

    if return_pone_ptwo:
        return 0.5*(pone + ptwo), pone, ptwo
    else:
        return 0.5*(pone + ptwo)

def dVdz(z):
    h = 0.6774
    omega_m = 0.3089
    omega_lambda = 1 - omega_m
    def E(z):
        return np.sqrt((omega_m * (1 + z) ** (3) + omega_lambda))
    dh = 3000/h # Mpc
    dm = dh * quad(lambda x: 1/E(x),0,z)[0]
    ez = E(z)
    return 4*np.pi * dh * dm**2 / ez / (1+z) # 1+z maps time properly to redshift

class OverallDistribution(): # bringing together all of the painful spaghetti
    def __init__(self, mmin, mmax, mumin, mumax, zmax, dVdz_spline=None, Np=10):
        self.mmin = mmin
        self.mmax = mmax
        self.mumin = mumin  # these should agree with the priors on the mu distribution to provide sufficient coverage.
        self.mumax = mumax
        self.zmax = zmax 
        self.Np = Np  # adjustable parameter for the rate, if one is so inclined.
        self._dVdz = dVdz_spline
        self._prepare_p0()
        self.cache_preliminaries()

    def __call__(self, M, mu, z, pmu, pmu_params={}): # compute probability!
        norm = self.get_rate(pmu, pmu_params)
        dn_part = self.dVdz_wrap(z) * self._mbh_distribution(M) * pmu(mu, **pmu_params)
        gam = self._gamma(mu, M)
        other_part = self._p0(z, M) * gam * self._kappa(mu, M, gam) * self._R0(M)
        # if other_part.shape[0] == 1:
        #     other_part = other_part[0]
        try:
            return dn_part * other_part / norm
        except ValueError:
            return dn_part * other_part / norm[:,None]

    def cache_preliminaries(self):
        #cache some preliminaries for the normalisation
        self.mvec = xp.linspace(self.mmin, self.mmax, 1000)   
        self.muvec = xp.linspace(self.mumin, self.mumax, 1100) 
        self.zvec = xp.linspace(0., self.zmax, 900)
        self.dm = self.mvec[1]-self.mvec[0]
        self.dmu = self.muvec[1]-self.muvec[0]
        self.dz = self.zvec[1] - self.zvec[0]
        self.get_dVdz_norm()
        self.get_Mnorm()
        self.N = self.dVdznorm * self.Mnorm

        p0star = self._integrate_over_z() # integrate over z
        thr_gamma = self._gamma(self.muvec[:,None], self.mvec)
        thr_kappa = self._kappa(self.muvec[:,None], self.mvec, thr_gamma)
        self.R = self._R0(self.mvec)*thr_gamma*thr_kappa
        all_but_mu = self._mbh_distribution(self.mvec)*p0star*self.R  # 2d
        self.integrated_over_M = trapz(all_but_mu,dx=self.dm,axis=1)  # 1d, ready to be normalised by the rate

    def get_rate(self, mu_probdist, mu_probdist_kwargs={}):  # the normalisation for the distribution
            muprobs = mu_probdist(self.muvec, **mu_probdist_kwargs)  # assume this is normalised
            combined = self.integrated_over_M * muprobs
            return self.N * trapz(combined, dx=self.dmu, axis=-1)# * 1600/1520
    
    def dVdz_wrap(self, z):
        try:
            z = z.get()
        except:
            pass
        vals = self._dVdz(z)
        return xp.asarray(vals)/self.dVdznorm

    def get_dVdz_norm(self):
        self.dVdznorm = 1.
        out = self.dVdz_wrap(self.zvec)
        self.dVdznorm = trapz(out, dx=self.dz) 

    def get_Mnorm(self):
        self.Mnorm = 1.
        out = self._mbh_distribution(self.mvec)
        self.Mnorm = trapz(out, dx=self.dm)

    def _mbh_distribution(self, M):
        return 0.0055 * (M/3e6)**-0.3 / M / self.Mnorm  # from babak et al. 2017, corrected from dn/dlogm to dn/dm

    def _tEMRI(self,M):  # sigmoid fitted to data
        return 10.65 / (1 + xp.exp(3.85*(xp.log10(M)-5.82))) + 2.74  # Gyr

    def _prepare_p0(self):
        p0_logmasses = xp.linspace(4.5, 7.5, 7)
        p0_outputs = xp.asarray([[0.999,0.999,0.999,0.998,0.998,0.998,0.997],[0.997,0.991, 0.98,0.96,0.95,0.94,0.9],[0.23,0.14, 0.18, 0.295, 0.4, 0.53, 0.62],[0.08,0.09,0.15, 0.29, 0.47, 0.63, 0.67],[0.018, 0.023,0.07, 0.18, 0.3, 0.43, 0.5],np.array([0.08,0.09,0.15, 0.29, 0.47, 0.63, 0.67])**3,np.array([0.08,0.09,0.15, 0.29, 0.47, 0.63, 0.67])**4])
        p0_redshifts = xp.linspace(0, 3, 7)
        self._p0_spline = partial(map_coordinates, xp.log(p0_outputs).T)
        self._p0_zmin = p0_redshifts[0].item()
        self._p0_zmax = p0_redshifts[-1].item()
        self._p0_Mmin = p0_logmasses[0].item()
        self._p0_Mmax = p0_logmasses[-1]

    def _p0(self,z,M):
        # rescale to limits
        z = xp.asarray((z - self._p0_zmin) / (self._p0_zmax - self._p0_zmin)) * 6
        M = xp.asarray((xp.log10(M) - self._p0_Mmin) / (self._p0_Mmax - self._p0_Mmin)) * 6
        M[M<0] = 0
        M[M>6] = 6  # keep M within the grid - no extrap.
        reshape = False
        if M.ndim == 2:
            reshape = True
            nz = z.size
            nm = M.size
            reshap = (nm, nz)
            M = M.repeat(nz, axis=1).flatten()
            z = z[:,None].repeat(nm,axis=1).T.flatten()
        zM = xp.vstack((z, M))
        outs = xp.exp(self._p0_spline(zM, cval=0.))
        if reshape:
            outs = outs.reshape(reshap)
        return outs

    def _p0_z_integrand(self, z, M):
        return self.dVdz_wrap(z)*self._p0(z, M)

    def _integrate_over_z(self):
        zints = xp.asarray([self._p0_z_integrand(xp.repeat(zv,self.mvec.size), self.mvec) for zv in self.zvec])
        return xp.asarray(trapz(zints, dx=self.dz, axis=0))

    def _R0(self,M):
        return 300*(M/1e6)**-0.19  # units of Gyr^-1

    def _gamma(self, mu, M):
        out = 1.2/(1+self.Np) * 10 * (M/1e6)**0.06 / mu#[None,:]
        out[out>1] = 1
        return out

    def _kappa(self, mu, M, gam):
        out = xp.exp(-1) * M / self._R0(M) / self._tEMRI(M) / (1+self.Np) / gam / mu#[None,:]
        out[out>1]=1
        return out
+68 −0
Original line number Diff line number Diff line
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
from emri_comfi.nn.utilities import get_script_path
import dill as pickle
from pathlib import Path
import numpy as np

class LinearModel(nn.Module):
    def __init__(self, in_features, out_features, neurons, n_layers, activation, name, out_activation=None, initialisation=xavier_uniform_, use_dropout=False,drop_p=0.25, use_bn=False):
        super().__init__()
        self.initial = initialisation
        self.name = name

        layers = [nn.Linear(in_features, neurons[0]), activation()]
        for i in range(n_layers - 1):
            layers.append(nn.Linear(neurons[i], neurons[i + 1]))

            if use_dropout:
                layers.append(nn.Dropout(drop_p))  
            layers.append(activation())
            if use_bn:
                layers.append(nn.BatchNorm1d(num_features=neurons[i+1]))
                
        layers.append(nn.Linear(neurons[-1], out_features))
        if out_activation is not None:
            layers.append(out_activation())
        
        self.layers = nn.Sequential(*layers)
        self.layers.apply(self.init_weights)

    def forward(self, x):
        return self.layers(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            self.initial(m.weight)


def create_mlp(input_features, output_features, neurons, layers, activation, model_name, out_activation=None, init=xavier_uniform_, device=None, norm_type='z-score', use_dropout=False,drop_p=0.25, use_bn=False, outdir='../models'):
    if isinstance(neurons, list):
        if len(neurons) != layers:
            raise RuntimeError('Length of neuron vector does not equal number of hidden layers.')
    else:
        neurons = [neurons, ]
    model = LinearModel(input_features, output_features, neurons, layers, activation, model_name, initialisation=init, use_dropout=use_dropout,drop_p=drop_p,use_bn=use_bn, out_activation=out_activation)
    model.norm_type=norm_type
    Path(get_script_path()+f'/{outdir}/{model_name}/').mkdir(parents=True, exist_ok=True)
    pickle.dump(model, open(get_script_path()+f'/{outdir}/{model_name}/function.pickle', "wb"), pickle.HIGHEST_PROTOCOL)  # save blank model

    if device is not None:
        model = model.to(device)
    return model


def load_mlp(model_name, device, get_state_dict=False, outdir='../models'):
    model = pickle.load(open(get_script_path()+f'/{outdir}/{model_name}/function.pickle', "rb"))  # load blank model
    if get_state_dict:
        model.load_state_dict(torch.load(open(get_script_path()+f'/{outdir}/{model_name}/model.pth', "rb"), map_location=device))
    if model.norm_type is not None:
        xscalevals = np.load(get_script_path() + f'/{outdir}/{model.name}/xdata_inputs.npy')
        yscalevals = np.load(get_script_path() + f'/{outdir}/{model.name}/ydata_inputs.npy')
    else:
        xscalevals = None
        yscalevals = None
    model.xscalevals = xscalevals
    model.yscalevals = yscalevals
    return model
+142 −0
Original line number Diff line number Diff line
import torch
import numpy as np
import matplotlib.pyplot as plt
from emri_comfi.nn.utilities import norm, norm_inputs, unnorm_inputs, unnorm, get_script_path
from sys import stdout
from pathlib import Path

def model_train_test(data, model, device, n_epochs, n_batches, loss_function, optimizer, verbose=False, return_losses=False, update_every=None, n_test_batches=None, save_best=False, scheduler=None, outdir='../models'):
    
    if n_test_batches is None:
        n_test_batches = n_batches
    
    xtrain, ytrain, xtest, ytest = data
    model.to(device)

    name = model.name
    path = get_script_path()
    norm_type = model.norm_type
    Path(get_script_path()+f'/{outdir}/{name}/').mkdir(parents=True, exist_ok=True)
    if norm_type == 'z-score':
        np.save(path+f'/{outdir}/{name}/xdata_inputs.npy',np.array([xtrain.mean(axis=0), xtrain.std(axis=0)]))
        np.save(path+f'/{outdir}/{name}/ydata_inputs.npy',np.array([ytrain.mean(), ytrain.std()]))
    elif norm_type == 'uniform':
        np.save(path+f'/{outdir}/{name}/xdata_inputs.npy',np.array([np.min(xtrain,axis=0), np.max(xtrain,axis=0)]))
        np.save(path+f'/{outdir}/{name}/ydata_inputs.npy',np.array([np.min(ytrain), np.max(ytrain)]))
    elif norm_type is None:
        pass
    xtest = torch.from_numpy(norm_inputs(xtest, ref_dataframe=xtrain, norm_type=norm_type)).to(device).float()
    ytest = torch.from_numpy(norm(ytest, ref_dataframe=ytrain, norm_type=norm_type)).to(device).float()
    xtrain = torch.from_numpy(norm_inputs(xtrain, ref_dataframe=xtrain, norm_type=norm_type)).to(device).float()
    ytrain = torch.from_numpy(norm(ytrain, ref_dataframe=ytrain, norm_type=norm_type)).to(device).float()

    ytrainsize = len(ytrain)
    ytestsize = len(ytest)

    train_losses = []
    test_losses = []
    rate = []
    # Run the training loop

    datasets = {"train": [xtrain, ytrain], "test": [xtest, ytest]}

    cutoff_LR = n_epochs - 50
    lowest_loss = 1e5
    for epoch in range(n_epochs):  # 5 epochs at maximum
        # Print epoch
        for phase in ['train','test']:
            if phase == 'train':
                model.train(True)
                shuffled_inds = torch.randperm(ytrainsize)

                # Set current loss value
                current_loss = 0.0

                # Iterate over the DataLoader for training data
                # Get and prepare inputs
                inputs, targets = datasets[phase]
                inputs = inputs[shuffled_inds]
                targets = targets[shuffled_inds]

                #targets = targets.reshape((targets.shape[0], 1))

                for i in range(n_batches):
                    for param in model.parameters():
                        param.grad = None
                    outputs = model(inputs[i * ytrainsize // n_batches:(i+1)*ytrainsize // n_batches])
                    loss = loss_function(outputs, targets[i * ytrainsize // n_batches: (i+1)*ytrainsize // n_batches])
                    loss.backward()
                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()
                    current_loss += loss.item()

                train_losses.append(current_loss / n_batches)

            else:
                with torch.no_grad():
                    model.train(False)
                    shuffled_inds = torch.randperm(ytestsize)
                    current_loss = 0.0
                    inputs, targets = datasets[phase]
                    inputs = inputs[shuffled_inds]
                    targets = targets[shuffled_inds]

#                     targets = targets.reshape((targets.shape[0], 1))

                    for i in range(n_test_batches):
                        outputs = model(inputs[i * ytestsize // n_batches: (i+1)*ytestsize // n_batches])
                        loss = loss_function(outputs, targets[i * ytestsize // n_batches: (i+1)*ytestsize // n_batches])
                        current_loss += loss.item()

                    test_losses.append(current_loss / n_test_batches)
        if test_losses[-1] < lowest_loss:
            lowest_loss = test_losses[-1]
            if save_best:
                torch.save(model.state_dict(),path+f'/{outdir}/{name}/model.pth')
                
#         if epoch >= cutoff_LR:
#             scheduler.step()
#             rate.append(scheduler.get_last_lr()[0])
#         else:
#             rate.append(learning_rate)
        if verbose:
            stdout.write(f'\rEpoch: {epoch} | Train loss: {train_losses[-1]:.3e} | Test loss: {test_losses[-1]:.3e} ')
        if update_every is not None:
            if epoch % update_every == 0:
                epochs = np.arange(epoch+1)
                plt.semilogy(epochs, train_losses, label='Train')
                plt.semilogy(epochs, test_losses, label='Test')
                plt.legend()
                plt.xlabel('Epochs')
                plt.ylabel('Loss')
                plt.title('Train and Test Loss Across Train Epochs')
                plt.savefig(path+f'/{outdir}/{name}/losses.png')
                #plt.show()
                plt.close()
                
                if not save_best:
                    torch.save(model.state_dict(),path+f'/{outdir}/{name}/model.pth')

        
    if verbose:
        print('\nTraining complete - saving.')
    
    if not save_best:
        torch.save(model.state_dict(),path+f'/{outdir}/{name}/model.pth')

    epochs = np.arange(n_epochs)
    plt.semilogy(epochs, train_losses, label='Train')
    plt.semilogy(epochs, test_losses, label='Test')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Train and Test Loss Across Train Epochs')
    plt.savefig(path+f'/{outdir}/{name}/losses.png')
    #plt.show()
    plt.close()

    out = (model,)
    if return_losses:
        out += (train_losses, test_losses,)
    return out
+87 −0

File added.

Preview size limit exceeded, changes collapsed.

+97 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading