
from CCP4PluginScript import CInternalPlugin, CPluginScript
from PyQt4 import QtCore
import os,glob,re,time,sys,shutil
import CCP4XtalData
from lxml import etree
import math
import CCP4Modules,CCP4Utils
import CCP4ErrorHandling

class PrepareDeposit(CPluginScript):
    TASKNAME = 'PrepareDeposit'
    TASKVERSION= 0.0
    ASYNCHRONOUS = False
    TIMEOUT_PERIOD = 240
    MAXNJOBS = 4
    SUBTASKS=['refmac']
    MAINTAINER = 'martin.noble@newcastle.ac.uk'
    
    ERROR_CODES = {  200 : { 'description' : 'Alignments yielded nBestPairs != 1' },
                    201 : { 'description' : 'Failed to copy files to destination directory...do you have write access ?' },}

    def process(self):
        
        self.xmlroot = etree.Element('PrepareDeposit')
        
        invalidFiles = self.checkInputData()
        if len(invalidFiles)>0:
            self.reportStatus(CPluginScript.FAILED)
        self.checkOutputData()
        
        #Do sequence alignemtns if needed
        chainsOfSequence = []
        if self.container.inputData.PROVIDESEQUENCES:
            iSequence = 0
            for sequenceFile in self.container.inputData.SEQUENCE_LIST:
                tmpFilename = os.path.normpath(os.path.join(self.getWorkDirectory(),'ProvidedSequence_'+str(iSequence)+'.fasta'))
                import shutil
                shutil.copyfile(sequenceFile.fullPath.__str__(), tmpFilename)
                #sequenceData = sequenceFile.fileContent.saveFile(tmpFilename)
                chainsOfSequence.append([])
                iSequence += 1
            
            sequences = self.sequencesFromPDB(self.container.inputData.XYZIN.__str__())
            
            for chainId in sequences:
                
                tmpFilename = os.path.normpath(os.path.join(self.getWorkDirectory(),'catenated.fasta'))
                providedFiles = []
                with open(tmpFilename,'w') as tmpFile:
                    tmpFile.write('>CHAIN_'+chainId+'\n')
                    tmpFile.write(sequences[chainId])
                    from glob import glob
                    providedFiles = glob(os.path.normpath(os.path.join(self.getWorkDirectory(),'ProvidedSequence_*.fasta')))
                    for fname in providedFiles:
                        with open(fname,'r') as infile:
                            tmpFile.write(infile.read())
                            tmpFile.write ('\n')
                #Align all of these sequences using clustalw
                clustalPlugin = self.makePluginObject('clustalw')
                clustalPlugin.container.inputData.SEQUENCELISTORALIGNMENT.set('ALIGNMENT')
                clustalPlugin.container.inputData.ALIGNMENTIN.setFullPath(tmpFilename)
                clustalPlugin.container.outputData.ALIGNMENTOUT.setFullPath(tmpFilename+'_out')
                rv = clustalPlugin.process()
                if rv != CPluginScript.SUCCEEDED:
                    self.reportStatus(rv)
                #Identify best pair from the above, and rerun using only these sequences
                programXML = CCP4Utils.openFileToEtree(clustalPlugin.makeFileName('PROGRAMXML'))
                bestPairNodes = programXML.xpath('BestPair')
                if len(bestPairNodes) != 1:
                    self.appendErrorReport(200,str(len(bestPairNodes)))
                    self.reportStatus(CPluginScript.FAILED)
                
                partnerNodes = bestPairNodes[0].xpath('Partner')
                otherPartner = [int(partnerNode.text) for partnerNode in partnerNodes if partnerNode.text != '1'][0]
                chainsOfSequence[otherPartner-2].append(chainId)
                #Rerun using only the best partner
                with open(tmpFilename,'w') as tmpFile:
                    tmpFile.write('>CHAIN_'+chainId+'\n')
                    tmpFile.write(sequences[chainId])
                    from glob import glob
                    providedFiles = glob(os.path.normpath(os.path.join(self.getWorkDirectory(),'ProvidedSequence_*.fasta')))
                    fname = providedFiles[otherPartner-2]
                    with open(fname,'r') as infile:
                        tmpFile.write(infile.read())
                        tmpFile.write ('\n')
                #Align all of these sequences using clustalw
                clustalPlugin = self.makePluginObject('clustalw')
                clustalPlugin.container.inputData.SEQUENCELISTORALIGNMENT.set('ALIGNMENT')
                clustalPlugin.container.inputData.ALIGNMENTIN.setFullPath(tmpFilename)
                clustalPlugin.container.outputData.ALIGNMENTOUT.setFullPath(tmpFilename+'_out')
                rv = clustalPlugin.process()
                if rv != CPluginScript.SUCCEEDED:
                    self.reportStatus(rv)

                alignChainNode = etree.SubElement(self.xmlroot,'AlignChain')
                chainIdNode = etree.SubElement(alignChainNode,'ChainId')
                chainIdNode.text = chainId
                clustalwNode = CCP4Utils.openFileToEtree(clustalPlugin.makeFileName('PROGRAMXML'))
                alignChainNode.append(clustalwNode)
        
        self.flushXML()
        
        refmacPlugin = self.makePluginObject('refmac')
        for propertyName in ['F_SIGF','FREERFLAG','XYZIN','TLSIN','DICT']:
            if hasattr(self.container.inputData,propertyName) and getattr(self.container.inputData,propertyName).isSet():
                setattr(refmacPlugin.container.inputData, propertyName, getattr(self.container.inputData,propertyName))
        
        #Here spoof forcing anomalous.  Really refmac should have this as explicit settable stuff
        if self.container.inputData.USINGIORF.__str__() == 'I':
            refmacPlugin.container.controlParameters.USE_TWIN.set(True)
            refmacPlugin.container.controlParameters.TWIN_TYPE.set('I')
        refmacPlugin.container.controlParameters.NCYCLES.set(0)

        #Fold TLS into the B-factors ofthe output PDB
        if refmacPlugin.container.inputData.TLSIN.isSet() or refmacPlugin.container.inputData.get('AUTOTLS', False):
            refmacPlugin.container.controlParameters.TLSOUT_ADDU.set(True)
        refmacPlugin.container.controlParameters.NTLSCYCLES.set(0)
        refmacPlugin.container.controlParameters.NTLSCYCLES_AUTO.set(0)
        refmacPlugin.container.controlParameters.TLSBFACSETUSE.set(False)
    
        #Don't dip out just because DICT not available
        refmacPlugin.container.controlParameters.MAKE_NEW_LIGAND_EXIT.set(False)
        rv = refmacPlugin.process()
        
        if rv is not CPluginScript.SUCCEEDED:
            self.reportStatus(rv)

        refmacRootNode = CCP4Utils.openFileToEtree(refmacPlugin.makeFileName('PROGRAMXML'))
        self.xmlroot.append(refmacRootNode)
        self.flushXML()
        
        #Now convert refmac input to mmcif
        self.hklin2cifPlugin = self.makePluginObject('hklin2cif')
        hklinPath = os.path.normpath(os.path.join(refmacPlugin.getWorkDirectory(),'hklin.mtz'))
        self.hklin2cifPlugin.container.inputData.HKLIN.setFullPath(hklinPath)
        rv = self.hklin2cifPlugin.process()
        if rv != CPluginScript.SUCCEEDED: self.reportStatus(rv)
        
        #First pass of pdb_extract
        pdbExtract1Plugin = self.makePluginObject('pdb_extract_wrapper')
        pdbExtract1Plugin.container.inputData.XYZIN = refmacPlugin.container.outputData.XYZOUT
        pdbExtract1Plugin.container.outputData.ENTRYDATA.setFullPath(os.path.normpath(os.path.join(pdbExtract1Plugin.getWorkDirectory(),'data_template.text')))

        rv = pdbExtract1Plugin.process()
        if rv is not CPluginScript.SUCCEEDED:
            self.reportStatus(rv)
        
        modifiedTemplatePath = os.path.normpath(os.path.join(self.getWorkDirectory(),'modifiedTemplate.text'))
        templatePath = pdbExtract1Plugin.container.outputData.ENTRYDATA.__str__()
        print 'templatePath, modifiedTemplatePath',templatePath, modifiedTemplatePath
        with open(templatePath,'r') as blankTemplate:
            lines = blankTemplate.readlines()
            with open(modifiedTemplatePath,'w') as modifiedTemplate:
                inMolecules = False
                for line in lines:
                    
                    if "molecule_entity_id" in line:
                        inMolecules = True
                    elif "CATEGORY 3:" in line:
                        inMolecules = False
                        iEntity = 1
                        for sequence in self.container.inputData.get('SEQUENCE_LIST', []):
                            modifiedTemplate.write('<molecule_entity_id="%d">\n'%(iEntity))
                            modifiedTemplate.write('<molecule_entity_type="polypeptide(L)" >\n')
                            modifiedTemplate.write('<molecule_one_letter_sequence="\n%s">\n'%(sequence.fileContent.sequence.__str__()))
                            if len(chainsOfSequence) > iEntity-1:
                                modifiedTemplate.write('<molecule_chain_id="%s">\n'%(','.join(chainsOfSequence[iEntity-1])))
                            modifiedTemplate.write('< target_DB_id=" " > (if known) \n\n\n')
                            iEntity += 1
                    if not inMolecules: modifiedTemplate.write(line)

        #Second pass of pdb_extract
        pdbExtract2Plugin = self.makePluginObject('pdb_extract_wrapper')
        pdbExtract2Plugin.container.inputData.XYZIN = refmacPlugin.container.outputData.XYZOUT
        pdbExtract2Plugin.container.inputData.ENTRYDATAIN.set(modifiedTemplatePath)

        rv = pdbExtract2Plugin.process()
        if rv is not CPluginScript.SUCCEEDED:
            self.reportStatus(rv)
        
        import shutil
        structureCifPath = os.path.normpath(os.path.join(self.container.inputData.OUTPUT_DIRECTORY.__str__(),'Coordinates.cif'))
        reflectionCifPath = os.path.normpath(os.path.join(self.container.inputData.OUTPUT_DIRECTORY.__str__(),'Reflections.cif'))
        try:
            shutil.copyfile(pdbExtract2Plugin.container.outputData.CIFFILE.__str__(), structureCifPath)
            shutil.copyfile(self.hklin2cifPlugin.container.outputData.CIFFILE.__str__(), reflectionCifPath)
        except:
            self.appendErrorReport(201)
            self.reportStatus(CPluginScript.FAILED)
        self.reportStatus(CPluginScript.SUCCEEDED)

    def flushXML(self, xml=None):
        if xml is None:
            if hasattr(self,'xmlroot'): xml=self.xmlroot
        import os
        tmpFilename = self.makeFileName('PROGRAMXML')+'_tmp'
        with open(tmpFilename,'w') as tmpFile:
            tmpFile.write(etree.tostring(xml, pretty_print=True))
        self.renameFile(tmpFilename, self.makeFileName('PROGRAMXML'))

    def sequencesFromPDB(self, filePath):
        tlcOlcMap = {'ALA':'A','CYS':'C','ASP':'D','GLU':'E','PHE':'F','GLY':'G','HIS':'H','ILE':'I','LYS':'K','LEU':'L','MET':'M','ASN':'N','PRO':'P','GLN':'Q','ARG':'R','SER':'S','THR':'T','VAL':'V','TRP':'W','TYR':'Y'}
        sequences = {}
        import mmut
        if os.path.isfile(filePath):
            from CCP4ModelData import CPdbData
            aCPdbData = CPdbData()
            aCPdbData.loadFile(filePath)
            mmdbManager = aCPdbData.mmdbManager
            for peptideChainId in aCPdbData.composition.peptides:
                sequences[peptideChainId] = ''
                peptideChain = mmdbManager.GetChain(1,peptideChainId)
                for i in range(peptideChain.GetNumberOfResidues()):
                    residueName = peptideChain.GetResidue(i).GetResName()
                    if residueName in tlcOlcMap:
                        sequences[peptideChainId]+= tlcOlcMap[residueName]
                    else:
                        sequences[peptideChainId]+= '-'
                    if (i+1)%60 == 0: sequences[peptideChainId]+= '\n'
                if sequences[peptideChainId][-1] is not '\n': sequences[peptideChainId] += '\n'
        return sequences
    

