Verified Commit c6fd2058 authored by Daniel Williams's avatar Daniel Williams
Browse files

Refactored code to ensure that the method of sources is used to generate a...

Refactored code to ensure that the  method of sources is used to generate a signal, and not the MDC generation method.
parent 5440e3ba
Loading
Loading
Loading
Loading
+29 −20
Original line number Diff line number Diff line
@@ -47,6 +47,29 @@ import matplotlib.pyplot as plt
import re
import random

from minke import sources
sourcemap = {}
for classin in dir(sources):
    classin = sources.__dict__[classin]
    if hasattr(classin, "waveform"):
        sourcemap[classin.waveform] = classin
        
def source_from_row(row):
    waveform = row.waveform
    sourceobj = sourcemap[row.waveform].__new__(sourcemap[row.waveform])
    sourceobj.numrel_data = str("")
    params = {}
    for attr in dir(row):
        if not attr[0] == "_" and not attr[:3] =="get":
            #print attr
            params[attr] = getattr(row, attr)
            setattr(sourceobj, attr, getattr(row, attr))
    sourceobj.params = params
    try:
        sourceobj.time = row.time_geocent_gps
    except:
        pass
    return sourceobj

table_types = {
    # Ad-Hoc
@@ -256,7 +279,9 @@ class MDCSet():
                self.numrel_file = str(sim_burst_table.waveform)
                sim_burst_table.waveform = "Dimmelmeier+08"

            self.waveforms.append(simrow)#_burst_table)
            self.waveforms.append(source_from_row(simrow))
            #self.waveforms.append(simrow)#_burst_table)
            #self.waveforms.
            if full:
                self._measure_hrss(i)
                self._measure_egw_rsq(i)
@@ -292,27 +317,10 @@ class MDCSet():
        hx0 : 
            A copy of the strain in the x polarisation
        """
        # This is a temporary kludge to allow LALSimulation to
        # be bypassed for pre-calculated waveforms.
        # A more robust solution should be considered.
        exceptions = ["Ott+13", "Mueller+12", "Scheidegger+10", "ADI"]
        row = self.waveforms[row]
        swig_row = lalburst.CreateSimBurst()
        for a in lsctables.SimBurstTable.validcolumns.keys():
            try:
                setattr(swig_row, a, getattr( row, a ))
            except AttributeError: continue # we didn't define it
            except TypeError: 
                #print a, getattr(row,a)
                continue # the structure is different than the TableRow
        theta, phi = np.cos(swig_row.incl), swig_row.phi
        swig_row.numrel_data = str(row.numrel_data)
        
        hp, hx = lalburst.GenerateSimBurst(swig_row, 1.0/rate)
        hp0, hx0 = lalburst.GenerateSimBurst(swig_row, 1.0/rate)
        hp, hx, hp0, hx0 = row._generate()
        return hp, hx, hp0, hx0
    
    
    def _getDetector(self, det):
        """
        A method to return a LALDetector object corresponding to a detector's
@@ -434,7 +442,8 @@ class MDCSet():
        hphx : float
            The hrss of |HpHx| 
        """
        hp, hx, hp0, hx0 = self._generate_burst(row)# self.hp, self.hx, self.hp0, self.hx0
        row = self.waveforms[row]
        hp, hx, hp0, hx0 = row._generate() #self._generate_burst(row)# self.hp, self.hx, self.hp0, self.hx0

        hp0.data.data *= 0
        hx0.data.data *= 0
+94 −45
Original line number Diff line number Diff line
@@ -36,6 +36,9 @@ except ImportError:

import matplotlib.pyplot as plt




class Waveform(object):
    """
    Generic container for different source types. 
@@ -43,8 +46,8 @@ class Waveform(object):
    In the future, different sources should subclass this and override the generation routines.
    """
    
    sim = lsctables.New(lsctables.SimBurstTable)
    table_type = lsctables.SimBurstTable
    sim = lsctables.New(table_type)
    
    numrel_data = []
    waveform = "Generic"
@@ -52,7 +55,7 @@ class Waveform(object):

    def _clear_params(self):
        self.params = {}
        for a in lsctables.SimBurstTable.validcolumns.keys():
        for a in self.table_type.validcolumns.keys():
            self.params[a] = None
        

@@ -126,52 +129,39 @@ class Waveform(object):
           hx0 : 
               A copy of the strain in the x polarisation 
        """
        row = self._row() 
        swig_row = lalburst.CreateSimBurst() 
        for a in lsctables.SimBurstTable.validcolumns.keys(): 
            try:
                setattr(swig_row, a, getattr( row, a )) 
            except AttributeError: 
                continue 
            except TypeError: 
                continue 
            try:
                swig_row.numrel_data = row.numrel_data 
            except: pass
        burstobj = self._burstobj()
                
        hp, hx = lalburst.GenerateSimBurst(swig_row, 1.0/rate) 
        # FIXME: Totally inefficent --- but can we deep copy a SWIG SimBurst?  
        # DW: I tried that, and it doesn't seem to work :/
        hp, hx = lalburst.GenerateSimBurst(burstobj, 1.0/rate) 
        if not half :
            hp0, hx0 = lalburst.GenerateSimBurst(swig_row, 1.0/rate) 
            hp0, hx0 = lalburst.GenerateSimBurst(burstobj, 1.0/rate) 
        else: 
            hp0, hx0 = hp, hx
        
        # detrend supernova waveforms
        if hasattr(self, "supernova"):
            hp.data.data, hx.data.data, hp0.data.data, hx0.data.data = scipy.signal.detrend(hp.data.data), scipy.signal.detrend(hx.data.data), scipy.signal.detrend(hp0.data.data), scipy.signal.detrend(hx0.data.data)
            # Rescale for a given distance 
        if row.amplitude and hasattr(self, "supernova"): 
            rescale = 1.0 / (self.file_distance / row.amplitude)
            hp.data.data, hx.data.data, hp0.data.data, hx0.data.data = hp.data.data * rescale, hx.data.data * rescale, hp0.data.data * rescale, hx0.data.data * rescale
        return hp, hx, hp0, hx0 

            if hasattr(self, "has_memory"):
                # Apply the tail correction for memory
                tail_hp = self.generate_tail(length = 1, h_max = hp.data.data[-1])
                tail_hx = self.generate_tail(length = 1, h_max = hx.data.data[-1])
    def _burstobj(self):
        """
        Generate a SimBurst object for this waveform.
        """
        swig_row = self._row()
        burstobj = lalburst.CreateSimBurst()
        
                hp_data = np.append(hp.data.data,tail_hp.data)
                hx_data = np.append(hp.data.data,tail_hx.data)
        for a in self.table_type.validcolumns.keys():
            try:
                setattr(burstobj, a, getattr(swig_row,a))
            except AttributeError:
                continue
            except TypeError: 
                continue

                tail_hp = lal.CreateREAL8Vector(len(hp_data))
                tail_hp.data = hp_data
                tail_hx = lal.CreateREAL8Vector(len(hx_data))
                tail_hx.data = hx_data
        burstobj.waveform = str(self.waveform)
            
                hp.data = tail_hp
                hx.data = tail_hx
        if swig_row.numrel_data:
            burstobj.numrel_data = str(swig_row.numrel_data)
        else:
            burstobj.numrel_data = str("")

        return hp, hx, hp0, hx0 
        return burstobj
    
    def _generate_for_detector(self, ifos, sample_rate = 16384.0, nsamp = 2000):
        data = []
@@ -209,10 +199,16 @@ class Waveform(object):
        for a in self.table_type.validcolumns.keys():
            setattr(row, a, self.params[a])

        if self.numrel_data:
            row.numrel_data = str(self.numrel_data)
        else:
            row.numrel_data = ""
            
        row.waveform = self.waveform
        # Fill in the time
        row.set_time_geocent(GPS(float(self.time)))
        # Get the sky locations
        if not row.ra:
            row.ra, row.dec, row.psi = self.sky_dist()
        row.simulation_id = sim.get_next_id()
        row.waveform_number = random.randint(0,int(2**32)-1)
@@ -459,6 +455,59 @@ class Supernova(Waveform):

        return Hlm

    def _generate(self, rate=16384.0, half=False, distance=None): 
        """
        Generate the burst described in a given row, so that it can be
        measured.
        
        Parameters 
        ---------- 
        rate : float 
           The sampling rate of the signal, in Hz. 
           Defaults to 16384.0Hz
            
        half : bool 
           Only compute the hp and hx once if this is true;
           these are only required if you need to compute the cross
           products. Defaults to False.

        Returns 
        ------- 
           hp : 
              The strain in the + polarisation 
           hx : 
              The strain in the x polarisation
           hp0 : 
              A copy of the strain in the + polarisation 
           hx0 : 
               A copy of the strain in the x polarisation 
        """
        burstobj = self._burstobj()
                
        hp, hx = lalburst.GenerateSimBurst(burstobj, 1.0/rate) 
        if not half :
            hp0, hx0 = lalburst.GenerateSimBurst(burstobj, 1.0/rate) 
        else: 
            hp0, hx0 = hp, hx
        
        # detrend supernova waveforms
        if hasattr(self, "supernova"):
            hp.data.data, hx.data.data, hp0.data.data, hx0.data.data = scipy.signal.detrend(hp.data.data), scipy.signal.detrend(hx.data.data), scipy.signal.detrend(hp0.data.data), scipy.signal.detrend(hx0.data.data)
            # Rescale for a given distance 
        if burstobj.amplitude: 
            rescale = 1.0 / (self.file_distance / burstobj.amplitude)
            hp.data.data, hx.data.data, hp0.data.data, hx0.data.data = hp.data.data * rescale, hx.data.data * rescale, hp0.data.data * rescale, hx0.data.data * rescale

            if hasattr(self, "has_memory"):
                # Apply the tail correction for memory
                tail_hp = self.generate_tail(length = 1, h_max = hp.data.data[-1])
                tail_hx = self.generate_tail(length = 1, h_max = hx.data.data[-1])

                hp.data.data = np.append(hp.data.data(tail_hp))
                hx.data.data = np.append(hp.data.data(tail_hx))
        
        return hp, hx, hp0, hx0 
    
    def generate_tail(self, sampling=16384.0, length = 1, h_max = 1e-23):
        """Generate a "low frequency tail" to append to the end of the
        waveform to overcome problems related to memory in the
+43 −6
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ import unittest

import minke
from minke import mdctools, sources
import numpy as np

class TestMinke(unittest.TestCase):

@@ -24,8 +25,7 @@ class TestMinke(unittest.TestCase):
    # Check that failure occurs if the wrong waveforms are inserted into the wrong tables

    def test_insert_wrong_waveform_to_table(self):
        """Test whether adding a ringdown waveform to a SimBurstTable throws
        an error.
        """Test whether adding a ringdown waveform to a SimBurstTable throws an error.
        """
        mdcset = mdctools.MDCSet(["L1"], table_type = "burst")
        ring = sources.Ringdown()
@@ -34,8 +34,7 @@ class TestMinke(unittest.TestCase):
            mdcset + ring

    def test_insert_burst_waveform_to_burst_table(self):
        """
        Test whether inserting the correct type of waveform into a table works
        """Test whether inserting the correct type of waveform into a table works
        """
        mdcset = mdctools.MDCSet(["L1"], table_type = "burst")
        waveform = sources.Gaussian(0,0,0)
@@ -47,8 +46,7 @@ class TestMinke(unittest.TestCase):
    # Check that the correct XML files are produced for different table types

    def test_write_simbursttable(self):
        """
        Write out a simburst table xml file
        """Test writing out a simburst table xml file
        """
        mdcset = mdctools.MDCSet(["L1"], table_type = "burst")
        waveform = sources.Gaussian(0.1,1e-23,1000)
@@ -87,6 +85,45 @@ class TestMinke(unittest.TestCase):
        self.assertEqual(len(mdcset.waveforms), 1)


class TestMDC(unittest.TestCase):

    def test_Gaussian_Waveform_generation_in_MDC(self):
        """
        Check that waveforms inside an MDC are generated.
        """
        ga = sources.Gaussian(1,1,1)

        mdcset = mdctools.MDCSet(["L1"])
        mdcset + ga
        
        data = mdcset._generate_burst(0)

        gadata = np.array([  3.22991378e-57,   1.68441642e-25,   1.43167033e-23,
         6.20128276e-22,   1.92272388e-20,   4.74643578e-19,
         9.78126889e-18,   1.72560021e-16,   2.64547116e-15,
         3.55831241e-14,   4.22637998e-13,   4.45285513e-12,
         4.17503728e-11,   3.49184250e-10,   2.60952919e-09,
         1.74465929e-08,   1.04437791e-07,   5.60038118e-07,
         2.70506254e-06,   1.18998283e-05,   4.76932871e-05,
         1.74151408e-04,   5.79361856e-04,   1.75600595e-03,
         4.84903412e-03,   1.21993775e-02,   2.79623250e-02,
         5.83931709e-02,   1.11097431e-01,   1.92574662e-01,
         3.04121728e-01,   4.37571401e-01,   5.73592657e-01,
         6.85032870e-01,   7.45370866e-01,   7.38901572e-01,
         6.67350426e-01,   5.49129112e-01,   4.11669004e-01,
         2.81173926e-01,   1.74966576e-01,   9.91946657e-02,
         5.12359410e-02,   2.41109482e-02,   1.03372980e-02,
         4.03787583e-03,   1.43698441e-03,   4.65912462e-04,
         1.37628942e-04,   3.70397795e-05,   9.08197260e-06,
         2.02882766e-06,   4.12393041e-07,   7.53383218e-08,
         1.23283052e-08,   1.80606751e-09,   2.36657746e-10,
         2.77012812e-11,   2.89124066e-12,   2.68404945e-13,
         2.20863500e-14,   1.60321302e-15,   1.01948770e-16,
         5.62074344e-18,   2.64306001e-19,   1.03069500e-20,
         3.15739881e-22,   6.68755767e-24,   6.21121403e-26])

        np.testing.assert_array_almost_equal(data[0].data.data[::5000], gadata)
        
        
if __name__ == '__main__':
    import sys
+93 −1
Original line number Diff line number Diff line
@@ -37,7 +37,99 @@ def download_nr(url):

    return fname.strip(".gz")

class TestMinkeSources(unittest.TestCase):
class TestMinkeAdHocSources(unittest.TestCase):
    """
    Tests for the adhoc analytical waveforms
    """

    def setUp(self):
        """
        Set everything up to make things work.
        """
        self.mdcset = mdctools.MDCSet(['L1', 'H1'])
        self.times = distribution.even_time(start = 1126620016, stop = 1136995216, rate = 630720, jitter = 20)
        self.angles = distribution.supernova_angle(len(self.times))

    def test_Gaussian_Waveform_generation(self):
        """Test that Gaussian waveforms are generated sensibly"""
        ga = sources.Gaussian(1,1,1)
        data = ga._generate()

        gadata = np.array([  3.22991378e-57,   1.68441642e-25,   1.43167033e-23,
         6.20128276e-22,   1.92272388e-20,   4.74643578e-19,
         9.78126889e-18,   1.72560021e-16,   2.64547116e-15,
         3.55831241e-14,   4.22637998e-13,   4.45285513e-12,
         4.17503728e-11,   3.49184250e-10,   2.60952919e-09,
         1.74465929e-08,   1.04437791e-07,   5.60038118e-07,
         2.70506254e-06,   1.18998283e-05,   4.76932871e-05,
         1.74151408e-04,   5.79361856e-04,   1.75600595e-03,
         4.84903412e-03,   1.21993775e-02,   2.79623250e-02,
         5.83931709e-02,   1.11097431e-01,   1.92574662e-01,
         3.04121728e-01,   4.37571401e-01,   5.73592657e-01,
         6.85032870e-01,   7.45370866e-01,   7.38901572e-01,
         6.67350426e-01,   5.49129112e-01,   4.11669004e-01,
         2.81173926e-01,   1.74966576e-01,   9.91946657e-02,
         5.12359410e-02,   2.41109482e-02,   1.03372980e-02,
         4.03787583e-03,   1.43698441e-03,   4.65912462e-04,
         1.37628942e-04,   3.70397795e-05,   9.08197260e-06,
         2.02882766e-06,   4.12393041e-07,   7.53383218e-08,
         1.23283052e-08,   1.80606751e-09,   2.36657746e-10,
         2.77012812e-11,   2.89124066e-12,   2.68404945e-13,
         2.20863500e-14,   1.60321302e-15,   1.01948770e-16,
         5.62074344e-18,   2.64306001e-19,   1.03069500e-20,
         3.15739881e-22,   6.68755767e-24,   6.21121403e-26])

        np.testing.assert_array_almost_equal(data[0].data.data[::5000], gadata)
        
    def test_SG_Waveform_generation(self):
        """
        Regression test for SineGaussian Waveforms.
        """
        
        sg = sources.SineGaussian(10,1,1,"linear",1, seed = 0)
        data = sg._generate()

        sgdata = np.array([ -8.73180626e-58,  -1.93967534e-26,   5.31448156e-25,
         2.85625032e-24,  -1.16897801e-22,   4.54221716e-22,
         7.30222266e-21,  -7.05284947e-20,  -7.18443386e-20,
         3.83540555e-18,  -1.35146997e-17,  -9.19982707e-17,
         8.18155341e-16,  -1.75844610e-16,  -2.20122350e-14,
         7.73708586e-14,   2.54778729e-13,  -2.36261229e-12,
         2.01404378e-12,   3.49131470e-11,  -1.23199563e-10,
        -1.89419681e-10,   2.06578175e-09,  -2.48373900e-09,
        -1.71767992e-08,   6.11421635e-08,   3.47780769e-08,
        -5.83248940e-07,   8.10839193e-07,   2.72782093e-06,
        -1.00820213e-05,  -1.30630395e-07,   5.80605496e-05,
        -8.87448223e-05,  -1.55818895e-04,   6.31437341e-04,
        -2.04866297e-04,  -2.23611311e-03,   3.62241631e-03,
         3.19299434e-03,  -1.53446556e-02,   8.16204089e-03,
         3.31744792e-02,  -5.62402740e-02,  -2.09500348e-02,
         1.44826566e-01,  -9.53929854e-02,  -1.87967432e-01,
         3.35588750e-01,   2.19275523e-02,  -5.30630637e-01,
         3.90608881e-01,   4.00367320e-01,  -7.74119199e-01,
         9.53218059e-02,   7.53180228e-01,  -5.91727767e-01,
        -3.10327911e-01,   6.92512672e-01,  -1.65601830e-01,
        -4.12458030e-01,   3.39597507e-01,   8.05475578e-02,
        -2.40582205e-01,   7.46124228e-02,   8.65034217e-02,
        -7.47429394e-02,  -4.86389437e-03,   3.24540997e-02,
        -1.14700762e-02,  -6.85576126e-03,   6.35159480e-03,
        -2.55523511e-04,  -1.69725819e-03,   6.45883694e-04,
         1.99966867e-04,  -2.09187908e-04,   2.19763186e-05,
         3.42911180e-05,  -1.37134844e-05,  -2.01577460e-06,
         2.67493920e-06,  -3.87356579e-07,  -2.65313667e-07,
         1.10113354e-07,   5.46563207e-09,  -1.26429809e-08,
         2.07544399e-09,   6.97707443e-10,  -2.97465468e-10,
         2.38449797e-12,   1.95277082e-11,  -3.35619371e-12,
        -5.64458960e-13,   2.55522574e-13,  -1.03406456e-14,
        -9.34634494e-15,   1.60335007e-15,   1.24686211e-16,
        -6.41126505e-17,   3.59252308e-18,   1.20462046e-18,
        -1.94100550e-19,  -5.20175883e-21,   3.52788114e-21,
        -1.92473065e-22,  -2.39883058e-23,   2.65759767e-24,
         2.11709555e-27,  -2.28657022e-27])

        np.testing.assert_array_almost_equal(data[0].data.data[::5000], sgdata)

class TestMinkeSupernovaSources(unittest.TestCase):
    def setUp(self):
        """
        Set things up for the tests by making the MDC set, and defining the various parameter distributions.