From 27052988790dfa00cd88beab862b17b9beddb44b Mon Sep 17 00:00:00 2001 From: Christian Chapman-Bird Date: Wed, 3 Nov 2021 16:58:30 +0000 Subject: [PATCH] Initial base for population container code. --- src/populations/popBase.py | 178 +++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 src/populations/popBase.py diff --git a/src/populations/popBase.py b/src/populations/popBase.py new file mode 100644 index 0000000..5f674ac --- /dev/null +++ b/src/populations/popBase.py @@ -0,0 +1,178 @@ +import pandas as pd +import numpy as np +from scipy.integrate import quad + + +class Population: + def __init__(self, populations_to_add=None): + self.populations = [] + if populations_to_add is not None: + for pop in populations_to_add: + if not isinstance(pop, Subpopulation): + raise Exception('You must construct populations as Subpopulation instances.') + self.add_population(pop) + + def add_population(self, pop): + self.populations.append(pop) + + def pdf(self, parameters): + prod = 1 + for pop in self.populations: + param_subset = {key: value for key, value in parameters.items() if key in pop.param_names} + prod *= pop.get_pdf(param_subset) + return prod + + def sample(self, num): + params_in = [] + for pop in self.populations: + params_in.extend(pop.param_names) + df = pd.DataFrame(columns=params_in) + for pop in self.populations: + df[pop.param_names] = pop.sample(num) + + +class Subpopulation: + """ + Specify function with keyword arguments only. + """ + + def __init__(self, pdf, param_names, hypers, lower_limits, upper_limits, normalise=False): + super().__init__() + self._pdf = pdf + self.normnum = 1 + self.normalise = normalise + + self.lower_limits = lower_limits + self.upper_limits = upper_limits + + self.ndim = len(upper_limits) + if self.ndim > 1: + raise NotImplementedError('Multivariate populations are not yet implemented.') + + if isinstance(param_names, str): + self.param_names = [param_names, ] + elif isinstance(param_names, list): + self.param_names = param_names + if isinstance(hypers, dict): + self.hypers = hypers + else: + raise Exception('Hyperparameters must be input as a dictionary.') + + self.renormalise() + + def get_pdf(self, params): + return self._pdf(**params, **self.hypers) / self.normnum + + def renormalise(self): + if self.normalise: + if self.ndim == 1: + low_lims = self.lower_limits[self.param_names[0]] + up_lims = self.upper_limits[self.param_names[0]] + normfactor = quad(self.get_pdf, low_lims, up_lims)[0] + else: + raise NotImplementedError + + self.normnum = normfactor + + def sample(self, n): + low_lims = self.lower_limits[self.param_names[0]] + up_lims = self.upper_limits[self.param_names[0]] + inputs = np.linspace(low_lims, up_lims, 1000) + dist = Distribution(self.get_pdf(inputs), transform=lambda i: i - inputs.shape[0] / 2) + return dist(n) + + def update_limits(self): + for pm in self.param_names: + try: + self.lower_limits[pm] = self.hypers[pm] + except KeyError: + pass + try: + self.upper_limits[pm] = self.hypers[pm] + except KeyError: + pass + + def set_hypers(self, hypers): + self.hypers = hypers + self.update_limits() + self.renormalise() + + +class Distribution: + """ + draws samples from a one dimensional probability distribution, + by means of inversion of a discrete inversion of a cumulative density function + + the pdf can be sorted first to prevent numerical error in the cumulative sum + this is set as default; for big density functions with high contrast, + it is absolutely necessary, and for small density functions, + the overhead is minimal + + a call to this distibution object returns indices into density array + + from https://stackoverflow.com/a/21101584 + + x = np.linspace(-100, 100, 512) + p = np.exp(-x**2) + pdf = p[:,None]*p[None,:] #2d gaussian + dist = Distribution(pdf, transform=lambda i:i-256) + print dist(1000000).mean(axis=1) #should be in the 1/sqrt(1e6) range + import matplotlib.pyplot as pp + pp.scatter(*dist(1000)) + pp.show() + + """ + + def __init__(self, pdf, sort=True, interpolation=True, transform=lambda x: x): + self.shape = pdf.shape + self.pdf = pdf.ravel() + self.sort = sort + self.interpolation = interpolation + self.transform = transform + + # a pdf can not be negative + assert (np.all(pdf >= 0)) + + # sort the pdf by magnitude + if self.sort: + self.sortindex = np.argsort(self.pdf, axis=None) + self.pdf = self.pdf[self.sortindex] + # construct the cumulative distribution function + self.cdf = np.cumsum(self.pdf) + + @property + def ndim(self): + return len(self.shape) + + @property + def sum(self): + """cached sum of all pdf values; the pdf need not sum to one, and is implicitly normalized""" + return self.cdf[-1] + + def __call__(self, n): + """draw """ + # pick numbers which are uniformly random over the cumulative distribution function + choice = np.random.uniform(high=self.sum, size=n) + # find the indices corresponding to this point on the CDF + index = np.searchsorted(self.cdf, choice) + # if necessary, map the indices back to their original ordering + if self.sort: + index = self.sortindex[index] + # map back to multi-dimensional indexing + index = np.unravel_index(index, self.shape) + index = np.vstack(index) + # is this a discrete or piecewise continuous distribution? + if self.interpolation: + index = index + np.random.uniform(size=index.shape) + return self.transform(index) + +# if __name__ == '__main__': +# x = np.linspace(-100, 100, 10000) +# p = np.exp(-x**2) +# pdf = p[:,None]*p[None,:] #2d gaussian +# dist = Distribution(pdf, transform=lambda i:i - x.size/2) +# print(dist(1000000).mean(axis=1)) #should be in the 1/sqrt(1e6) range +# import matplotlib.pyplot as pp +# pp.hist2d(*dist(10000000),bins=100) +# # pp.scatter(*dist(100000)) +# pp.show() -- GitLab