#
#  Copyright (C) 2016 STFC Rutherford Appleton Laboratory, UK.
#
#  Author: David Waterman
#  Acknowledgements: based on code by Graeme Winter and Martin Noble.
#

from CCP4PluginScript import CPluginScript
from CCP4ErrorHandling import *
import os, glob, shutil
import CCP4Utils
from lxml import etree
import CCP4Container
import platform
import json

class Cxia2_dials(CPluginScript):

    TASKTITLE='Data processing with xia2/dials'
    TASKNAME = 'xia2_dials'
    TASKCOMMAND = 'xia2'
    if platform.system() == 'Windows': TASKCOMMAND = 'xia2.bat'
    TASKMODULE = 'data_processing'
    TASKVERSION = 0.0
    ERROR_CODES = {
        200:{'description':'Failed harvesting integrated data'},
        201:{'description':'Failed scaled data'},
        202:{'description':'Failed harvesting pointless XML'},
        203:{'description':'Failed harvesting aimless xml'},
        204:{'description':'Failed harvesting truncate xml'}}
    PERFORMANCECLASS = 'CDataReductionPerformance'
    ASYNCHRONOUS=True
    WHATNEXT = ['phaser_pipeline','molrep_mr','crank2','ShelxCD']
    MAINTAINER = 'ccp4@stfc.ac.uk'

    def extract_parameters(self, container):
        """Walk through a container locating parameters that have been set
        and return a list of name, value pairs"""

        result=[]
        dataOrder = container.dataOrder()
        contents = [getattr(container,name) for name in dataOrder]
        for model in contents:
            if isinstance(model, CCP4Container.CContainer):
                result.extend(self.extract_parameters(model))
            elif model.isSet():
                name = model.objectName().replace('__','.')
                # ensure commas are converted to whitespace-separated lists. Only
                # whitespace appears to work correctly with PHIL multiple
                # choice definitions.
                val = str(model.get()).split()
                val = ' '.join([v[:-1] if v.endswith(',') else v for v in val])
                result.append((name, val))
        return result

    def makeCommandAndScript(self):
        par = self.container.controlParameters
        inp = self.container.inputData

        # Set xia2 switches
        self.appendCommandLine(['pipeline=dials', ])

        # PHIL parameters set by the gui
        phil_file = os.path.normpath(os.path.join(
                    self.getWorkDirectory(), 'xia2_dials.phil'))
        with open(phil_file, 'w') as f:
            for (name, val) in self.extract_parameters(par):
                #self.appendCommandLine()
                f.write(name + '={0}\n'.format(val))
        self.appendCommandLine([phil_file])

        # Finally the data location
        self.appendCommandLine(['%s' % str(inp.IMAGE_DIRECTORY)])

        self.xmlroot = etree.Element('Xia2Dials')

        self.watchFile(os.path.normpath(os.path.join(
            self.getWorkDirectory(),'xia2.txt')), self.handleXia2DotTxtChanged)

        return CPluginScript.SUCCEEDED

    @staticmethod
    def _get_annotation(prefix, suffix):
        '''Form suitable annotation strings'''
        return  prefix + " from DIALS integration of " + suffix

    @staticmethod
    def _extract_wavelength_names(json_txt):
      '''Get a list of wavelength names from a xia.json string'''
      d = json.loads(json_txt)
      xls = d['_crystals']
      assert len(xls) == 1 # don't cope with > 1 crystal here
      xl = xls[xls.keys()[0]]
      wl = xl['_wavelengths']
      # extract names, converting from unicode
      return [str(e) for e in wl.keys()]

    def processOutputFiles(self):

        # Check for exit status of the program
        from CCP4Modules import PROCESSMANAGER
        exitStatus = PROCESSMANAGER().getJobData(pid=self.getProcessId(),
                                                 attribute='exitStatus')
        if exitStatus != CPluginScript.SUCCEEDED:
            element = etree.SubElement(self.xmlroot,'Xia2Error')
            element.text = 'Failed to locate XIA2'
            return exitStatus

        # Read xia2.txt
        xia2TxtPath = os.path.normpath(os.path.join(self.getWorkDirectory(),
                                                    'xia2.txt'))
        if os.path.isfile(xia2TxtPath):
            with open(xia2TxtPath, 'r') as xia2TxtFile:
                element = etree.SubElement(self.xmlroot,'Xia2Txt')
                element.text = etree.CDATA(xia2TxtFile.read())

        # Infer if xia2 gave an error by virtue of xia2.error existing
        xia2ErrorPath = os.path.normpath(os.path.join(self.getWorkDirectory(),
                                                      'xia2.error'))
        if os.path.isfile(xia2ErrorPath):
            with open(xia2ErrorPath,'r') as xia2ErrorFile:
                element = etree.SubElement(self.xmlroot, 'Xia2Error')
                element.text = etree.CDATA(xia2ErrorFile.read())
                self.flushXML()
            return CPluginScript.SUCCEEDED

        # Read xia2.json, pack it in the XML and also extract wavelength names
        wavelength_names = []
        xia2JsonPath = os.path.normpath(os.path.join(self.getWorkDirectory(),
                                                  'xia2.json'))
        if os.path.isfile(xia2JsonPath):
            with open(xia2JsonPath, 'r') as xia2JsonFile:
                # Keep the whole thing in the xml. Currently just used for the
                # xia2 run summary table, but may have other uses in future.
                element = etree.SubElement(self.xmlroot,'Xia2Json')
                json_txt = xia2JsonFile.read()
                element.text = etree.CDATA(json_txt)
                # Get the wavelength names, which end up appended to column
                # names if >1 wavelength. Will need for splitMtz.
                wavelength_names.extend(self._extract_wavelength_names(json_txt))

        par = self.container.controlParameters
        tmp = par.xia2.xia2__settings.xia2__settings__input
        anomalous = (tmp.xia2__settings__input__atom.isSet() or
          str(tmp.xia2__settings__input__anomalous) == 'True')

        unmergedOut =  self.container.outputData.UNMERGEDOUT
        obsOut =  self.container.outputData.HKLOUT
        freerOut =  self.container.outputData.FREEROUT

        # Grab integrated data
        candidates = glob.glob(os.path.normpath(os.path.join(self.getWorkDirectory(),
                                          'DataFiles','*INTEGRATE.mtz')))
        for candidateIntegratedFile in candidates:
            srdDirectory, srcFilename = os.path.split(candidateIntegratedFile)
            destPath = os.path.normpath(os.path.join(self.getWorkDirectory(),
                                                     srcFilename))
            shutil.copyfile(candidateIntegratedFile, destPath)
            unmergedOut.append(unmergedOut.makeItem())
            unmergedOut[-1].fullPath = destPath
            anno = self._get_annotation('Unmerged reflections',
                                        srcFilename[:-13])
            unmergedOut[-1].annotation = anno

        # Grab merged files
        pattern = os.path.normpath(os.path.join(self.getWorkDirectory(),
                                                'DataFiles','*free.mtz'))
        possibleFilesToCopy = glob.glob(pattern)
        for srcPath in possibleFilesToCopy:
            srcDirectory, srcFilename = os.path.split(srcPath)

            if len(wavelength_names) > 1:
              obsPath_list = []
              for w in wavelength_names:
                obsPath_list.append(os.path.join(self.getWorkDirectory(),
                                  srcFilename[:-9]+'_obs_{0}.mtz'.format(w)))
            else:
              obsPath_list = [os.path.join(self.getWorkDirectory(),
                                  srcFilename[:-9]+'_obs.mtz')]

            freerPath = os.path.join(self.getWorkDirectory(),
                                     srcFilename[:-9]+'_freer.mtz')

            import CCP4XtalData

            if anomalous:
                colin_base = 'I(+){0},SIGI(+){0},I(-){0},SIGI(-){0}'
                if len(wavelength_names) > 1:
                  colin_list = [colin_base.format('_' + w) for w in wavelength_names]
                else:
                  colin_list = [colin_base.format('')]
                colout = 'Iplus,SIGIplus,Iminus,SIGIminus'
            else:
                colin_list = ['IMEAN,SIGIMEAN']
                colout = 'I,SIGI'
            colfree = 'FreeR_flag'
            logFile = os.path.join(self.getWorkDirectory(),'cmtzsplit.log')
            out_triplets = [[obspth, colin, colout] for obspth, colin in \
                            zip(obsPath_list, colin_list)]
            out_triplets.append([freerPath, colfree, colfree])
            status = self.splitMtz(srcPath, out_triplets, logFile)
            if status == CPluginScript.SUCCEEDED:
                import CCP4XtalData
                for w, obsPath in zip(wavelength_names, obsPath_list):
                    obsOut.append(obsOut.makeItem())
                    obsOut[-1].fullPath = obsPath
                    anno = self._get_annotation('Reflections: {0}'.format(w),
                                                srcFilename[:-8])
                    obsOut[-1].annotation = anno
                    if anomalous:
                        flag = CCP4XtalData.CObsDataFile.CONTENT_FLAG_IPAIR
                        obsOut[-1].contentFlag = flag
                    else:
                        flag = CCP4XtalData.CObsDataFile.CONTENT_FLAG_IMEAN
                        obsOut[-1].contentFlag = flag
                freerOut.append(freerOut.makeItem())
                freerOut[-1].fullPath = freerPath
                anno = self._get_annotation('FreeR', srcFilename[:-8])
                freerOut[-1].annotation = anno
            else:
                self.appendErrorReport(200)
                return CPluginScript.FAILED

            # Grab XMLs
            try:
                pointlessXMLs = glob.glob(os.path.normpath(os.path.join(
                    self.getWorkDirectory(),
                    'DEFAULT','scale','*pointless.xml')))
                print 'pointlessXMLs',pointlessXMLs
                if len(pointlessXMLs)>0:
                    sortedPointlessXMLs = sorted(pointlessXMLs, key=lambda \
                        filePath: int(os.path.split(filePath)[1].split('_')[0]))
                    trees = [CCP4Utils.openFileToEtree(f) \
                             for f in sortedPointlessXMLs]
                    self.xmlroot.append(trees[-1])
                    # symmetry element scores are from the first Pointless run
                    for element in trees[0].iter('LatticeSymmetry',
                                                 'ElementScores',
                                                 'LaueGroupScoreList'):
                        self.xmlroot.append(element)
            except:
                self.appendErrorReport(202)
                return CPluginScript.FAILED

            try:
                aimlessXMLs = glob.glob(os.path.normpath(os.path.join(
                    self.getWorkDirectory(), 'LogFiles','*aimless_xml.xml')))
                for aimlessXML in aimlessXMLs:
                    self.xmlroot.append(CCP4Utils.openFileToEtree(aimlessXML))
            except:
                self.appendErrorReport(203)
                return CPluginScript.FAILED

            try:
                truncateXMLs = glob.glob(os.path.normpath(os.path.join(
                    self.getWorkDirectory(), 'LogFiles','*truncate.xml')))
                ctruncates = etree.Element('CTRUNCATES')
                for truncateXML in truncateXMLs:
                    ctruncates.append(CCP4Utils.openFileToEtree(truncateXML))
                self.xmlroot.append(ctruncates)
            except:
                self.appendErrorReport(204)
                return CPluginScript.FAILED

            self.flushXML()

            # Grab all log files
            allLogs = glob.glob(os.path.normpath(os.path.join(
                self.getWorkDirectory(), 'LogFiles','*.log')))
            for logFilePath in allLogs:
                destLogPath = os.path.normpath(os.path.join(
                    self.getWorkDirectory(), os.path.split(logFilePath)[1]))
                shutil.copyfile(logFile, destLogPath)

        # Populate the performance indicator
        try:
            eList = self.xmlroot.xpath('AIMLESS/ReflectionFile/SpacegroupName')
            spGp = self.container.outputData.PERFORMANCE.spaceGroup
            if len(eList)>0: spGp.set(spGp.fix(str(eList[0].text).strip()))
        except:
            pass
        eList = self.xmlroot.xpath(
            'AIMLESS/Result/Dataset/RmeasOverall/Overall')
        if len(eList) > 0:
            self.container.outputData.PERFORMANCE.rMeas.set(
                float(str(eList[0].text).strip()))
        eList = self.xmlroot.xpath(
            'AIMLESS/Result/Dataset/ResolutionHigh/Overall')
        if len(eList) > 0:
            self.container.outputData.PERFORMANCE.highResLimit.set(
                float(str(eList[0].text).strip()))

        return CPluginScript.SUCCEEDED

    def handleXia2DotTxtChanged(self, filename):
        # remove xia2.txt nodes
        for xia2TxtNode in self.xmlroot.xpath('Xia2Txt'):
            self.xmlroot.remove(xia2TxtNode)
        xia2TxtNode = etree.SubElement(self.xmlroot,'Xia2Txt')
        with open (filename,'r') as xia2DotTxtFile:
            xia2TxtNode.text = etree.CDATA(xia2DotTxtFile.read())
        self.flushXML()

    def flushXML(self):
        tmpFilename = self.makeFileName('PROGRAMXML')+'_tmp'
        with open(tmpFilename,'w') as xmlFile:
            xmlFile.write(etree.tostring(self.xmlroot,pretty_print=True))
        if os.path.exists(self.makeFileName('PROGRAMXML')):
            os.remove(self.makeFileName('PROGRAMXML'))
        os.rename(tmpFilename, self.makeFileName('PROGRAMXML'))


