cena_energy_spectra_fit2ΒΆ

''' CENA energy spectra fittings: To see the difference.

Only used SV-2 energy range.

Run as:

python %s 2009-04-23T07:38:58 2009-04-23T07:58:59

Options:

- v             verbose mode
- o [filename]  plot to a given file

.. note::

    110804: The module (in behind) support only using SV-2 table data.
    However, it may be time consuming to support the interface.
    Thus, I suggest to use this in ipython command line.
    A sample is shown here.

    >>> import dateutil.parser
    >>> import cena_energy_spectra_fit2
    >>> 
    >>> print("### This is for nominal data.")
    >>> t0 = dateutil.parser.parse('2009-02-06 05:13:11.078000')
    >>> t1 = dateutil.parser.parse('2009-02-06 05:32:49.086000')
    >>> bank = cena_energy_spectra_fit2.CenaDataBank(t0, t1)
    >>> pcnt1 = bank.fit_in_cnt()
    >>> print(pcnt1)
    >>> pcnt2 = bank.fit_in_cnt_sv2()
    >>> print(pcnt2)
    >>> 
    >>> print("### This is only non-sv2 data.")
    >>> t0 = dateutil.parser.parse('2009-07-29 06:21:07.265000')
    >>> t1 = dateutil.parser.parse('2009-07-29 06:42:05.303000')
    >>> bank = cena_energy_spectra_fit2.CenaDataBank(t0, t1)
    >>> pcnt1 = bank.fit_in_cnt()
    >>> print(pcnt1)
    >>> pcnt2 = bank.fit_in_cnt_sv2()
    >>> print(pcnt2)
    >>> 
    >>> print("### This is for mix of sv2 and others")
    >>> t0 = dateutil.parser.parse('2009-07-29 08:29:11.415000')
    >>> t1 = dateutil.parser.parse('2009-07-29 08:50:00.121000')
    >>> bank = cena_energy_spectra_fit2.CenaDataBank(t0, t1)
    >>> pcnt1 = bank.fit_in_cnt()
    >>> print(pcnt1)
    >>> pcnt2 = bank.fit_in_cnt_sv2()
    >>> print(pcnt2)

''' % (__file__, )

import matplotlib
matplotlib.use('Agg')

import sys
from optparse import OptionParser
import datetime
import dateutil.parser
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)

import numpy as np
import scipy.optimize
import matplotlib.pyplot as plt

import irfpy.cena.cena_mass2
import irfpy.cena.energy
import irfpy.cena.cena_flux

import irfpy.util.bipower

clr = ['b', 'g', 'r', 'c', 'm', 'y', 'k']

