module Master

  use math_head
  use NEB_toolbox
  use Funcs
  use MatrixOps
  use skb

  implicit none

  PRIVATE
  public :: Master__run_skb, run_path, terminate_nodes
contains

  subroutine run_path(images,size)

    use fileio

    include "mpif.h"

    integer :: images,i,status(MPI_STATUS_SIZE),mpierr,iter,rem,index,j,n,size
    logical :: failed

    do i=1,images
       index = i + nroot-images
       rem = mod(i,size)

       if (rem == 0) then
          ! Run the program.
          if (natoms==3) then
             root(index)%p%V = MB(root(index)%p%pos)
             call dMB(root(index)%p%pos,root(index)%p%grad)
          else
             call run_g03(root(index)%p%pos,root(index)%p%V,root(index)%p%grad,failed,0)
             if (failed.or.norm(root(index)%p%grad,ndim)==0d0 ) then
                print*, "g03 calculation failed, exiting...";                stop
             end if
          end if

          do j=1,size-1
             call MPI_RECV(root(index-size+j)%p%V,1, MPI_DOUBLE_PRECISION, j,MPI_ANY_TAG, MPI_COMM_WORLD,status, mpierr)
             call MPI_RECV(root(index-size+j)%p%grad,ndim, MPI_DOUBLE_PRECISION, j, MPI_ANY_TAG, MPI_COMM_WORLD,status,mpierr)
             call MPI_RECV(root(index-size+j)%p%pos,fulldim, MPI_DOUBLE_PRECISION, j,MPI_ANY_TAG, MPI_COMM_WORLD,status,mpierr)
          end do
       else
          call MPI_SEND(root(index)%p%pos,fulldim, MPI_DOUBLE_PRECISION, rem,1, MPI_COMM_WORLD,mpierr)
       end if
    end do

    n = mod(images,size)
    do j=1,n
       call MPI_RECV(root(nroot-n+j)%p%V,1, MPI_DOUBLE_PRECISION, j,MPI_ANY_TAG, MPI_COMM_WORLD,status, mpierr)
       call MPI_RECV(root(nroot-n+j)%p%grad,ndim, MPI_DOUBLE_PRECISION, j, MPI_ANY_TAG, MPI_COMM_WORLD,status,mpierr)
       call MPI_RECV(root(nroot-n+j)%p%pos,fulldim, MPI_DOUBLE_PRECISION, j,MPI_ANY_TAG, MPI_COMM_WORLD,status,mpierr)
    end do

    ! Entire string saved at once.
    do i=nroot-images+1,nroot
       call save_node_to_disk(root(i)%p)
    end do

  end subroutine run_path

  subroutine terminate_nodes(size)
    include "mpif.h"

    integer :: size,mpierr,i
    real(DOUBLE),dimension(fulldim) :: buffer

    do i=1,fulldim
       buffer(i) = 0d0
    end do

    do i=1,size-1
       call MPI_SEND(buffer,fulldim, MPI_DOUBLE_PRECISION, i,1, MPI_COMM_WORLD,mpierr)
    end do

  end subroutine terminate_nodes

  subroutine Master__run_skb(path,Efull,grad,glast,oldpath,lastenergy,Hs,hssr1,trs,  &
       ftol,first_time, max_dfval)
    integer :: i,iter

    real(DOUBLE), dimension(nimages) :: energy,equad,lastenergy,const,dfvals,aligned
    real(DOUBLE) :: dftol,spaceout_dist,ftol,dftot,max_dfval,rho,spacing,old_order
    real(DOUBLE),dimension(0:nimages+1,ndim) :: path,oldpath,grad
    real(DOUBLE),dimension(0:nimages+1) :: Efull
    real(DOUBLE),dimension(nimages,ndim) :: g,glast,force,gquad,trs
    real(DOUBLE),dimension(nimages,ndim,ndim) :: Hs,HsSR1
    real(DOUBLE),dimension(ndim) :: dx

    logical :: first_time
    
    call eqconst(path,spacing)
    print*, "Points spaced out?: ", spacing

    spaceout_dist = norm(path(0,:)-path(nimages+1,:),ndim)/(5*nimages)
    print*, "space out distance: ", spaceout_dist

    g = grad(1:nimages,:)
    energy = Efull(1:nimages)

    ! #### Set up Hessians ####
    if (first_time) then
       do i=1,nimages
          call make_I(Hs(i,:,:),ndim)
