# -------------------------------------------------------------------------
#     This file is part of mMass - the spectrum analysis tool for MS.
#     Copyright (C) 2005-07 Martin Strohalm <mmass@biographics.cz>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 2 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE in the
#     main directory of the program
# -------------------------------------------------------------------------

# Function: Load and parse data from mzXML format.

# load libs
import wx
import os.path
import xml.dom.minidom
import base64
import struct

# load modules
from dlg_select_msscan import dlgSelectScan


class mzXMLDoc:
    """ Get and format data from mzXML document. """

    # ----
    def __init__(self, parent):

        self.parent = parent
        self.data = {
                    'docType':'mzXML',
                    'scanID':'',
                    'date':'',
                    'operator':'',
                    'institution':'',
                    'contact':'',
                    'instrument':'',
                    'notes':'',
                    'peaklist':[],
                    'spectrum':[]
                    }

        self.elmName = None
    # ----


    # ----
    def getDocument(self, path):
        """ Read and parse all data from document. """

        # parse XML
        try:
            document = xml.dom.minidom.parse(path)
        except:
            return False

        # get spectrum
        element = document.getElementsByTagName('msRun')
        if element:
            status = self.handleSpectrumList(element[0])

            # error in data
            if status == False:
                return False

            # no spectrum selected
            elif status == None:
                return None

        # no msRun tag found
        else:
            return False

        # get description
        element = document.getElementsByTagName('msInstrument')
        if element:
            self.handleDescription(element[0])

        return self.data
    # ----


    # ----
    def getElement(self, name, path):
        """ Read and parse selected elements' data from document. """

        self.elmName = name

        # parse XML
        try:
            document = xml.dom.minidom.parse(path)
        except:
            return False

        # get data
        if name == 'description':
            element = document.getElementsByTagName('msInstrument')
        elif name == 'spectrum' or name == 'peaklist':
            element = document.getElementsByTagName('msRun')

        if element:

            # get description
            if name == 'description':
                if self.handleDescription(element[0]) == False:
                    return False

            # get spectrum
            elif name == 'spectrum' or name == 'peaklist':
                status = self.handleSpectrumList(element[0])

                # error in data
                if status == False:
                    return False

                # no spectrum selected
                elif status == None:
                    return None

        return self.data
    # ----


    # ----
    def handleDescription(self, elements):
        """ Get document description from <msInstrument> element. """

        # msManufacturer
        msManufacturer = elements.getElementsByTagName('msManufacturer')
        if msManufacturer:
            self.data['instrument'] = msManufacturer[0].getAttribute('value')

        # msModel
        msModel = elements.getElementsByTagName('msModel')
        if msModel:
            self.data['instrument'] += msModel[0].getAttribute('value')

        # msIonisation
        msIonisation = elements.getElementsByTagName('msIonisation')
        if msIonisation:
            self.data['instrument'] += msIonisation[0].getAttribute('value')

        # msMassAnalyzer
        msMassAnalyzer = elements.getElementsByTagName('msMassAnalyzer')
        if msMassAnalyzer:
            self.data['instrument'] += msMassAnalyzer[0].getAttribute('value')

        # operator
        operator = elements.getElementsByTagName('operator')
        if operator:
            self.data['operator'] = '%s %s' % (operator[0].getAttribute('first'), operator[0].getAttribute('last'))
            self.data['contact'] = '%s %s %s' % (operator[0].getAttribute('phone'), operator[0].getAttribute('email'), operator[0].getAttribute('URI'))

        return True
    # ----


    # ----
    def handleSpectrumList(self, elements):
        """ Get list of spectra from <spectrumList> element. """

        # get all spectra
        spectra = elements.getElementsByTagName('scan')
        if not spectra:
            return False

        # get one spectrum
        if len(spectra) == 1:
            self.handleSpectrum(spectra[0])

        # get spectrum from list
        else:
            scans = self.getScans(spectra)
            dlg = dlgSelectScan(self.parent, scans)
            if dlg.ShowModal() == wx.ID_OK:
                scanID = dlg.selectedScan
                dlg.Destroy()

                # get data
                for scan in spectra:
                    if scan.getAttribute('num') == scanID:
                        if not self.handleSpectrum(scan):
                            return False
                        else:
                            self.data['scanID'] = scanID
                        break
            else:
                dlg.Destroy()
                return None

        return True
    # ----


    # ----
    def handleSpectrum(self, spectrum):
        """ Get spectrum data from <spectrum> element. """

        # get data element
        peaks = spectrum.getElementsByTagName('peaks')
        if not peaks:
            return False

        # get endian or use default(!)
        if peaks[0].getAttribute('byteOrder') == 'network':
            endian = '!'
        elif peaks[0].getAttribute('byteOrder') == 'little':
            endian = '<'
        elif peaks[0].getAttribute('byteOrder') == 'big':
            endian = '>'
        else:
            endian = '!'

        # get raw data
        data = self.getText(peaks[0].childNodes)

        # decode data
        try:
            data = base64.b64decode(data)
        except:
            return False

        # convert from binary format
        try:
          pointsCount = len(data)/struct.calcsize(endian+'f')
          start, end = 0, len(data)
          data = struct.unpack(endian+'f'*pointsCount, data[start:end])
        except:
            return False

        # split data to m/z and intensity
        mzData = data[::2]
        intData = data[1::2]

        # check data
        if not mzData or not intData or (len(mzData) != len(intData)):
            return False

        # "zip" mzData and intData
        formatedData = zip(mzData, intData)

        # set data as spectrum or peaklist
        if not self.elmName:
            dlg = wx.MessageDialog(self.parent, "Import data as continuous spectrum?\n(Press 'No' to import as descrete peaklist points.)", "Import as spectrum?", wx.YES_NO|wx.YES_DEFAULT|wx.ICON_QUESTION)
            button = dlg.ShowModal()
            dlg.Destroy()
            if button == wx.ID_YES:
                self.data['spectrum'] = formatedData
            else:
                self.data['peaklist'] = self.convertSpectrumToPeaklist(formatedData)
        elif self.elmName == 'spectrum':
                self.data['spectrum'] = formatedData
        elif self.elmName == 'peaklist':
            self.data['peaklist'] = self.convertSpectrumToPeaklist(formatedData)

        # get precursor info for MS/MS data
        scanInfo = self.getScanInfo(spectrum)
        if scanInfo['time'] != '---':
            self.data['notes'] += '\n-----\nTime: %s' % (scanInfo['time'])
        if scanInfo['level'] and scanInfo['level'] != '1':
            self.data['notes'] += '\nMS Level: %s' % (scanInfo['level'])
            self.data['notes'] += '\nPrecursor Mass: %s' % (scanInfo['mz'])
            self.data['notes'] += '\nPrecursor Charge: %s' % (scanInfo['charge'])
            self.data['notes'] += '\nPrecursor Polarity: %s' % (scanInfo['polarity'])

        return True
    # ----


    # ----
    def getScans(self, spectra):
        """ Get basic info about all the ms scans. """

        # get list of scans
        scans = []
        for x, scan in enumerate(spectra):

            # get scan info
            scanInfo = self.getScanInfo(scan)

            # ID, time, range, MS level, prec.mass, pre.charge, spec. type
            scans.append(['---', '---', '---', '---', '---', '---', '---', '---'])
            scans[x][0] = scanInfo['id']
            scans[x][1] = scanInfo['time']
            scans[x][2] = scanInfo['range']
            scans[x][3] = scanInfo['points']
            scans[x][4] = scanInfo['level']
            scans[x][5] = scanInfo['mz']
            scans[x][6] = scanInfo['charge']
            scans[x][7] = scanInfo['type']

        return scans
    # ----


    # ----
    def getScanInfo(self, scan):
        """ Get basic info about selected scan. """

        scanInfo = {}
        scanInfo['type'] = '---'
        scanInfo['level'] = '---'
        scanInfo['range'] = '---'
        scanInfo['points'] = '---'
        scanInfo['polarity'] = '---'
        scanInfo['time'] = '---'
        scanInfo['mz'] = '---'
        scanInfo['charge'] = '---'
        scanInfo['method'] = '---'

        # get ID
        scanInfo['id'] = scan.getAttribute('num')

        # get msLevel
        scanInfo['level'] = scan.getAttribute('msLevel')

        # get number of points
        scanInfo['points'] = scan.getAttribute('peaksCount')

        # get polarity
        scanInfo['polarity'] = scan.getAttribute('polarity')

        # get retention time
        scanInfo['time'] = scan.getAttribute('retentionTime')

        # get range
        lowMz = scan.getAttribute('lowMz')
        highMz = scan.getAttribute('highMz')
        try:
            scanInfo['range'] = '%d - %d' % (float(lowMz), float(highMz))
        except:
            scanInfo['range'] = '%s - %s' % (lowMz, highMz)

        # find precursor params
        if scanInfo['level'] and scanInfo['level'] != '1':
            precursorMz = scan.getElementsByTagName('precursorMz')
            if precursorMz:

                # get m/z
                scanInfo['mz'] = self.getText(precursorMz[0].childNodes)

                # get charge
                scanInfo['charge'] = precursorMz[0].getAttribute('retentionTime')

        return scanInfo
    # ----


    # ----
    def getText(self, nodelist):
        """ Get text from node list. """

        # get text
        buff = ''
        for node in nodelist:
            if node.nodeType == node.TEXT_NODE:
                buff += node.data

        return buff
    # ----


    # ----
    def convertSpectrumToPeaklist(self, spectrum):
        """ Convert spectrum to peaklist. """

        peaklist = []
        for point in spectrum:
            peaklist.append([point[0], point[1], '', 0])

        return peaklist
    # ----
