"""
     validate_protein.py: CCP4 GUI 2 Project
     Copyright (C) 2014 The University of York

     This library is free software: you can redistribute it and/or
     modify it under the terms of the GNU Lesser General Public License
     version 3, modified in accordance with the provisions of the 
     license to address the requirements of UK law.
 
     You should have received a copy of the modified GNU Lesser General 
     Public License along with this library.  If not, copies may be 
     downloaded from http://www.ccp4.ac.uk/ccp4license.php
 
     This program is distributed in the hope that it will be useful,
     but WITHOUT ANY WARRANTY; without even the implied warranty of
     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     GNU Lesser General Public License for more details.
"""


## @package validate_protein
# This package computes various metrics used in validation and deposition,
# for example average B-factors by chain, ligands, etc and Ramachandran statistics.


import clipper

# the following lines are required for running clipper-python subprograms outside of i2
import sys,os
ccp4_home = os.environ.get ( "CCP4", "not_set" )

if ccp4_home == "not_set" :
    sys.exit("\nError: CCP4 environment variable must be set\n\n")
else:
    sys.path.append ( "%s/share/ccp4i2/core" % ccp4_home )

from CCP4PluginScript import CPluginScript
from CCP4ClipperUtils import read_pdb, is_aminoacid, is_mainchain


class validate_protein(CPluginScript):


    TASKNAME = 'validate_protein'
    WHATNEXT = [ 'coot_rebuild' ]
    MAINTAINER = 'jon.agirre@york.ac.uk'


    def process ( self ) :
    
        print self.container.inputData.XYZIN
        from lxml import etree

        log_string, xml_root = b_averages ( str ( self.container.inputData.XYZIN ) )
        log_rama, xml_rama   = ramachandran_maps ( str ( self.container.inputData.XYZIN ) )
        
        #create_rama_backgrounds ()
        
        log_string += log_rama
        xml_root.append ( xml_rama )
    
        with open ( self.makeFileName('PROGRAMXML'),'w') as xml_file:
            xml_file.write ( etree.tostring ( xml_root, pretty_print=True ) )
        
        self.reportStatus(CPluginScript.SUCCEEDED)
        
        return CPluginScript.SUCCEEDED


## A subprogram that creates PNG backgrounds for Rama plots - should only be called once clipper's Rama data get updated
## Do not call this function in a production environment, it will probably crash due to permissions!

def create_rama_backgrounds ( ) :
    
    phi = psi = -6.28

    # [0]=phi, [1]=psi
    favoured_pro   = [ [], [] ]
    favoured_gly   = [ [], [] ]
    favoured_rst   = [ [], [] ]

    allowed_pro = [ [], [] ]
    allowed_gly = [ [], [] ]
    allowed_rst = [ [], [] ]

    rama_gly = clipper.Ramachandran ( clipper.Ramachandran.Gly )
    rama_pro = clipper.Ramachandran ( clipper.Ramachandran.Pro )
    rama_rst = clipper.Ramachandran ( clipper.Ramachandran.NonGlyPro )

    while phi < clipper.Util.pi() :
        while psi < clipper.Util.pi() :
            if rama_gly.favored ( phi, psi ) :
                favoured_gly[0].append (phi)
                favoured_gly[1].append (psi)
            elif rama_gly.allowed ( phi, psi ) :
                allowed_gly[0].append (phi)
                allowed_gly[1].append (psi)
            
            if rama_pro.favored ( phi, psi ) :
                favoured_pro[0].append (phi)
                favoured_pro[1].append (psi)
            elif rama_pro.allowed ( phi, psi ) :
                allowed_pro[0].append (phi)
                allowed_pro[1].append (psi)
            
            if rama_rst.favored ( phi, psi ) :
                favoured_rst[0].append (phi)
                favoured_rst[1].append (psi)
            elif rama_rst.allowed ( phi, psi ) :
                allowed_rst[0].append (phi)
                allowed_rst[1].append (psi)
            
            psi += clipper.Util.d2rad ( 1.5 )
        
        psi = - clipper.Util.pi()
        phi += clipper.Util.d2rad ( 1.5 )

    import matplotlib.pyplot as plt


    plt.scatter ( allowed_gly[0], allowed_gly[1], s=0.1, color="lightblue", alpha=0.2 )
    plt.scatter ( favoured_gly[0], favoured_gly[1], s=3.0, color="lightblue", alpha=0.5 )
    plt.xlim (-3.14, 3.14)
    plt.ylim (-3.14, 3.14)
    plt.axis( 'off' )
    plt.savefig( "/tmp/rama_gly.png", bbox_inches = 'tight', pad_inches = 0 )
    plt.close()


    plt.figure()
    plt.scatter ( allowed_pro[0], allowed_pro[1], s=0.1, color="lightblue", alpha=0.2 )
    plt.scatter ( favoured_pro[0], favoured_pro[1], s=3.0, color="lightblue", alpha=0.5 )
    plt.xlim (-3.14, 3.14)
    plt.ylim (-3.14, 3.14)
    plt.axis( 'off' )
    plt.savefig( "/tmp/rama_pro.png", bbox_inches = 'tight', pad_inches = 0 )
    plt.close()

    plt.figure()
    plt.scatter ( allowed_rst[0], allowed_rst[1], s=0.1, color="lightblue", alpha=0.2 )
    plt.scatter ( favoured_rst[0], favoured_rst[1], s=3.0, color="lightblue", alpha=0.5 )
    plt.xlim (-3.14, 3.14)
    plt.ylim (-3.14, 3.14)
    plt.axis( 'off' )
    plt.savefig( "/tmp/rama_rst.png", bbox_inches = 'tight', pad_inches = 0 )
    plt.close()


## A Clipper subprogram that calculates B-averages and returns and XML tree and a log string
#  @param pdbin A string containing the path to a PDB

def b_averages ( pdbin = "undefined" ) :
    
    log_string = "\n\n######### clipper-python util: b_averages #########\n\n"
    log_string += "pdbin:\t%s\n\n" % pdbin
    
    from lxml import etree
    
    xml_root = etree.Element('B_averages')
    
    log_buffer, xml_buffer, mmol = read_pdb ( pdbin )
    
    log_string += log_buffer
    xml_root.append ( xml_buffer )
    
    n_atom_ligands = n_ligands = n_atom_waters = n_waters = n_atoms_sidechain = n_atoms_mainchain = n_atom_residues = n_residues = 0
    u_atoms_mainchain = u_atom_residues = u_atoms_sidechain = u_atom_waters = u_waters = u_ligands = u_atom_ligands = 0.0
    
    # to do: radially averaged B factor by chain
    # to do: per residue correlation VS mean B factor, plot by colour
    
    type = "None"
    by_monomer = etree.SubElement ( xml_root, 'By_Residue' )

    # a set of lists for easy stddev calculation
    mc_bfacts = [ ]
    sc_bfacts = [ ]
    water_bfacts = [ ]
    ligand_bfacts = [ ]
    
    
    for poly in mmol :
        chain = etree.SubElement ( by_monomer, "Chain", id=poly.id().trim() )
        for mono in poly :
            res = None
            if mono.type().trim() == "HOH" :
                res = etree.SubElement ( chain, 'Water' )
                number = etree.SubElement ( res, 'Number' )
                number.text = str(mono.id().trim())
                type = "Water"
                n_waters += 1
            elif is_aminoacid ( mono.type().trim() ) :
                res = etree.SubElement ( chain, 'Aminoacid' )
                number = etree.SubElement ( res, 'Number' )
                number.text = str(mono.id().trim())
                type = "Aminoacid"
                n_residues += 1
            else :
                res = etree.SubElement ( chain, 'Ligand' )
                number = etree.SubElement ( res, 'Number' )
                number.text = str(mono.id().trim())
                type = "Ligand"
                n_ligands += 1
            
            atom_counter = main_chain_atom_counter = 0
            entity_u_value = main_chain_entity_u = 0.0
            
            for atom in mono :
                if type == "Aminoacid" :
                    n_atom_residues += 1
                    u_atom_residues += atom.u_iso()
                    
                    if is_mainchain ( atom.id().trim() ):
                        n_atoms_mainchain += 1
                        main_chain_atom_counter += 1
                        u_atoms_mainchain += atom.u_iso()
                        main_chain_entity_u += atom.u_iso()
                        mc_bfacts.append ( atom.u_iso() )
                    
                    else :
                        n_atoms_sidechain += 1
                        u_atoms_sidechain += atom.u_iso()
                        sc_bfacts.append ( atom.u_iso() )
        
                elif type == "Water" :
                    n_atom_waters += 1
                    u_atom_waters += atom.u_iso()
                    water_bfacts.append ( atom.u_iso () )
                else :
                    n_atom_ligands += 1
                    u_atom_ligands += atom.u_iso()
                    ligand_bfacts.append ( atom.u_iso () )
    
                atom_counter += 1
                entity_u_value += atom.u_iso()
            
            etree.SubElement ( res, "Atom_count" ).text = str ( atom_counter )
            etree.SubElement ( res, "Mean_B" ).text = str ( clipper.Util.u2b( entity_u_value / atom_counter ) )

            if type == "Aminoacid" :
                if main_chain_atom_counter == 0 :
                    etree.SubElement ( res, "Main_Chain_Mean_B" ).text = "-"
                else :
                    etree.SubElement ( res, "Main_Chain_Mean_B" ).text = str ( clipper.Util.u2b ( main_chain_entity_u / main_chain_atom_counter ) )
                if atom_counter > 4 :
                    etree.SubElement ( res, "Side_Chain_Mean_B" ).text = str ( clipper.Util.u2b ((entity_u_value - main_chain_entity_u) / ( atom_counter - main_chain_atom_counter) ) )
                else :
                    etree.SubElement ( res, "Side_Chain_Mean_B" ).text = "-"

    total = etree.SubElement ( xml_root, "Totals" )
    
    log_string += "Totals calculated: \n\n"

    if n_atom_residues > 0 :
        aminoacids = etree.SubElement ( total, "Aminoacids" )
        etree.SubElement ( aminoacids, "Atom_count" ).text = str ( n_atom_residues )
        etree.SubElement ( aminoacids, "Mean_B" ).text = str (clipper.Util.u2b(u_atom_residues / n_atom_residues) )
        etree.SubElement ( aminoacids, "Main_Chain_Mean_B" ).text = str ( clipper.Util.u2b( u_atoms_mainchain / n_atoms_mainchain ) )
        etree.SubElement ( aminoacids, "Side_Chain_Mean_B" ).text = str ( clipper.Util.u2b( u_atoms_sidechain / n_atoms_sidechain ) )
        
        residuals = 0.0
        residuals_totals = 0.0
        
        import math
        
        for item in mc_bfacts :
            residuals = residuals + math.pow ( clipper.Util.u2b (item) - clipper.Util.u2b( u_atoms_mainchain / n_atoms_mainchain ), 2 )
        try :
            etree.SubElement ( aminoacids, "Main_Chain_StdDev" ).text = str ( math.sqrt( residuals / n_atoms_mainchain - 1) )
        except ValueError:
            etree.SubElement ( aminoacids, "Main_Chain_StdDev" ).text = "0.0"

        residuals_totals = residuals
        residuals = 0.0
        
        for item in sc_bfacts :
            residuals = residuals + math.pow ( clipper.Util.u2b (item) - clipper.Util.u2b( u_atoms_sidechain / n_atoms_sidechain ), 2 )
        
        residuals_totals += residuals

        try :
            etree.SubElement ( aminoacids, "All_StdDev" ).text = str ( math.sqrt( residuals_totals / n_atom_residues - 1) )
        except ValueError :
            etree.SubElement ( aminoacids, "All_StdDev" ).text = "0.0"

        try :
            etree.SubElement ( aminoacids, "Side_Chain_StdDev" ).text = str ( math.sqrt( residuals / n_atoms_sidechain - 1) )
        except ValueError :
            etree.SubElement ( aminoacids, "Side_Chain_StdDev" ).text = "0.0"


        log_string += "Atoms as part of aminoacids: %d\n" % n_atom_residues
        log_string += "Mean B-factor: %.2f\n" % clipper.Util.u2b(u_atom_residues / n_atom_residues)
        log_string += "Main chain mean B-factor: %.2f\n" % clipper.Util.u2b ( u_atoms_mainchain / n_atoms_mainchain )
        log_string += "Side chain mean B-factor: %.2f\n" % clipper.Util.u2b ( u_atoms_sidechain / n_atoms_sidechain )

    if n_atom_ligands > 0 :
    
        ligands = etree.SubElement ( total, "Ligands" )
        etree.SubElement ( ligands, "Atom_count" ).text = str ( n_atom_ligands )
        etree.SubElement ( ligands, "Mean_B" ).text = str ( clipper.Util.u2b(u_atom_ligands / n_atom_ligands ) )
        
        residuals = 0.0
        
        import math
        
        for item in ligand_bfacts :
            residuals = residuals + math.pow ( clipper.Util.u2b (item) - clipper.Util.u2b( u_atom_ligands / n_atom_ligands ), 2 )
        
        etree.SubElement ( ligands, "Ligands_StdDev" ).text = str ( math.sqrt( residuals / n_atom_ligands - 1) )
        
        
        log_string += "Atoms as part of (non-water) ligands: %d\n" % n_atom_ligands
        log_string += "Mean B-factor: %.2f\n" % clipper.Util.u2b(u_atom_ligands / n_atom_ligands)


    if n_atom_waters > 0 :

        waters = etree.SubElement ( total, "Waters" )
        etree.SubElement ( waters, "Atom_count" ).text = str ( n_atom_waters )
        etree.SubElement ( waters, "Mean_B" ).text = str ( clipper.Util.u2b ( u_atom_waters / n_atom_waters) )
        
        import math
        
        for item in water_bfacts :
            residuals = residuals + math.pow ( clipper.Util.u2b (item) - clipper.Util.u2b( u_atom_waters / n_atom_waters ), 2 )
        
        etree.SubElement ( waters, "Waters_StdDev" ).text = str ( math.sqrt( residuals / n_atom_waters - 1) )

        
        log_string += "Number of waters: %d\n" % n_atom_waters
        log_string += "Mean B-factor: %.2f\n" % clipper.Util.u2b(u_atom_waters / n_atom_waters)
        
    log_string += "\n###################################################\n\n"
    
    print log_string.__str__()
    
    return log_string, xml_root



def ramachandran_maps ( pdbin = "undefined" ) :

    log_string = "\n\n######### clipper-python util: ramachandran_maps #########\n\n"
    log_string += "pdbin:\t%s\n\n" % pdbin
    
    from lxml import etree
    
    xml_root = etree.Element('Ramachandran_maps')
    
    log_buffer, xml_buffer, mmol = read_pdb ( pdbin )
    
    log_string += log_buffer
    xml_root.append ( xml_buffer )

    allowed   = etree.SubElement ( xml_root, "Allowed" )
    favoured  = etree.SubElement ( xml_root, "Favoured"  )
    outliers  = etree.SubElement ( xml_root, "Outliers"  )
    
    rama_gly = clipper.Ramachandran ( clipper.Ramachandran.Gly )
    rama_pro = clipper.Ramachandran ( clipper.Ramachandran.Pro )
    rama_rest= clipper.Ramachandran ( clipper.Ramachandran.NonGlyPro )

    prev_residue = clipper.MMonomer()
    n_residues = n_residues_chain = n_allowed = n_favoured = n_outliers = 0
    phi = psi = 0.0

    for chain in mmol :
        n_residues_chain = 0
        for residue in chain :
            n_residues_chain += 1
            
            if n_residues_chain == 1 :
                prev_residue = residue
            elif n_residues_chain == len ( chain ) :
                break
            elif is_aminoacid ( prev_residue.type().trim() ) and is_aminoacid ( residue.type().trim() ) and is_aminoacid (chain [ n_residues_chain ].type().trim() ) :
                phi = clipper.MMonomer.protein_ramachandran_phi ( prev_residue, residue )
                psi = clipper.MMonomer.protein_ramachandran_psi ( residue, chain [ n_residues_chain ] )
                n_residues += 1
                
                if residue.type().trim() == "GLY" :
                    if rama_gly.favored ( phi, psi ) :
                        residue_rama = etree.SubElement ( favoured, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_favoured += 1
                    elif rama_gly.allowed ( phi, psi ) :
                        residue_rama = etree.SubElement ( allowed, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_allowed += 1
                    else :
                        residue_rama = etree.SubElement ( outliers, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_outliers += 1

                elif residue.type().trim() == "PRO" :
                    if rama_pro.favored ( phi, psi ) :
                        residue_rama = etree.SubElement ( favoured, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_favoured += 1
                    elif rama_gly.allowed ( phi, psi ) :
                        residue_rama = etree.SubElement ( allowed, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_allowed += 1
                    else :
                        residue_rama = etree.SubElement ( outliers, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_outliers += 1
                else :
                    if rama_rest.favored ( phi, psi ) :
                        residue_rama = etree.SubElement ( favoured, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_favoured += 1
                    elif rama_rest.allowed ( phi, psi ) :
                        residue_rama = etree.SubElement ( allowed, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_allowed += 1
                    else :
                        residue_rama = etree.SubElement ( outliers, "Residue", chain=chain.id().trim(), residue=residue.id().trim(), type=residue.type().trim() )
                        etree.SubElement ( residue_rama, "Phi" ).text = str ( clipper.Util.rad2d(phi) )
                        etree.SubElement ( residue_rama, "Psi" ).text = str ( clipper.Util.rad2d(psi) )
                        n_outliers += 1
                        
                prev_residue = residue

    total = etree.SubElement ( xml_root, "Totals" )
    etree.SubElement ( total, "Residues" ).text = str (n_residues)
    etree.SubElement ( total, "Favoured" ).text = str (n_favoured)
    etree.SubElement ( total, "Allowed"  ).text = str (n_allowed)
    etree.SubElement ( total, "Outliers" ).text = str (n_outliers)
    
    log_string += "Found %d residues," % n_residues
    log_string += " with %d in favoured regions," % n_favoured
    log_string += " %d in allowed regions" % n_allowed
    log_string += " and %d being outliers\n\n" % n_outliers
    
    log_string += "\n###################################################\n\n"
    
    print log_string.__str__()
    return log_string, xml_root



if __name__ == '__main__' :

    if sys.argv > 2 :
        if sys.argv[1] == "b_averages" :
            b_averages ( pdbin = sys.argv[2] )
        elif sys.argv[1] == "ramachandran_maps" :
            ramachandran_maps ( pdbin = sys.argv[2] )


#=====================================================================================================
#=================================test suite=========================================================
#=====================================================================================================

import unittest
from CCP4Utils import getCCP4I2Dir,getTMP

# unit testing asynchronous processes potential tricky but QProcess has option to wait for finished
 
class test_validate_protein ( unittest.TestCase ) :
  
  def setUp(self):
    # make all background jobs wait for completion
    PROCESSMANAGER().setWaitForFinished(10000)

  def tearDown(self):
    PROCESSMANAGER().setWaitForFinished(-1)


  def test_validate_protein(self):
    import os
    inputData =  CScriptDataContainer(name='validate_protein_test',containerType='inputData',initialise=validate_protein.INPUTDATA)
    outputData =  CScriptDataContainer(name='validate_protein_test',containerType='outputData',initialise=validate_protein.OUTPUTDATA)
    try:
      inputData.importXML(os.path.join(getCCP4I2Dir(),'wrappers','validate_protein','test_data','validate_protein_test_1.def.xml'))
    except CException as e:
      self.fail(e.errorType)
    try:
      outputData.importXML(os.path.join(getCCP4I2Dir(),'wrappers','validate_protein','test_data','validate_protein_test_1.def.xml'))
    except CException as e:
      self.fail(e.errorType)
      
    wrapper = validate_protein()
    pid = wrapper.process()


def testSuite():
  suite = unittest.TestLoader().loadTestsFromTestCase(test_validate_protein)
  return suite

def runAllTests():
  suite = testSuite()
  unittest.TextTestRunner(verbosity=2).run(suite)
