#!/opt/sharcnet/python/2.7.1/bin/python

# http://molmod.ugent.be/code/static/doc/sphinx/molmod/latest/
# http://www.scipy.org/Tentative_NumPy_Tutorial

import shutil, sys
import os, numpy, getopt,re


from molmod import *
from molmod.periodic import periodic
from molmod.io import FCHKFile

from numpy import *
from numpy.linalg import *
from amber_mole import *


class InternalCoordinate(object):
    """Abstract base class for all internal coordinates."""
    def __init__(self, indexes, icfn,scale=1.0):
        """
           Arguments:
            | ``indexes`` -- The indexes of the atoms in the internal
                             coordinate. The order must be the same as the order
                             of the mandatory arguments of icfn.
            | ``icfn`` -- a function from molmod.ic that can compute the
                          internal coordinate and its derivatives.
            | ``conversion`` -- In case the internal coordinate does not have a
                                unit of length, then this conversion factor is
                                used to convert it to a length unit. This way,
                                the Jacobian becomes a dimensionless constant.

           All the Jacobian-logic is implemented in this abstract class.
        """
        self.indexes = indexes
        self.icfn = icfn
        self.scaling = scale
        
    def __call__(self, coordinates):
        return self.icfn(coordinates[list(self.indexes)])[0]


class BondLength(InternalCoordinate):
    def __init__(self, i, j):
        InternalCoordinate.__init__(self, (i, j), bond_length)


class BendingAngle(InternalCoordinate):
    def __init__(self, i, j, k):
        InternalCoordinate.__init__(self, (i, j, k), bend_angle, angstrom/(5*deg))


class DihedralAngle(InternalCoordinate):
    def __init__(self, i, j, k, l):
        InternalCoordinate.__init__(self, (i, j, k, l), dihed_angle, angstrom/(5*deg))

def calc_jacobian(ics, xyz):
    JJ = []
    for j, ic in enumerate(ics):
        ind = list( ic.indexes )
        # Let the ic object fill in each column of the Jacobian.
        ignore, grad = ic.icfn(xyz[ind],1)
        col = [0]*3*len(xyz)
        for i in range(len(ind)):
            col[3*ind[i]:3*ind[i]+3] = grad[i]
        JJ.append(col)

    return matrix( JJ ).transpose()

def compute_K(ics, xyz, g_q):
    natoms = len(xyz)
    kk = zeros(( 3*natoms, 3*natoms) )
    for ii,ic in enumerate(ics):
        ind = list(ic.indexes)
        ignore, ignore2, kout = ic.icfn(xyz[ind],2)
        for i, j in enumerate(ind):
            for k,l in enumerate(ind):
                kk[3*j:3*j+3,3*l:3*l+3] += kout[i,0:3,k,:]*g_q[0,ii]
 
    return kk

def read_MMhess():
    f = open("hess.dat",'r')
    MMhess=[]
    for line in f:
        MMhess.append(map(float,line.split()))
    f.close()
    return numpy.mat(MMhess)

def read_none_MMhess():
    f = open("hess_missing.dat",'r')
    MMhessin=[]
    for line in f:
        MMhessin.append(map(float,line.split()))
    f.close()
    return numpy.mat(MMhessin)

def read_extra_dihed():
    extra_dihed=[]
    if (os.path.exists("extra_dihedrals.dat")):
        f = open("extra_dihedrals.dat",'r')
        for line in f:
            m = line.split(" : ")
            diheds = eval(m[1])
#            print diheds
            extra_dihed.append(diheds)
        f.close()

    return extra_dihed

def write_QMhess(QMhess):
    f = open("QMhess.dat",'w')
    for i in range(len(QMhess)):
        for j in range(len(QMhess)):
            if (abs(QMhess[i,j])<0.000001):
                f.write("0.0000")
            else:
                f.write(str(QMhess[i,j]))
            f.write(" ")
        f.write("\n")
    f.close()

def write_hess_fchk(QMhess):
    f = open("hess_fchk.dat",'w')
    k = 0
    for i in range(len(QMhess)):
        for j in range(i+1):
            f.write(" ")
            f.write('{0: .8E}'.format(QMhess[i,j]))
            k = k + 1
            if (k%5==0): f.write(" \n")
    f.close()


def read_in_coordinates():
    bonds=[]
    angles=[]
    dihed=[]
    if (os.path.exists("coordinates.dat")):
        f = open("coordinates.dat",'r')
        for line in f:
            coords = eval(line)
            for el in coords:
                if (len(el)==2):
                    bonds.append(el)
                elif (len(el)==3):
                    angles.append(el)
                elif (len(el)==4):
                    dihed.append(el)
        f.close()
    else:
        print "Could not find coordinates.dat file"
        sys.exit(0)

    return bonds, angles, dihed
                
def main():
    set_printoptions(threshold=nan)
    if (len(sys.argv)==1):
        print "usage: <prefix> -s <scaling factor> -x[use only extra dihedrals] -c[only use coordinates.dat] -d[use the diagonal terms of the Hessian matrix]"
        sys.exit(0)

    prefix = sys.argv[1]
    fchk = FCHKFile("{0}.fchk".format(prefix))
    energy = fchk.fields["Total Energy"]
    xyz = fchk.fields["Current cartesian coordinates"] # Read in coordinates in bohr.
    mass = fchk.fields["Real atomic weights"]
    QMhess = fchk.get_hessian()
    fn_xyz = "{0}.xyz".format(prefix)

    # [('--scaling', '0.9751'), ('-x', '')]
    opts, args = getopt.getopt(sys.argv[2:], 'xcdf',['scaling='])
    scaling = 0.8953 # HF value is default

    svd_para = 0.0001
    extra_only = 0
    coord_only = 0
    diag_only = 0
    freq_only = 0 
    for i, j in opts:
        if re.match("--scaling",i):
            scaling = float(j)
        if re.match("-x",i):
            extra_only = 1
        if re.match("-c",i):
            coord_only = 1
        if re.match("-d",i):
            diag_only = 1
        if re.match("-f",i):
            freq_only = 1
        if re.match("--svd",i):
            svd_para = float(j)
            
    mol0 = Molecule.from_file(fn_xyz)
    N = mol0.size

    am = amber(prefix)
    am.read_parm_file()
    bonds = am.build_bond_list()
    angles = am.build_angle_list()
    diheds = am.build_dihed_list()
    
    # Get the gradient and energy
    fchk = FCHKFile("{0}.fchk".format(prefix))
    energy = fchk.fields["Total Energy"]
    grad = fchk.fields["Cartesian Gradient"]

    MMhess = read_MMhess()
    write_hess_fchk(MMhess*0.529177249*0.529177249/627.509/scaling)
    write_QMhess(QMhess*scaling/0.529177249/0.529177249*627.509)
            
#    MM_none_hess = read_none_MMhess()
    extra_dihed = read_extra_dihed()

    hess = QMhess*scaling - MMhess*0.529177249*0.529177249/627.509

    if (freq_only==1):
        freq1 = freq.freq_hess(xyz,mass,QMhess*scaling)
        freq2 = freq.freq_hess(xyz,mass,MMhess*0.529177249*0.529177249/627.509)
        print freq1
        print freq2
        
#    hess = freq.ortho_vec(xyz,mass,hess)
    f = open("hessnorm.dat",'w')
    f.write(str(norm(hess)))
    f.close()

    damp_factor = 1

    # Scale the Zn radius for additional bonds.
    periodic[30].covalent_radius = periodic[30].covalent_radius * 1.2

    if (diag_only==1):
        # The diagonal terms
        ics = []
        for i, j in bonds:
            ics.append(BondLength(i, j))
        for i, j, k in angles:
            ics.append(BendingAngle(i, j, k))
        for i, j, k, l in diheds:
            ics.append(DihedralAngle(i, j, k, l))

        J = calc_jacobian(ics, mol0.coordinates)
        
        Jinv = linalg.pinv(J, svd_para)
        g_q = dot(Jinv,grad)
        K = compute_K(ics, mol0.coordinates, g_q)

        HK = hess - K
        
        hess_q = dot(Jinv,numpy.dot(HK,Jinv.transpose()))

        index = 0
        for i, j in bonds:
            vect_x = mol0.coordinates[i, 0] - mol0.coordinates[j, 0]
            vect_y = mol0.coordinates[i, 1] - mol0.coordinates[j, 1]
            vect_z = mol0.coordinates[i, 2] - mol0.coordinates[j, 2]
            vector = [vect_x, vect_y, vect_z]
            dist = numpy.sqrt(numpy.vdot(vector,vector))/angstrom

            k_q = hess_q[index,index]*(1.0/0.529177249)*(1.0/0.529177249)*627.509/2.0
            index = index + 1
            print '%3d   %3d : %8.3f %6.3f' % (i,j,k_q,dist)

        for i, j, k in angles:
            angle_val = bend_angle(mol0.coordinates[[i,j,k]])[0]/deg
            k_q = hess_q[index,index]*627.509/2.0
            index = index + 1
            print '%3d  %3d  %3d : %8.3f %8.3f' % (i,j,k,k_q,angle_val)

        sys.exit(0)
        
        for i, j, k, l in diheds:
            Dihedral_value = dihed_angle(mol0.coordinates[i], mol0.coordinates[j], mol0.coordinates[k], mol0.coordinates[l])[0]/deg
            if (Dihedral_value<-1):    Dihedral_value = Dihedral_value + 360
#            Dihedral_value = Dihedral_value - 180        # for PN = 1
#            if (abs(Dihedral_value)<1 or abs(Dihedral_value-180)<1. or abs(Dihedral_value+180)<1):
#                Dihedral_value=0
            k_q = hess_q[index,index]*627.509
            index = index + 1            
            print '%3d %3d %3d %3d: 1 %8.3f %8.3f 1.0' % (i,j,k,l,k_q,Dihedral_value)
    
    # A) Collect all bonds.
    for i, j in bonds:
        ics = []
        ics.append(BondLength(i, j))
        vect_x = mol0.coordinates[i, 0] - mol0.coordinates[j, 0]
        vect_y = mol0.coordinates[i, 1] - mol0.coordinates[j, 1]
        vect_z = mol0.coordinates[i, 2] - mol0.coordinates[j, 2]
        vector = [vect_x, vect_y, vect_z]
        dist = numpy.sqrt(numpy.vdot(vector,vector))/angstrom

        J = calc_jacobian(ics, mol0.coordinates)

        Jinv = numpy.linalg.pinv(J, svd_para)
        g_q = numpy.dot(Jinv,grad)
        K = compute_K(ics, mol0.coordinates, g_q)
        HK = hess - K
        hess_q = numpy.dot(Jinv,numpy.dot(HK,Jinv.transpose()))
        
        k_q = hess_q[0,0]*(1.0/0.529177249)*(1.0/0.529177249)*627.509/2.0/damp_factor
        print '%3d   %3d : %8.3f %6.3f' % (i,j,k_q,dist)

    # B) Collect all angles.
    for i, j, k in angles:
        ics = []
        ics.append(BendingAngle(i, j, k))
        angle_val = bend_angle(mol0.coordinates[[i,j,k]])[0]/deg

        J = calc_jacobian(ics, mol0.coordinates)
        Jinv = numpy.linalg.pinv(J, svd_para)
        g_q = numpy.dot(Jinv,grad)
        K = compute_K(ics, mol0.coordinates, g_q)        
        HK = hess - K
        hess_q = numpy.dot(Jinv,numpy.dot(HK,Jinv.transpose()))

        k_q = hess_q[0,0]*627.509/2.0/damp_factor
        print '%3d  %3d  %3d : %8.3f %8.3f' % (i,j,k,k_q,angle_val)

    sys.exit(0)
    
    # C) Collect all dihedrals.
    ics = []
    for i, j in bonds:
        ics.append(BondLength(i, j))
    for i, j, k in angles:
        ics.append(BendingAngle(i, j, k))
        
    nn = len(ics)

    # Check for extra dihedrals to add.
    if (extra_only==1):
        Dihedral = []
        for ii in range(len(extra_dihed)):
            dihed = extra_dihed[ii][0]
            Dihedral.append((dihed[0], dihed[1], dihed[2], dihed[3]))
    else:
        for ii in range(len(extra_dihed)):
            found=False
            for dihed in extra_dihed[ii]:
                for i, j, k, l in Dihedral:
                    if (i==dihed[0] and j==dihed[1] and k==dihed[2] and l==dihed[3]) or (l==dihed[0] and k==dihed[1] and j==dihed[2] and i==dihed[3]):
                        found=True
                        break
                    if found: break
                    if (not found):
                        dihed = extra_dihed[ii][0]
                        Dihedral.append((dihed[0], dihed[1], dihed[2], dihed[3]))

    for i, j, k, l in diheds:
        ics.append(DihedralAngle(i, j, k, l))

#    newxyz = check_dihedral(Dihedral,mol0)
            
    hess = QMhess*scaling  # - MMhess*0.529177249*0.529177249/627.509

    J = calc_jacobian(ics, mol0.coordinates)

    wilson = J.transpose()
    Jinv = numpy.linalg.pinv(J, 1e-10)
    wilsoninv = numpy.linalg.pinv(wilson, 1e-10)
    g_q = numpy.dot(Jinv,grad)
    hess_q = numpy.dot(Jinv,numpy.dot(hess,wilsoninv))
    hess_q = 0.5*(hess_q + hess_q.transpose())
    index = 0
#    print hess_q
    
    for i, j, k, l in diheds:
        Dihedral_value = dihed_angle(mol0.coordinates[i], mol0.coordinates[j], mol0.coordinates[k], mol0.coordinates[l])[0]/deg

#        k_q = hess_q[index+nn,index+nn]*627.509/damp_factor
        k_q = hess_q[nn+index,nn+index]*627.509

        if (Dihedral_value<0):    Dihedral_value = Dihedral_value + 360
        Dihedral_value = Dihedral_value - 180        # for PN = 1
        index = index + 1
        if (abs(Dihedral_value)<1 or abs(Dihedral_value-180)<1. or abs(Dihedral_value+180)<1):
            Dihedral_value=0

        print '%3d %3d %3d %3d: 1 %8.3f %8.3f 1.0' % (i,j,k,l,k_q,Dihedral_value)
    
if __name__ == "__main__":
    main()            

 
                   
            



            
            
    