class CenaDataBank:
    ''' For performance and readability, functions are defined here.
    '''
    c2f = irfpy.cena.cena_flux.Count2Flux()
    f2c = irfpy.cena.cena_flux.Flux2Count_Simple()

    def __init__(self, t0, t1):
        '''
        '''
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.setLevel(logging.DEBUG)

        self.t0 = t0
        self.t1 = t1
        self.__tlist = None

        self.__fluxlist = None
        self.__cntlist = None
        self.__cntlist_sv2 = None


    def getobstime(self):
        ''' Return the time list.
        '''
        if self.__tlist == None:
            self.__tlist = irfpy.cena.cena_mass2.getobstime(
                                      timerange=[self.t0, self.t1])
        return self.__tlist

    def get_energy_table(self):
        ''' Return the energy table.

        Just return :meth:`irfpy.cena.energy.getEnergyE16`.
        '''
        return irfpy.cena.energy.getEnergyE16()

    def get_fluxlist(self):
        ''' Return the list of the count rate.

        :returns: numpy.ma array of H-ENA with shape of 16x7xN.
        '''
        if self.__fluxlist == None:
            tlist = self.getobstime()
            ntlist = len(tlist)
            self.__fluxlist = np.ma.zeros([16, 7, ntlist])

            for ind, t in enumerate(tlist):
                try:
                    d = irfpy.cena.cena_mass2.getHdefluxE16(t).getData()
                    self.__fluxlist[:, :, ind] = d[:, :]
                except RuntimeError as e:
                    # Raised if the data is not in the right format or etc.
                    self.__fluxlist[:, :, ind] = float('nan')

            self.__fluxlist = np.ma.masked_invalid(self.__fluxlist)

        return self.__fluxlist

    def get_cntlist(self):
        ''' Return the list of the count rate.

        :returns: numpy.ma array of H-ENA with shape of 16x7xN.
        '''
        if self.__cntlist == None:
            tlist = self.getobstime()
            ntlist = len(tlist)
            self.__cntlist = np.ma.zeros([16, 7, ntlist])

            for ind, t in enumerate(tlist):
                try:
                    d = irfpy.cena.cena_mass2.getdataE16(t).getData()
                    d = d[:, :, 0:64].sum(2)  # Integrate over H-channel.
                    self.__cntlist[:, :, ind] = d[:, :]
                except RuntimeError as e:
                    # Raised if the data is invalid.
                    self.__cntlist[:, :, ind] = float('nan')

            self.__cntlist = np.ma.masked_invalid(self.__cntlist)

        return self.__cntlist

    def get_cntlist_sv2(self):
        ''' Return the list of the count rate for SV table 2.

        :returns: A list of the count list and the time list.
        '''
        if self.__cntlist_sv2 == None:
            tlist = self.getobstime()[:]
            tlist_valid = []
            ntlist = len(tlist)
            cntlist = np.ma.zeros([16, 7, ntlist])

            for ind, t in enumerate(tlist):
                try:
                    d = irfpy.cena.cena_mass2.getdataE16(t).getData()
                    d = d[:, :, 0:64].sum(2)  # Integ. over mass. Now 16x7
                    if np.ma.is_masked(d[4:11, :]):
                        self.logger.warn('Time %s (ind=%d) has invalid data than Sv-2 or saturated'
                                                    % (t, ind))
                        raise RuntimeError('Not a SV-2 data.')
                    cntlist[:, :, ind] = d[:, :]
                    tlist_valid.append(t)

                except RuntimeError as e:
                    cntlist[:, :, ind] = float('nan')

            cntlist = np.ma.masked_invalid(cntlist)

            self.__cntlist_sv2 = cntlist, tlist_valid

        return self.__cntlist_sv2

    @classmethod
    def get_fitfunc_in_log(self):
        fitfunc = lambda p, x: [irfpy.util.bipower.mklnfunc(
                            p[0], p[1], p[2], p[3])(ix) for ix in x]
        return fitfunc
        

    def fit_in_log(self):
        ''' Fit in the log space using energy range 4-11.

        Return a dictionary which contains data with keys below.

        -   ``2``, ``3`` and ``4``: have 4 parameters each.
            [log(k0), r0, log(k1), r1] for corresponding channels.
        -   ``"ndat2"``, ``"ndat3"`` and ``"ndat4"``: Data length used for fitting.
            Each have 8 element integer.
        -   ``"flux2"``, ``"flux3"`` and ``"flux4"``: Flux used for fitting.
            Each have 8 element integer.
    
        '''
        fluxlist = self.get_fluxlist()  # Flux is 16x7xN data
        aveflux = fluxlist.mean(2)   # Average over the time.
        ndata = fluxlist.count(2)
        # Choose only SV-2
        etbl2 = self.get_energy_table()[4:12]
        flux2 = aveflux[4:12, :]

        prmset = {}

        for ch in (2, 3, 4):  # Central three channels.
            flux = flux2[:, ch]
            # Fit function converts energy to flux in the linear space.
            fitfunc = self.get_fitfunc_in_log()

            # Errof function is evaluated in the log space.
            errfunc = lambda p, x, y: np.log(y) - np.log(fitfunc(p, x))

            try:
                prms, success = scipy.optimize.leastsq(errfunc,
                                    (4, 0, 12, -4), args=(etbl2, flux))

                self.logger.info('--- CH-%d: (%f, %f) + (%f, %f)' %
                                (ch, prms[0], prms[1], prms[2], prms[3]))

                prmset[ch] = prms
            except Exception as e:
                self.logger.error('Fail optimization flx %s-%s' % (self.t0, self.t1))
                prmset[ich] = [float('nan'), float('nan'), float('nan'), float('nan')]

            prmset['ndat%01d' % ch] = ndata[4:12, ch]
            prmset['flux%01d' % ch] = flux

        cnts = fluxlist[4:12, 2:5, :].count(2)  # number of obs for E=4..11 and D=2..4

        return prmset

    @classmethod
    def get_fitfunc_in_cnt(self, channel):
        ''' Return a fit function for count rate space.

        :returns: A function to be fitted in count rate space.
            The returned function forms f(p, x) where p is the
            parameter which defines the bi-power law, and x is
            for energy.
        '''

        def fitfunc(p, x):
            # Here p is parameter (4 elems) and x is energy (8 elems).
            f2c = self.f2c
            ff = irfpy.util.bipower.mklnfunc(p[0], p[1], p[2], p[3])
            ffx = [ff(ix) for ix in x]
            cnts = np.zeros([8])
            for index, ix in enumerate(x):
                cnts[index] = f2c.getCount(ffx[index], index+4, channel)
            return cnts

        return fitfunc

    def fit_in_cnt_sv2(self):
        ''' Fit in the count rate space using only data in SV-table 2.

        This is a limited version of :meth:`fit_in_cnt`.
        The :meth:`fit_in_cnt` will return the fitted data
        using energy spectra corresponding to 4-11 if SV table is 1 or 3.
        However, this fit_in_cnt_sv2 will return the fitted data
        using only SV-table 2 data.

        Return a dictionary which contains data with keys below.

        -   ``2``, ``3`` and ``4``: have 4 parameters each.
            [log(k0), r0, log(k1), r1] for corresponding channels.
        -   ``"ndat2"``, ``"ndat3"`` and ``"ndat4"``: Data length used for fitting.
            Each have 8 element integer.
        -   ``"count2"``, ``"count3"`` and ``"count4"``: Counts used for fitting.
            Each have 8 element integer.
        -   ``"tlist"`` returns the list of the time used.
    
        '''
        etbl2 = self.get_energy_table()[4:12]
        cntlist, tlist = self.get_cntlist_sv2()
        cnt2 = cntlist.mean(2)  # Average over time.
        ndata = cntlist.count(2)

        prmset = {}
        for ich in (2, 3, 4):
            cnt = cnt2[4:12, ich]

            # Fit function is in the count rate space.
            # Error function
            fitfunc = self.get_fitfunc_in_cnt(ich)
            errfunc = lambda p, x, y: y - fitfunc(p, x)

            try:
                prms2, success = scipy.optimize.leastsq(errfunc, (4, 0, 12, -4), args=(etbl2, cnt))
                self.logger.info('--- CH-%d: (%f, %f) + (%f, %f), succ=%d' %
                                (ich, prms2[0], prms2[1], prms2[2], prms2[3], success))

                prmset[ich] = prms2
            except Exception as e:
                self.logger.error('Fail optimization cnt %s-%s' % (self.t0, self.t1))
                prmset[ich] = [float('nan'), float('nan'), float('nan'), float('nan')]

            prmset['ndat%01d' % ich] = ndata[4:12, ich]
            prmset['count%01d' % ich] = cnt
            prmset['tlist'] = tlist

        return prmset


    def fit_in_cnt(self):
        ''' Fit in the count rate space using energy range 4-11.

        Return a dictionary which contains data with keys below.

        -   ``2``, ``3`` and ``4``: have 4 parameters each.
            [log(k0), r0, log(k1), r1] for corresponding channels.
        -   ``"ndat2"``, ``"ndat3"`` and ``"ndat4"``: Data length used for fitting.
            Each have 8 element integer.
        -   ``"count2"``, ``"count3"`` and ``"count4"``: Counts used for fitting.
            Each have 8 element integer.
    
        '''
        etbl2 = self.get_energy_table()[4:12]
        cntlist = self.get_cntlist()
        cnt2 = cntlist.mean(2)  # Average over time.
        ndata = cntlist.count(2)

        prmset = {}
        for ich in (2, 3, 4):
            cnt = cnt2[4:12, ich]

            # Fit function is in the count rate space.
            # Error function
            fitfunc = self.get_fitfunc_in_cnt(ich)
            errfunc = lambda p, x, y: y - fitfunc(p, x)

            try:
                prms2, success = scipy.optimize.leastsq(errfunc, (4, 0, 12, -4), args=(etbl2, cnt))
                self.logger.info('--- CH-%d: (%f, %f) + (%f, %f), succ=%d' %
                                (ich, prms2[0], prms2[1], prms2[2], prms2[3], success))

                prmset[ich] = prms2
            except Exception as e:
                self.logger.error('Fail optimization cnt %s-%s' % (self.t0, self.t1))
                prmset[ich] = [float('nan'), float('nan'), float('nan'), float('nan')]

            prmset['ndat%01d' % ich] = ndata[4:12, ich]
            prmset['count%01d' % ich] = cnt

        return prmset


