Skip to content
Snippets Groups Projects
Commit 27052988 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Initial base for population container code.

parent c2c1a486
No related branches found
No related tags found
No related merge requests found
import pandas as pd
import numpy as np
from scipy.integrate import quad
class Population:
def __init__(self, populations_to_add=None):
self.populations = []
if populations_to_add is not None:
for pop in populations_to_add:
if not isinstance(pop, Subpopulation):
raise Exception('You must construct populations as Subpopulation instances.')
self.add_population(pop)
def add_population(self, pop):
self.populations.append(pop)
def pdf(self, parameters):
prod = 1
for pop in self.populations:
param_subset = {key: value for key, value in parameters.items() if key in pop.param_names}
prod *= pop.get_pdf(param_subset)
return prod
def sample(self, num):
params_in = []
for pop in self.populations:
params_in.extend(pop.param_names)
df = pd.DataFrame(columns=params_in)
for pop in self.populations:
df[pop.param_names] = pop.sample(num)
class Subpopulation:
"""
Specify function with keyword arguments only.
"""
def __init__(self, pdf, param_names, hypers, lower_limits, upper_limits, normalise=False):
super().__init__()
self._pdf = pdf
self.normnum = 1
self.normalise = normalise
self.lower_limits = lower_limits
self.upper_limits = upper_limits
self.ndim = len(upper_limits)
if self.ndim > 1:
raise NotImplementedError('Multivariate populations are not yet implemented.')
if isinstance(param_names, str):
self.param_names = [param_names, ]
elif isinstance(param_names, list):
self.param_names = param_names
if isinstance(hypers, dict):
self.hypers = hypers
else:
raise Exception('Hyperparameters must be input as a dictionary.')
self.renormalise()
def get_pdf(self, params):
return self._pdf(**params, **self.hypers) / self.normnum
def renormalise(self):
if self.normalise:
if self.ndim == 1:
low_lims = self.lower_limits[self.param_names[0]]
up_lims = self.upper_limits[self.param_names[0]]
normfactor = quad(self.get_pdf, low_lims, up_lims)[0]
else:
raise NotImplementedError
self.normnum = normfactor
def sample(self, n):
low_lims = self.lower_limits[self.param_names[0]]
up_lims = self.upper_limits[self.param_names[0]]
inputs = np.linspace(low_lims, up_lims, 1000)
dist = Distribution(self.get_pdf(inputs), transform=lambda i: i - inputs.shape[0] / 2)
return dist(n)
def update_limits(self):
for pm in self.param_names:
try:
self.lower_limits[pm] = self.hypers[pm]
except KeyError:
pass
try:
self.upper_limits[pm] = self.hypers[pm]
except KeyError:
pass
def set_hypers(self, hypers):
self.hypers = hypers
self.update_limits()
self.renormalise()
class Distribution:
"""
draws samples from a one dimensional probability distribution,
by means of inversion of a discrete inversion of a cumulative density function
the pdf can be sorted first to prevent numerical error in the cumulative sum
this is set as default; for big density functions with high contrast,
it is absolutely necessary, and for small density functions,
the overhead is minimal
a call to this distibution object returns indices into density array
from https://stackoverflow.com/a/21101584
x = np.linspace(-100, 100, 512)
p = np.exp(-x**2)
pdf = p[:,None]*p[None,:] #2d gaussian
dist = Distribution(pdf, transform=lambda i:i-256)
print dist(1000000).mean(axis=1) #should be in the 1/sqrt(1e6) range
import matplotlib.pyplot as pp
pp.scatter(*dist(1000))
pp.show()
"""
def __init__(self, pdf, sort=True, interpolation=True, transform=lambda x: x):
self.shape = pdf.shape
self.pdf = pdf.ravel()
self.sort = sort
self.interpolation = interpolation
self.transform = transform
# a pdf can not be negative
assert (np.all(pdf >= 0))
# sort the pdf by magnitude
if self.sort:
self.sortindex = np.argsort(self.pdf, axis=None)
self.pdf = self.pdf[self.sortindex]
# construct the cumulative distribution function
self.cdf = np.cumsum(self.pdf)
@property
def ndim(self):
return len(self.shape)
@property
def sum(self):
"""cached sum of all pdf values; the pdf need not sum to one, and is implicitly normalized"""
return self.cdf[-1]
def __call__(self, n):
"""draw """
# pick numbers which are uniformly random over the cumulative distribution function
choice = np.random.uniform(high=self.sum, size=n)
# find the indices corresponding to this point on the CDF
index = np.searchsorted(self.cdf, choice)
# if necessary, map the indices back to their original ordering
if self.sort:
index = self.sortindex[index]
# map back to multi-dimensional indexing
index = np.unravel_index(index, self.shape)
index = np.vstack(index)
# is this a discrete or piecewise continuous distribution?
if self.interpolation:
index = index + np.random.uniform(size=index.shape)
return self.transform(index)
# if __name__ == '__main__':
# x = np.linspace(-100, 100, 10000)
# p = np.exp(-x**2)
# pdf = p[:,None]*p[None,:] #2d gaussian
# dist = Distribution(pdf, transform=lambda i:i - x.size/2)
# print(dist(1000000).mean(axis=1)) #should be in the 1/sqrt(1e6) range
# import matplotlib.pyplot as pp
# pp.hist2d(*dist(10000000),bins=100)
# # pp.scatter(*dist(100000))
# pp.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment