Module MatrixOps
  use math_head

  contains

    real(DOUBLE) function norm(a,n)
      integer :: n
      real(DOUBLE), dimension(n) :: a

      norm = sqrt(sum(a**2))
    end function norm

   subroutine normalize(vect,n)
     integer :: n
     real(DOUBLE),dimension(n) :: vect

     vect = vect / norm(vect,n)
   end subroutine normalize
    
    subroutine outer(a,b,out,n)
      integer :: n
      real(DOUBLE), dimension(n,1) :: a, b
      real(DOUBLE), dimension(n,n) :: out
      
      out = matmul(a,Transpose(b))
    end subroutine outer

    subroutine make_I(matrix,n)
      integer :: n,i
      real(DOUBLE),dimension(n,n) :: matrix
      
      matrix=0d0
      do i=1,n
         matrix(i,i)=1d0
      end do
    end subroutine make_I
      

    ! *** SR1 Hessian update ***
    subroutine SR1update(H,s,q,n)
      integer :: n
      real(DOUBLE),dimension(n,n) :: H,tmp_nn
      real(DOUBLE),dimension(n) :: q,s,tmp_n
      
      tmp_n = s - matmul(H,q)
      call outer(tmp_n,tmp_n,tmp_nn,n)
  
      H = H + tmp_nn * 1d0 / dot_product(s - matmul(H,q),q)
    end subroutine SR1update

    subroutine DampedBFGSupdate(H,s,q,n)
      integer :: n
      real(DOUBLE),dimension(n,n) :: H,tmp_nn,Ht,tmp_nn2
      real(DOUBLE),dimension(n) :: q,s,r
      real(DOUBLE) :: theta,shs
      
      if (norm(s,n) > 1e-6) then
         if (dot_product(s,q)>= 0.2*dot_product(s,matmul(H,s))) then
            theta = 1
         else
            theta = dot_product(s,matmul(H,s)) - dot_product(s,q)
            theta = 0.8*dot_product(s,matmul(H,s))/theta
         end if
      
         r = theta*q + (1 - theta)*matmul(H,s);
         
         call outer(r,r,tmp_nn,n)
         tmp_nn2 = tmp_nn/dot_product(s,r)
         
         call outer(s,s,tmp_nn,n)
         Ht = matmul(H,tmp_nn)
         shs = dot_product(s,matmul(H,s))
         tmp_nn = matmul(Ht,H)/shs

         H = H + tmp_nn2 - tmp_nn
      end if

    end subroutine DampedBFGSupdate

    ! ar1(1:ndim) < ar2(1:ndim) forall elements? 
    logical function all_less(ar1,ar2,dim)
      integer :: i,dim
      real(DOUBLE),dimension(dim) :: ar1,ar2

      all_less = .true.

      do i =1,dim
         if (ar1(i)>ar2(i)) then
            all_less = .false.;    return
         end if
      end do

    end function all_less


    subroutine updateTR(H, g0, x0, x1, f1, f0, trs, min_tr,max_tr,ndim,rho)
      integer :: ndim
      real(DOUBLE) :: f1,f0,rho
      real(DOUBLE),dimension(ndim) :: x0,x1,g0,dx,trs,min_tr,max_tr
      real(DOUBLE),dimension(ndim,ndim) :: H
      
      dx = x1 - x0

      if (abs(f1 - f0) > 1e-9 .and. norm(dx,ndim) > 1e-6) then
         rho = (f1 - f0)/(0.5d0 * dot_product(dx,matmul(H,dx)) + dot_product(dx,g0))

         if (rho < 0.3 .or. rho > 3) then
            trs = 0.25d0*abs(dx)
         else if(rho > 0.9d0 .and. rho < 1.5d0 .and. all_less(trs,abs(dx)*1.4d0,ndim)) then
            trs = 2d0*trs
         else if (rho < 0.85d0 .and. all_less(abs(dx),trs*0.8d0,ndim)) then
            trs = trs/2d0
         else if (rho < 0.95d0 .and. all_less(abs(dx),trs*0.5d0,ndim)) then
            trs = trs/2d0
         endif
      else
         rho = 1
      end if
      if (trs(1)<min_tr(1)) trs = min_tr
      if (trs(1)>max_tr(1)) trs = max_tr

      print*, "rho:",rho,norm(dx,ndim),trs(1)
      
    end subroutine updateTR

    subroutine get_inverse(n,Hinv,Hvals,Hvects,lambda)
      integer :: n,i
      real(DOUBLE), dimension(N,N) :: Hvects,Hinv,EE
      real(DOUBLE), dimension(N) :: Hvals
      real(DOUBLE),optional :: lambda
      EE = 0

      do i =1,n
         if (Hvals(i)==0d0) then
            print*, "Zero eigenvalue ... can't take inverse"
            stop
         endif

         if (present(lambda)) then
            EE(i,i) = 1d0/(Hvals(i)-lambda)
         else
            EE(i,i) = 1d0/Hvals(i)
         endif
      end do

      Hinv = matmul(Transpose(Hvects),matmul(EE,Hvects))

    end subroutine get_inverse

end Module MatrixOps