def do_fitting(t0, t1):
    ''' Do fitting.

    Call :meth:`CenaDataBank.fit_in_log` and :meth:`CenaDataBack.fit_in_cnt`.
    '''
    logger.debug('Start = %s' % t0)
    logger.debug('End = %s' % t1)

    bank = CenaDataBank(t0, t1)

    # Observation time
    tlist = bank.getobstime()
    ntlist = len(tlist)
    logger.debug('Length = %d' % ntlist)

    # Fitting using data bank.
    plog = bank.fit_in_log()
    pcnt = bank.fit_in_cnt()

    return bank, plog, pcnt

def do_plotting(bank, plog, pcnt, figname=None):
    # Energy table
    etbl = bank.get_energy_table()
    etbl2 = etbl[4:12]
    logger.debug('Energy = %f - %f' % (etbl[0], etbl[-1]))
    etblf = np.logspace(1, 3)

    # Flux data can be get from the bank.
    flux = bank.get_fluxlist().mean(2)   # Average over time
    flux2 = flux[4:12, :]
    cnt = bank.get_cntlist().mean(2)   # Average over time
    cnt2 = cnt[4:12, :]


    fig = plt.figure(figsize=(12,10))
    
    ax1 = fig.add_subplot(221)  # Flux fit, flux plot
    ax1_ = fig.add_subplot(222) # Flux fit, count plot
    ax2 = fig.add_subplot(223)  # Count rate fit, count rate plot
    ax2_ = fig.add_subplot(224) # Flux fit, count plot

    # Flux fitting plotted in flux space
    for ich in (2, 3, 4):
        ax1.plot(etbl2, flux2[:, ich], clr[ich]+'o', label='CH-%1d' % ich)
        prms = plog[ich]
        fitted = irfpy.util.bipower.mklnfunc(prms[0], prms[1], prms[2], prms[3])
        fitflx = [fitted(ie) for ie in etblf]
        ax1.plot(etblf, fitflx, clr[ich]+'-')

    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.legend(loc='lower left')
    ax1.set_title('Flux fit, flux plot')

    # Flux fitting plotted in count space
    for ich in (2, 3, 4):
        ax1_.plot(etbl2, cnt2[:, ich], clr[ich]+'o', label = 'CH-%1d' % ich)
        prms = plog[ich]  # Fitted parameter
        # Fitted function
        flxfitfunc = bank.get_fitfunc_in_log()
        flxfit = flxfitfunc(prms, etbl2)
        # Corresponding count rate
        cntfit = np.zeros([8])
        for estep, fl in enumerate(flxfit):
            cntfit[estep] = bank.f2c.getCount(fl, estep+4, ich)
        ax1_.plot(etbl2, cntfit, clr[ich] + '-')
        
    ax1_.set_xscale('log')
    ax1_.set_title('Flux fit, count plot')

    # Count fitting plotted in count rate
    for ich in (2, 3, 4):
        ax2.plot(etbl2, cnt2[:, ich], clr[ich] + 'o', label='CH-%1d' % ich)
        prms = pcnt[ich]  # Fitted parameter.
        fitfunc = bank.get_fitfunc_in_cnt(ich)  # A function to get fitted count.
        fitcnt = fitfunc(prms, etbl2)
        ax2.plot(etbl2, fitcnt, clr[ich])

    ax2.set_xscale('log')
    ax2.set_title('Count fit, count plot')

    # Count fitting plotted in flux space

    for ich in (2, 3, 4):
        ax2_.plot(etbl2, flux2[:, ich], clr[ich] + 'o', label='CH-%1d' % ich)
        prms = pcnt[ich]
        fitted = bank.get_fitfunc_in_log()
        fitflx = fitted(prms, etblf)
        ax2_.plot(etblf, fitflx, clr[ich] + '-')
        

    ax2_.set_xscale('log')
    ax2_.set_yscale('log')
    ax2_.set_title('Count fit, flux plot')

    if figname != None:
        fig.savefig(figname)

    return (bank, plog, pcnt)


def main(t0, t1, figname=None):

    bank, plog, pcnt = do_fitting(t0, t1)

    if figname != None:
        do_plotting(bank, plog, pcnt, figname=figname)

    return bank, plog, pcnt


if __name__ == '__main__':
    usage = "USAGE: %prog [options] start_time stop_time"
    parser = OptionParser(usage=usage)

    parser.add_option('-v', action='store_true', dest='verbose', default=False)
    parser.add_option('-o', action='store', dest='outfile', default=None)

    options, args = parser.parse_args()


    if len(args) != 2:
        logging.error('!!!! Illegal number of argument. t0 and t1 should be specified.')
        parser.print_help()
        sys.exit(-5)

    t0 = dateutil.parser.parse(args[0])
    t1 = dateutil.parser.parse(args[1])

    if options.verbose:
        logger.setLevel(logging.DEBUG)

    logging.debug('Start = %s' % t0)
    logging.debug('End   = %s' % t1)

    retvals = main(t0, t1, figname=options.outfile)