import re
import sys,os
import math,copy
from numpy import *

class math_funs:
    RadtoDeg = 57.2957795

    def augment_num(self,num,base):
        for i in range(len(num)):
            num[i]+=1
            if (num[i]==base):
                num[i]=0
            else:
                break

    def max_val(self,ar):
        m = max(ar)
        for i,v in enumerate(ar):
            if v==m: maxel = i

        return m,maxel

    def max_absvals(self,ar,num):
        maxval = []
        maxels = []
        for i,v in enumerate(ar):
            if (len(maxels)<num):
                maxels.append(i)
                maxval.append(v)
                m, el = self.min_val(map(abs,maxval))
            else:
                if (abs(v)>m):
                    maxels[el] = i
                    maxval[el] = v
                    m, el = self.min_val(map(abs,maxval))

        return maxels, maxval

    def min_val(self,ar):
        m = min(ar)
        for i,v in enumerate(ar):
            if v==m: minel = i

        return m,minel

    def max_vals(self,ar,num):
        maxval = []
        maxels = []
        for i,v in enumerate(ar):
            if (len(maxels)<num):
                maxels.append(i)
                maxval.append(v)
                m, el = self.min_val(maxval)
            else:
                if (v>m):
                    maxels[el] = i
                    maxval[el] = v
                    m, el = self.min_val(maxval)

        return maxels, maxval

    def augment_array(self,ar,n):
        for i in range(len(ar)):
            if (ar[i]==n):
                ar[i]=0
            else:
                ar[i]+=1
                break

    def bond_length(self,ixyz,jxyz):
        x1 = array(ixyz)
        x2 = array(jxyz)
        vec = x2 - x1

        return linalg.norm(vec)

    def bond_angle(self,ixyz,jxyz,kxyz):
        x1 = array(ixyz)
        x2 = array(jxyz)
        x3 = array(kxyz)

        a = x1 - x2
        a = a / linalg.norm(a)
        b = x3 - x2
        b = b / linalg.norm(b)

        out = arccos(clip(dot(a,b),-1,1))*self.RadtoDeg
        if (out<0): out = out + 180
        
        return out

    def dihed_angle(self,ixyz,jxyz,kxyz,lxyz):
        x1 = array(ixyz)
        x2 = array(jxyz)
        x3 = array(kxyz)
        x4 = array(lxyz)
        a = x1 - x2
        b = x3 - x2
        c = x4 - x3
        b = b / linalg.norm(b)
        tmp = b.copy()
        tmp = tmp*dot(a, b)
        a = a - tmp
        tmp = b.copy()
        tmp = tmp*dot(c, b)
        c = c - tmp
        a = a/linalg.norm(a)
        c = c/linalg.norm(c)

        sign = 1-(linalg.det([a, b, c]) > 0)*2
        out = sign*arccos(clip(dot(a, c),-1,1))*180.0/math.pi
#        if (out<-math.pi/2.0):  out += math.pi/2.0
 #       if (out> math.pi/2.0):  out -= math.pi/2.0        
        return out

    def fit_constant(self,data,R_eq,angle=False,dweight=0.0):
        dE = []
        dX = []
        # data is: [x, E-E0, dE]
        print "R_eq: ", R_eq
        for dat in data:
            dx = (dat[0] - R_eq)
            if (angle):
                dat[2] *= self.RadtoDeg
                dx /= self.RadtoDeg
            dX.append(dx*dx)
            dE.append(dat[1])
            print "dE: ", dat[1],dx
            print "dE approx: ", 2*350*dx
            print "actual: ", dat[2]
            if (len(dat)==3):
                dX.append(dweight*2*dx)
                dE.append(dweight*dat[2])
        kk = dot(dE,linalg.pinv([dX]))
        return kk[0]

    def fit_dihed_constant(self,data,terms,dweight=0.0):
        dE = []
        dX = []
        # data is: [x, E-E0, dE]
        for dat in data:
            dE.append(dat[1])
            dX.append([1])
            for tt in terms:
                dX[-1].append(1/tt[1]*(1 + cos((tt[4]*dat[0]-tt[3])/self.RadtoDeg)))
            if (len(dat)==3):
                dE.append(dweight*dat[2]*self.RadtoDeg)
                dX.append([0])
                for tt in terms:
                    dX[-1].append(dweight*(1/tt[1]*(1 - tt[4]*sin((tt[4]*dat[0]-tt[3])/self.RadtoDeg))))
                
#                print dat[0], dat[1],dX[-1][0]+dX[-1][1]

        # Fit the constants.
        kk = dot(linalg.pinv(dX),dE)
        fdat = open("data.dat",'w')
        fest = open("est.dat",'w')
        for i in range(len(data)):
            fdat.write("{0} {1}\n".format(data[i][0],data[i][1]))
            fest.write("{0} {1}\n".format(data[i][0],dot(dX[i],kk)))
        fdat.close()
        fest.close()
        print "Dihedral constants: ", kk
        return kk[1:]

    def estimate_dihed_error(self,tofit_opt,tofit_SP,terms):
        dE = []
        dX = []
        # data is: [x, E-E0, dE]
        for ffSP in tofit_SP:
            for ffOPT in tofit_opt:
                if (abs(ffSP[0]-ffOPT[0])<0.00001): # same position.
#                    print ffSP
 #                   print ffOPT
                    dE.append(ffSP[1]-ffOPT[1])
                    dX.append(0)
                    for tt in terms:
                        dX[-1] += tt[2]/tt[1]*(1 + cos((tt[4]*ffSP[0] - tt[3])/self.RadtoDeg))

        # Find the error constant.
        kk = dot(dE,linalg.pinv([dX]))[0]
        print "dE:",dE
        print "dX: ", dX
        print "kk: ", kk

        # Using the dihedral fit give an actual estimate of the error.
        error = []
        for ffSP in tofit_SP:
            found = False
            for ffOPT in tofit_opt:
                if (ffSP[0]==ffOPT[0]):
                    found = True
                    er = ffSP[1] - ffOPT[1]
                    break
            if(found):
                error.append(er)
            else:
                dx = 0
                for tt in terms:
                    dx += tt[2]/tt[1]*(1 + cos((tt[4]*ffSP[0] - tt[3])/self.RadtoDeg))
                error.append(kk*dx)

        return error

    def estimate_quad_error(self,tofit_opt,tofit_SP,R_eq):
        # Find error.
        dX = []
        dE = []

        # Find the constant to estimate the error.
        for ffSP in tofit_SP:
            for ffOPT in tofit_opt:
                if (abs(ffSP[0]-ffOPT[0])<0.00001): # same position.
                    dX.append((ffSP[0] - R_eq)**2)
                    dE.append(ffSP[1]-ffOPT[1])

        if (len(dX)==0):
            k = 0.0
        else:
            k = dot(dE,linalg.pinv([dX]))[0]
            
        print "error constant: ", k

        # Find the error for each term.
        error = []

        """
        print "BEFORE"
        for ffOPT in tofit_opt:
            print ffOPT
        del tofit_opt[0]
        print "AFTER"
        for ffOPT in tofit_opt:
            print ffOPT
        """

        for ffSP in tofit_SP:
            found = False
            for ffOPT in tofit_opt:
                if (ffSP[0]==ffOPT[0]):
                    found = True
                    er = ffSP[1] - ffOPT[1]
                    break
            if(found):
                error.append(er)
            else:
#                print "extra: ", k*(ffSP[0] - R_eq)**2
                error.append(k*(ffSP[0] - R_eq)**2)

        return error
    

