module skb
  use math_head
  use MatrixOps
  use NEB_toolbox
  use mvlsq

  IMPLICIT NONE

  PRIVATE
  public :: RKstep,quad_gE,odesolve,update_surface,Shepard,shepardtr,shepardtr2,updatetr_local,calc_shep_tr

  contains

    subroutine ODESolve(path,Hs,g,energy,nimages,ndim,trs, const, dftot, ftol)
      integer :: ndim, nimages, maxiter,iter,i,j
      real(DOUBLE), dimension(0:nimages+1,ndim) :: path, path0, path_new
      real(DOUBLE), dimension(nimages,ndim) :: g,force,gquad,trs
      real(DOUBLE), dimension(nimages,ndim,ndim) :: Hs
      real(DOUBLE), dimension(ndim) :: dx
      real(DOUBLE), dimension(nimages) :: energy, equad, moved, df, htot, const
      real(DOUBLE) :: h,error0,error1,hnew, dftol, dftot, ftol
      logical, dimension(nimages) :: switch,switch_old
      logical :: done, allswitch
      switch = .true.

      h=0.001d0  ! initial h
      hnew = h
!      error0 = ftol/100   ! error 
      error0 = 1.e-5
      maxiter = 2000        ! Maximum number of integration steps allowed
      path0 = path         ! save the initial path
      htot = 0             ! total time of integration
      dftol = ftol/10      ! assumed that L<100 where |x - y| < L|y(x) - y(y)|

      do iter=1,maxiter
         error1 = 1

         ! ### Move 1 RK step forward such that the error < error0 ###
         switch_old = switch

         call quad_gE(path,path0,Hs,g,energy,nimages,ndim,equad,gquad)
         call funupwind(path,gquad,equad,nimages,ndim,force)

         do i=1,nimages
            if (all_less(trs(i,:)/10d0,hnew * abs(force(i,:))* const(i),ndim)) then
!               print*, "hnew before: ", hnew,sqrt(sum(force(i,:)**2))
               hnew = norm(trs(i,:),ndim)/(50d0*sqrt(sum(force(i,:)**2))*const(i))
!               print*, "modifying hnew to: ", hnew
            end if
         end do

         do while (error1 > error0)
            h = hnew
            switch = switch_old
            call RKstep(path,path0,Hs,g,energy,nimages,ndim,h,error0,error1,path_new,hnew,switch,const)
            if (h< 1e-20) then
               print*, "Step size has become too small h = ",h
               stop
            endif
         end do

         ! ### Check status of all points ###
         path = path_new
         call quad_gE(path,path0,Hs,g,energy,nimages,ndim,equad,gquad)
         call funupwind(path,gquad,equad,nimages,ndim,force)

         do i=1,nimages
            moved(i) = sqrt(sum((path0(i,:) - path(i,:))**2))
            df(i) = sqrt(sum(force(i,:)**2))
            if (switch(i)) then
               htot(i) = htot(i) + h
            endif
         end do
!         print*, "htot: ", htot

         ! ### Break out if a point moves beyonds its tr or the calculation is converged or all points moved past valley ###
         done = .false.
         allswitch = .true.
         do i=1,nimages
            dx = path0(i,:) - path(i,:)
            if (all_less(trs(i,:),abs(dx),ndim)) then
               done = .true.
            end if
            allswitch = (allswitch .and. .not. switch(i))
         end do
!         print*, "moved: ", moved

         dftot = maxval(df)
         if (done .or. allswitch .or. (dftot < dftol)) then
            if (allswitch) print*, "All points passed over minima"
            if (done) print*, "Point moved passed TR"
            exit
         end if

         if (mod(iter,10)==10.and. iter < 50.and..false.) then
            print*, "------ ITER =",iter
            print*, "error ", error1
            print*, "h hnew: ",h,hnew
            print*, "total movement: ", moved
            print*, "switched: ", switch
            print*, "df: ", df
            print*, "dftot: ", dftot
         end if

      end do
      
      if (iter == maxiter) then
         print*, "Exceed number of integeration steps"
      endif
      
      print*, "Finished at ITER =",iter
      print*, "total movement: ", moved

      ! ### Modify const for the next run ###
      do i=1,nimages
         dx = path0(i,:) - path(i,:)
         if (.not. switch(i)) then
            const(i) = const(i) * max(0.5d0,htot(i)/maxval(htot))
         else
            if (all_less(trs(i,:)/0.5,abs(dx),ndim)) then
               const(i) = const(i) * 1/(moved(i)/norm(trs(i,:),ndim))
            else
               const(i) = const(i) * 2 
            end if
         end if
!         print*, i, const(i),switch(i)
      end do

    end subroutine ODESolve

    ! #### Quadratic approximate gradient and energy ####
    subroutine quad_gE(path,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      integer :: nimages, ndim, i
      real(DOUBLE), dimension(0:nimages+1,ndim) :: path,epath
      real(DOUBLE), dimension(nimages,ndim,ndim) :: Hs
      real(DOUBLE), dimension(nimages,ndim) :: glast,gquad,dx
      real(DOUBLE), dimension(nimages) :: energy,equad
      real(DOUBLE), dimension(ndim) :: hdx
      real(DOUBLE), dimension(ndim,ndim) :: Hnew
      real(DOUBLE) :: RMS

      dx = path(1:nimages,:) - epath(1:nimages,:)

      do i=1,nimages
         hdx = dx(i,:)
         if (order>0) then
            call Shepard(path(i,:),equad(i),gquad(i,:),RMS)
         else
            Hnew = Hs(i,:,:)
            gquad(i,:) = glast(i,:) + matmul(Hnew,hdx)
            equad(i) = energy(i) + dot_product(glast(i,:),dx(i,:)) + 0.5d0 * dot_product(hdx,matmul(Hnew,hdx))
         end if
      end do
    end subroutine quad_gE

    subroutine calc_shep_tr
      INTEGER   :: i, j, index,base
      REAL(DOUBLE) :: R(ndim),dx(ndim),trs(nimages,ndim)
      REAL(DOUBLE) ::  Vp,M
      
      shep_tr = 0d0 

      do index=1,nroot
         R = root(index)%p%R
         M = 0
         do j=1,ndim
            DO i = 1,nroot
               if (index==i) cycle
               dX = root(i)%p%R - R
               if (norm(dx,ndim)>0.00001.and.all_less(abs(dx),10d0*tr0,ndim)) then
                  M = M + 1
!                  shep_tr(index,j) = shep_tr(index,j) + 8d0/5d0*eps_shep/abs(root(i)%p%grad(j)-root(index)%p%grad(j)) * &
!                       abs(dX(j))/norm(dx,ndim)
                  shep_tr(index,j) = shep_tr(index,j) + &
                       ((root(i)%p%grad(j)-root(index)%p%grad(j))*dX(j)/(5*eps_shep))**2/dot_product(dx,dx)**(order+1)
               end if
            end do
            if (M==0) then
               print*, "No points for index: ", index
               stop
            end if
            shep_tr(index,j) = (shep_tr(index,j)/M)**(-1d0/(order + 2))
         end do
      end do

    end subroutine calc_shep_tr

    subroutine Shepardtr(trs)
      INTEGER   :: i
      REAL(DOUBLE) :: trs(nimages,ndim)
      
      do i=1,nimages
         trs(i,:) = shep_tr(nroot-nimages+i,:)
         print*, "|tr|: ", norm(trs(i,:),ndim)
      end do

    end subroutine Shepardtr


    subroutine Shepardtr2(trs)
      INTEGER   :: i, j, index,base
      REAL(DOUBLE) :: R(ndim),dx(ndim),gTotal(ndim),TT2(0:ndim)
      REAL(DOUBLE) :: T(1:nroot), Vp,TT,M,TOTAL,trs(nimages,ndim)

      base = nroot-nimages

      do index=1,nimages
         Vp = root(index+base)%p%V
         R  = root(index+base)%p%R
         M = 0d0;      TT2(0:ndim) = 0d0
         DO i = 1,nroot
            if (base+index==i) cycle
            dX = root(i)%p%R - R
!            print*,"dX", i, dX
            if (norm(dx,ndim)>0.00001.and.all_less(abs(dx),10d0*tr0,ndim)) then
               M = M + 1
               call FourthOrder(index+base,root(i)%p%R,TT,gTotal)
               TT2(0) = TT2(0) + (TT - root(i)%p%V)**2/(eps_shep**2 * norm(dX,ndim)**(2*(order+1)))
               print*, i, norm(dX,ndim),(eps_shep**2 * norm(dX,ndim)**(2*(order+1)))
               do j=1,ndim
                  TT2(j) = TT2(j) + ((root(i)%p%grad(j)-gTotal(j))*dX(j))**2/(eps_shep**2 * norm(dX,ndim)**(2*(order+1)))
               end do
!               print*, "TT2: ", TT2
!               print*, "TTT: ", (eps_shep**2 * norm(dX,ndim)**(2*(order+1))),(2*(order+1))
            end if
         ENDDO

         print*, "M: ", M,order

         do j=1,ndim
            trs(index,j) = ( 1d0/M * TT2(j) )**(-1d0/2d0*(order+1))
         end do
         print*, "|tr|: ", ( 1d0/M * TT2(0) )**(-1d0/2d0*(order+1))
         print*, "|tr|: ", norm(trs(index,:),ndim)
      end do

      stop
    end subroutine Shepardtr2

    subroutine updatetr_local(trs)
      REAL(DOUBLE) :: Vp,rho,RMS,trs(nimages,ndim),dx(ndim),grad(ndim)
      integer :: i,base

      base = nroot - nimages
      nroot = nroot - nimages ! remove most recent points from Shepard interpolation.


      do i=1,nimages
         dx = root(base+i)%p%R - root(base-nimages+i)%p%R
         call Shepard(root(base+i)%p%R,Vp,grad,RMS,.true.)
         rho = (root(base+i)%p%V - root(base-nimages+i)%p%V)/(Vp - root(base-nimages+i)%p%V)
!         print*, root(base+i)%p%V,root(base-nimages+i)%p%V,Vp,root(base-nimages+i)%p%V
 !        print*, "root(b+i): ", root(base+i)%p%R
  !       print*, "root(b-n+i): ", root(base-nimages+i)%p%R
         if (abs(root(base+i)%p%V - root(base-nimages+i)%p%V) > 1e-9 .and. norm(dx,ndim) > 1e-6) then
            if (rho < 0.2 .or. rho > 5) then
               trs(i,:) = 0.25d0*abs(dx)
            else if(rho > 0.9d0 .and. rho < 1.5d0 .and. all_less(trs(i,:),abs(dx)*1.4d0,ndim)) then
               trs(i,:) = 2d0*trs(i,:)
            else if (rho < 0.85d0 .and. all_less(abs(dx),trs(i,:)*0.8d0,ndim)) then
               trs(i,:) = trs(i,:)/2d0
            else if (rho < 0.95d0 .and. all_less(abs(dx),trs(i,:)*0.5d0,ndim)) then
               trs(i,:) = trs(i,:)/2d0
            endif
         end if
         if (trs(i,1)<mintr(1)) trs(i,:) = mintr
         if (trs(i,1)>maxtr(1)) trs(i,:) = maxtr

         print*, "rho: ",i,norm(trs(i,:),ndim),rho
      end do

      nroot = nroot + nimages ! restore points
    end subroutine updatetr_local

    Subroutine Shepard(R,Vp,grad,RMS,debug)
      INTEGER   :: i, j, k
      REAL(DOUBLE) :: W(nroot), Wv(nroot),M
      REAL(DOUBLE) :: R(ndim),dx(ndim),deriv(ndim)
      REAL(DOUBLE) :: SumW, SumNor, SumSig, Sum,no_eff,Sigma
      REAL(DOUBLE) :: T(1:nroot), Ti(1:nroot,ndim), Vp, RMS, err
      REAL(DOUBLE) :: TOTAL,gTotal(ndim),grad(ndim)
      logical,optional :: debug
      grad = 0d0

      DO i = 1,nroot
         dX = R - root(i)%p%R
         if (norm(dx,ndim)<0.00001) then
            Vp = root(i)%p%V; grad = root(i)%p%grad;            RMS = 0d0
            return
         end if
!         T(i) = root(i)%p%V + DOT_PRODUCT(root(i)%p%grad,dX)
         call FourthOrder(i,R,T(i),Ti(i,1:ndim))
!         T(i) = T(i) + Total
!         if (present(debug)) then
!            print*, i, T(i),Total,norm(dx,ndim),root(i)%p%V, DOT_PRODUCT(root(i)%p%grad,dX)
!         end if
!         Ti(i,1:ndim) = root(i)%p%grad + gTotal(1:ndim)
      ENDDO
      
      SumW = 0.0d0
      W = 0.0d0
      Wv = 0d0
      DO i = 1, nroot
         SumNor = 0d0
         dX = R - root(i)%p%R
         if (all_less(abs(dx),10d0*tr0,ndim)) then
            DO k = 1, ndim
               SumNor = SumNor + (dX(k)/shep_tr(i,k))**2
            ENDDO
            Wv(i) = EXP(-0.5d0*SumNor)/SumNor
            SumW = SumW + Wv(i)
         end if
      ENDDO

      DO i = 1, nroot
         w(i) = Wv(i)/SumW
      ENDDO

      Vp = DOT_PRODUCT(W,T)

      do i = 1,nroot
         grad = grad + Ti(i,:)*w(i)
      end do

      ! Error estimation:
      RMS = 0d0
      err = 0d0
      DO i = 1, nroot
         RMS = RMS + W(i)*(Vp-T(i))**2
         err = err + (Vp-T(i))**2
      ENDDO

      no_eff = (DOT_PRODUCT(W,W))**(-1.0d0)
      err = err/(2.0d0*(no_eff-1))
      if (isnan(vp)) then
         RMS=1d9
         Vp=1d9
         grad=1d9
         print*, "Moved too far NAN"
!         print*, Wv
!         print*, grad,Vp
      end if

      IF(no_eff <= 1.2) RMS = 10000*RMS
    end Subroutine Shepard

    ! #### Runge-Kutta explicit 4/5 integration scheme with Cash-Karp parameters -same as mathematica ####
    subroutine RKstep(path,epath,Hs,glast,energy,nimages,ndim,h,error0,error1,x5,hnew,switch,const)

      integer :: ndim, nimages, i, j
      real(DOUBLE) :: error0, h,error1,hnew
      real(DOUBLE), dimension(nimages,ndim,ndim) :: Hs
      real(DOUBLE), dimension(0:nimages+1,ndim) :: path,x5,x4,epath,path1
      real(DOUBLE), dimension(nimages,ndim) :: glast,force,gquad,force1
      real(DOUBLE), dimension(nimages) :: energy,equad,const
      real(DOUBLE), dimension(6,0:nimages+1,ndim) :: ks
      logical, dimension(nimages) :: switch
      
      if (h==0) then
         x5 = path;         error1 = 0
         return
      end if

      ks = 0

      ! ### Cash-Karp 4/5 order steps ###
      call quad_gE(path,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(path,gquad,equad,nimages,ndim,force)
      do i=1, nimages
         force(i,:) = force(i,:) * const(i)
      end do
      ks(1,1:nimages,:) = -h * force

      path1 = path + 0.2d0*ks(1,:,:)      

      call quad_gE(path1,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(path1,gquad,equad,nimages,ndim,force)

      do i=1, nimages
         force(i,:) = force(i,:) * const(i)
      end do
      ks(2,1:nimages,:) = -h * force

      path1 = path + 0.075d0*ks(1,:,:) + 0.225d0*ks(2,:,:)

      call quad_gE(path1,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(path1,gquad,equad,nimages,ndim,force)
      do i=1, nimages
         force(i,:) = force(i,:) * const(i)
      end do
      ks(3,1:nimages,:) = -h * force

      path1 = path + 3d0/10d0*ks(1,:,:) - 9d0/10d0*ks(2,:,:) + 6d0/5d0*ks(3,:,:)
      call quad_gE(path1,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(path1,gquad,equad,nimages,ndim,force)
      do i=1, nimages
         force(i,:) = force(i,:) * const(i)
      end do
      ks(4,1:nimages,:) = -h * force

      path1 = path - 11d0/54d0*ks(1,:,:) + 5d0/2d0*ks(2,:,:) - 70d0/27d0*ks(3,:,:) + 35d0/27d0*ks(4,:,:)
      call quad_gE(path1,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(path1,gquad,equad,nimages,ndim,force)
      do i=1, nimages
         force(i,:) = force(i,:) * const(i)
      end do
      ks(5,1:nimages,:) = -h * force

      path1 = path + 1631d0/55296d0*ks(1,:,:) + 175d0/512d0*ks(2,:,:) + 575d0/13824d0*ks(3,:,:) + &
           44275d0/110592d0*ks(4,:,:) + 253d0/4096d0*ks(5,:,:)

      call quad_gE(path1,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(path1,gquad,equad,nimages,ndim,force)
      do i=1, nimages
         force(i,:) = force(i,:) * const(i)
      end do
      ks(6,1:nimages,:) = -h * force

      x5 = path + ks(1,:,:)*37d0/378d0 + ks(3,:,:)*250d0/621d0 + ks(4,:,:)*125d0/594d0 + &
           ks(6,:,:) * 512d0/1771d0;
      x4 = path + ks(1,:,:)*2825d0/27648d0 + ks(3,:,:)*18575d0/48384d0 + ks(4,:,:)*13525d0/55296d0 + &
           ks(5,:,:)*277d0/14336d0 + ks(6,:,:)*1d0/4d0

!!$      ! Compare against midpoint/Euler
!!$      print*, "------ paths ---------",h
!!$      call quad_gE(path + 0.5d0*ks(1,:,:),epath,Hs,glast,energy,nimages,ndim,equad,gquad)
!!$      call funupwind(path + 0.5d0*ks(1,:,:),gquad,equad,nimages,ndim,force)
!!$      do i=1, nimages
!!$         force(i,:) = force(i,:) * const(i)
!!$      end do!
!!$      ks(6,1:nimages,:) = -h * force
!!$      print*, maxval(abs(path - ks(6,:,:) - (path + ks(1,:,:))))
!!$      print*, maxval(abs(x5 - x4))

      ! ### Get error and adjust step size ###
      error1=0
      do i=1,nimages
         do j=1,ndim
            if (abs(x5(i,j)-x4(i,j)) > error1) then
               error1 = abs(x5(i,j)-x4(i,j))
            end if
         end do
      end do
      
      if (error1 == 0) then
         hnew = 2*h
      else
         hnew = 0.9*h*abs(error0/error1)**(0.2)
!         print*, "hhh: ", error0,error1,h,hnew
      end if

      call quad_gE(x5,epath,Hs,glast,energy,nimages,ndim,equad,gquad)
      call funupwind(x5,gquad,equad,nimages,ndim,force)
      call funupwind(epath,glast,energy,nimages,ndim,force1)

      ! ### Check to see if moved over valley ###
      do j=1,nimages
         switch(j) = switch(j) .and. (dot_product(force(j,:),force1(j,:)) .ge. 0)
      end do
!      print*, "switch in: ", switch,h,error1,hnew

      do j=1,nimages
         do i=1,6
!            print*, "switch out: ", i, switch
            switch(j) = switch(j) .and. (dot_product(ks(i,j,:),force1(j,:)) .le. 0)
         end do
      end do

    end subroutine RKstep

    subroutine update_surface
      integer :: i

      do i=1,nroot
         call Deriv4(i)
      end do
      
    end subroutine update_surface

end module skb