!          Hs(i,:,:) = 1.0/norm(g(i,:))*Hs(i,:,:)

          if (Efull(i+1) < Efull(i)) then
             Hs(i,:,:) = Hs(i,:,:) * abs(dot_product(grad(i,:)-grad(i-1,:),grad(i,:)-grad(i-1,:)) &
                  /dot_product(grad(i,:)-grad(i-1,:),path(i,:)-path(i-1,:)))
          else
             Hs(i,:,:) = Hs(i,:,:) * abs( dot_product(grad(i,:)-grad(i+1,:),grad(i,:)-grad(i+1,:)) &
                  /dot_product(grad(i,:)-grad(i+1,:),path(i,:) -path(i+1,:)))
          end if
       end do
       
       do i=1,nimages
          dx = path(i,:) - path(i-1,:)
          call DampedBFGSupdate(Hs(i,:,:), dx,grad(i,:)-grad(i-1,:), ndim)
          dx = path(i,:) - path(i+1,:)
          call DampedBFGSupdate(Hs(i,:,:), dx,grad(i,:)-grad(i+1,:), ndim)
       end do
       old_order = order
       order = 0d0
       max_dfval = 1d9
    else 
       print*, "sum E: ", sum(energy)
       print*, "Change in E: ", energy - lastenergy

       do i=1,nimages
          if (order<1) then
             call updateTR(Hs(i,:,:), glast(i,:), oldpath(i,:), path(i,:), energy(i), lastenergy(i), &
                  trs(i,:),mintr,maxtr, ndim,rho)
          end if
          call DampedBFGSupdate(Hs(i,:,:), path(i,:) - oldpath(i,:),g(i,:)-glast(i,:), ndim)
          
          call SR1update(HsSR1(i,:,:),g(i,:)-glast(i,:), path(i,:) - oldpath(i,:),ndim)
       end do
    end if

!!$    do i=1,nimages
!!$       call ddMB(path(i,:),Hs(i,:,:))
!!$    end do

    ! ### Check if calculation has converged ###
    call funupwind(path,g,energy,nimages,ndim,force)
    do i=1,nimages
       dfvals(i) = sqrt(sum(force(i,:)**2))
       print*, "tr:", i, norm(trs(i,:),ndim)
    end do
    max_dfval = maxval(dfvals)    
    print*, "Max dfval: ", max_dfval
    print*, "Initial dfval: ", dfvals
    
    if (maxval(dfvals)<ftol) then
       print*, "Calculation converged to dftol =",ftol
       return
    end if

    ! ### Print out alignment of path ###
    call alignment(path,aligned) 
    print*, "aligned? ", aligned

    ! ### Integrate to TRs or until finished ###
    oldpath = path
    glast = g

    lastenergy = energy

    const = 1

    do iter=1,4
       print*, "const: ", iter, const
       if (iter>1) path = oldpath
       call ODESolve(path,Hs,g,energy,nimages,ndim,trs,const,dftol,ftol)
!!$       do i=1,nimages
!!$          print*, i, norm(trs(i,:),ndim),norm(path(i,:) - oldpath(i,:),ndim)
!!$       end do

       const = const /maxval(const)
       if (dftol < ftol/10) then
          exit
       end if
    end do

    print*, "done integration. Expected dftol: ",dftol

    call quad_gE(path,oldpath,Hs,g,energy,nimages,ndim,equad,gquad)
    call funupwind(path,gquad,equad,nimages,ndim,force)
    print*, "Expected change in E before potential spaceout: ", equad - lastenergy
    
    dftot = 0d0
    do i=1,nimages
       dftot = max(dftot,sqrt(sum(force(i,:)**2)))
       print*, "after dftot: ", i, sqrt(sum(force(i,:)**2))
    end do

    call eqconst(path,spacing)
    print*, "Spacing after ODEsolve: ", spacing
    do i=1,nimages
       print*, "||p_i-p_i+1||:", i, sqrt(sum((path(i,:)-path(i-1,:))**2)),sqrt(sum((path(i,:)-path(i+1,:))**2))
    end do

    ! ### Space out if necessary ###
    if (spacing > spaceout_dist) then
       do i=1,4
          call spaceoutcubic(path,nimages+2)
       end do
       call eqconst(path,spacing)
       print*, "Spacing after cubic space out: ", spacing
       call quad_gE(path,oldpath,Hs,g,energy,nimages,ndim,equad,gquad)
       call funupwind(path,gquad,equad,nimages,ndim,force)
       print*, "Expected change in E after spaceout: ", equad - lastenergy
          
       dftot = 0d0
       do i=1,nimages
          dftot = max(dftot,sqrt(sum(force(i,:)**2)))
       end do
       print*, "dftot after spaceout: ", dftot
    end if

    if (first_time) order = old_order
    first_time=.FALSE.

  end subroutine Master__run_skb

end module Master
