LCOV - code coverage report
Current view: top level - src - iterate_matrix.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:1425fcd) Lines: 834 872 95.6 %
Date: 2024-05-08 07:14:22 Functions: 17 19 89.5 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : ! **************************************************************************************************
       8             : !> \brief Routines useful for iterative matrix calculations
       9             : !> \par History
      10             : !>       2010.10 created [Joost VandeVondele]
      11             : !> \author Joost VandeVondele
      12             : ! **************************************************************************************************
      13             : MODULE iterate_matrix
      14             :    USE arnoldi_api,                     ONLY: arnoldi_data_type,&
      15             :                                               arnoldi_extremal
      16             :    USE bibliography,                    ONLY: Richters2018,&
      17             :                                               cite_reference
      18             :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      19             :                                               cp_logger_get_default_unit_nr,&
      20             :                                               cp_logger_type
      21             :    USE dbcsr_api,                       ONLY: &
      22             :         dbcsr_add, dbcsr_add_on_diag, dbcsr_copy, dbcsr_create, dbcsr_desymmetrize, &
      23             :         dbcsr_distribution_get, dbcsr_distribution_type, dbcsr_filter, dbcsr_frobenius_norm, &
      24             :         dbcsr_gershgorin_norm, dbcsr_get_diag, dbcsr_get_info, dbcsr_get_matrix_type, &
      25             :         dbcsr_get_occupation, dbcsr_multiply, dbcsr_norm, dbcsr_norm_maxabsnorm, dbcsr_p_type, &
      26             :         dbcsr_release, dbcsr_scale, dbcsr_set, dbcsr_set_diag, dbcsr_trace, dbcsr_transposed, &
      27             :         dbcsr_type, dbcsr_type_no_symmetry
      28             :    USE input_constants,                 ONLY: ls_scf_submatrix_sign_direct,&
      29             :                                               ls_scf_submatrix_sign_direct_muadj,&
      30             :                                               ls_scf_submatrix_sign_direct_muadj_lowmem,&
      31             :                                               ls_scf_submatrix_sign_ns
      32             :    USE kinds,                           ONLY: dp,&
      33             :                                               int_8
      34             :    USE machine,                         ONLY: m_flush,&
      35             :                                               m_walltime
      36             :    USE mathconstants,                   ONLY: ifac
      37             :    USE mathlib,                         ONLY: abnormal_value
      38             :    USE message_passing,                 ONLY: mp_comm_type
      39             :    USE submatrix_dissection,            ONLY: submatrix_dissection_type
      40             : #include "./base/base_uses.f90"
      41             : 
      42             :    IMPLICIT NONE
      43             : 
      44             :    PRIVATE
      45             : 
      46             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'iterate_matrix'
      47             : 
      48             :    TYPE :: eigbuf
      49             :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE    :: eigvals
      50             :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: eigvecs
      51             :    END TYPE eigbuf
      52             : 
      53             :    INTERFACE purify_mcweeny
      54             :       MODULE PROCEDURE purify_mcweeny_orth, purify_mcweeny_nonorth
      55             :    END INTERFACE
      56             : 
      57             :    PUBLIC :: invert_Hotelling, matrix_sign_Newton_Schulz, matrix_sqrt_Newton_Schulz, &
      58             :              matrix_sqrt_proot, matrix_sign_proot, matrix_sign_submatrix, matrix_exponential, &
      59             :              matrix_sign_submatrix_mu_adjust, purify_mcweeny, invert_Taylor, determinant
      60             : 
      61             : CONTAINS
      62             : 
      63             : ! *****************************************************************************
      64             : !> \brief Computes the determinant of a symmetric positive definite matrix
      65             : !>        using the trace of the matrix logarithm via Mercator series:
      66             : !>         det(A) = det(S)det(I+X)det(S), where S=diag(sqrt(Aii),..,sqrt(Ann))
      67             : !>         det(I+X) = Exp(Trace(Ln(I+X)))
      68             : !>         Ln(I+X) = X - X^2/2 + X^3/3 - X^4/4 + ..
      69             : !>        The series converges only if the Frobenius norm of X is less than 1.
      70             : !>        If it is more than one we compute (recursevily) the determinant of
      71             : !>        the square root of (I+X).
      72             : !> \param matrix ...
      73             : !> \param det - determinant
      74             : !> \param threshold ...
      75             : !> \par History
      76             : !>       2015.04 created [Rustam Z Khaliullin]
      77             : !> \author Rustam Z. Khaliullin
      78             : ! **************************************************************************************************
      79         132 :    RECURSIVE SUBROUTINE determinant(matrix, det, threshold)
      80             : 
      81             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix
      82             :       REAL(KIND=dp), INTENT(INOUT)                       :: det
      83             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
      84             : 
      85             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'determinant'
      86             : 
      87             :       INTEGER                                            :: handle, i, max_iter_lanczos, nsize, &
      88             :                                                             order_lanczos, sign_iter, unit_nr
      89             :       INTEGER(KIND=int_8)                                :: flop1
      90             :       INTEGER, SAVE                                      :: recursion_depth = 0
      91             :       REAL(KIND=dp)                                      :: det0, eps_lanczos, frobnorm, maxnorm, &
      92             :                                                             occ_matrix, t1, t2, trace
      93         132 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: diagonal
      94             :       TYPE(cp_logger_type), POINTER                      :: logger
      95             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
      96             : 
      97         132 :       CALL timeset(routineN, handle)
      98             : 
      99             :       ! get a useful output_unit
     100         132 :       logger => cp_get_default_logger()
     101         132 :       IF (logger%para_env%is_source()) THEN
     102          66 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     103             :       ELSE
     104          66 :          unit_nr = -1
     105             :       END IF
     106             : 
     107             :       ! Note: tmp1 and tmp2 have the same matrix type as the
     108             :       ! initial matrix (tmp3 does not have symmetry constraints)
     109             :       ! this might lead to uninteded results with anti-symmetric
     110             :       ! matrices
     111             :       CALL dbcsr_create(tmp1, template=matrix, &
     112         132 :                         matrix_type=dbcsr_type_no_symmetry)
     113             :       CALL dbcsr_create(tmp2, template=matrix, &
     114         132 :                         matrix_type=dbcsr_type_no_symmetry)
     115             :       CALL dbcsr_create(tmp3, template=matrix, &
     116         132 :                         matrix_type=dbcsr_type_no_symmetry)
     117             : 
     118             :       ! compute the product of the diagonal elements
     119             :       BLOCK
     120             :          TYPE(mp_comm_type) :: group
     121             :          INTEGER :: group_handle
     122         132 :          CALL dbcsr_get_info(matrix, nfullrows_total=nsize, group=group_handle)
     123         132 :          CALL group%set_handle(group_handle)
     124         396 :          ALLOCATE (diagonal(nsize))
     125         132 :          CALL dbcsr_get_diag(matrix, diagonal)
     126         132 :          CALL group%sum(diagonal)
     127        2308 :          det = PRODUCT(diagonal)
     128             :       END BLOCK
     129             : 
     130             :       ! create diagonal SQRTI matrix
     131        2176 :       diagonal(:) = 1.0_dp/(SQRT(diagonal(:)))
     132             :       !ROLL CALL dbcsr_copy(tmp1,matrix)
     133         132 :       CALL dbcsr_desymmetrize(matrix, tmp1)
     134         132 :       CALL dbcsr_set(tmp1, 0.0_dp)
     135         132 :       CALL dbcsr_set_diag(tmp1, diagonal)
     136         132 :       CALL dbcsr_filter(tmp1, threshold)
     137         132 :       DEALLOCATE (diagonal)
     138             : 
     139             :       ! normalize the main diagonal, off-diagonal elements are scaled to
     140             :       ! make the norm of the matrix less than 1
     141             :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     142             :                           matrix, &
     143             :                           tmp1, &
     144             :                           0.0_dp, tmp3, &
     145         132 :                           filter_eps=threshold)
     146             :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     147             :                           tmp1, &
     148             :                           tmp3, &
     149             :                           0.0_dp, tmp2, &
     150         132 :                           filter_eps=threshold)
     151             : 
     152             :       ! subtract the main diagonal to create matrix X
     153         132 :       CALL dbcsr_add_on_diag(tmp2, -1.0_dp)
     154         132 :       frobnorm = dbcsr_frobenius_norm(tmp2)
     155         132 :       IF (unit_nr > 0) THEN
     156          66 :          IF (recursion_depth .EQ. 0) THEN
     157          41 :             WRITE (unit_nr, '()')
     158             :          ELSE
     159             :             WRITE (unit_nr, '(T6,A28,1X,I15)') &
     160          25 :                "Recursive iteration:", recursion_depth
     161             :          END IF
     162             :          WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     163          66 :             "Frobenius norm:", frobnorm
     164          66 :          CALL m_flush(unit_nr)
     165             :       END IF
     166             : 
     167         132 :       IF (frobnorm .GE. 1.0_dp) THEN
     168             : 
     169          50 :          CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
     170             :          ! these controls should be provided as input
     171          50 :          order_lanczos = 3
     172          50 :          eps_lanczos = 1.0E-4_dp
     173          50 :          max_iter_lanczos = 40
     174             :          CALL matrix_sqrt_Newton_Schulz( &
     175             :             tmp3, & ! output sqrt
     176             :             tmp1, & ! output sqrti
     177             :             tmp2, & ! input original
     178             :             threshold=threshold, &
     179             :             order=order_lanczos, &
     180             :             eps_lanczos=eps_lanczos, &
     181          50 :             max_iter_lanczos=max_iter_lanczos)
     182          50 :          recursion_depth = recursion_depth + 1
     183          50 :          CALL determinant(tmp3, det0, threshold)
     184          50 :          recursion_depth = recursion_depth - 1
     185          50 :          det = det*det0*det0
     186             : 
     187             :       ELSE
     188             : 
     189             :          ! create accumulator
     190          82 :          CALL dbcsr_copy(tmp1, tmp2)
     191             :          ! re-create to make use of symmetry
     192             :          !ROLL CALL dbcsr_create(tmp3,template=matrix)
     193             : 
     194          82 :          IF (unit_nr > 0) WRITE (unit_nr, *)
     195             : 
     196             :          ! initialize the sign of the term
     197          82 :          sign_iter = -1
     198        1078 :          DO i = 1, 100
     199             : 
     200        1078 :             t1 = m_walltime()
     201             : 
     202             :             ! multiply X^i by X
     203             :             ! note that the first iteration evaluates X^2
     204             :             ! because the trace of X^1 is zero by construction
     205             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, &
     206             :                                 0.0_dp, tmp3, &
     207             :                                 filter_eps=threshold, &
     208        1078 :                                 flop=flop1)
     209        1078 :             CALL dbcsr_copy(tmp1, tmp3)
     210             : 
     211             :             ! get trace
     212        1078 :             CALL dbcsr_trace(tmp1, trace)
     213        1078 :             trace = trace*sign_iter/(1.0_dp*(i + 1))
     214        1078 :             sign_iter = -sign_iter
     215             : 
     216             :             ! update the determinant
     217        1078 :             det = det*EXP(trace)
     218             : 
     219        1078 :             occ_matrix = dbcsr_get_occupation(tmp1)
     220             :             CALL dbcsr_norm(tmp1, &
     221        1078 :                             dbcsr_norm_maxabsnorm, norm_scalar=maxnorm)
     222             : 
     223        1078 :             t2 = m_walltime()
     224             : 
     225        1078 :             IF (unit_nr > 0) THEN
     226             :                WRITE (unit_nr, '(T6,A,1X,I3,1X,F7.5,F16.10,F10.3,F11.3)') &
     227         539 :                   "Determinant iter", i, occ_matrix, &
     228         539 :                   det, t2 - t1, &
     229        1078 :                   flop1/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     230         539 :                CALL m_flush(unit_nr)
     231             :             END IF
     232             : 
     233             :             ! exit if the trace is close to zero
     234        2156 :             IF (maxnorm < threshold) EXIT
     235             : 
     236             :          END DO ! end iterations
     237             : 
     238          82 :          IF (unit_nr > 0) THEN
     239          41 :             WRITE (unit_nr, '()')
     240          41 :             CALL m_flush(unit_nr)
     241             :          END IF
     242             : 
     243             :       END IF ! decide to do sqrt or not
     244             : 
     245         132 :       IF (unit_nr > 0) THEN
     246          66 :          IF (recursion_depth .EQ. 0) THEN
     247             :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     248          41 :                "Final determinant:", det
     249          41 :             WRITE (unit_nr, '()')
     250             :          ELSE
     251             :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     252          25 :                "Recursive determinant:", det
     253             :          END IF
     254          66 :          CALL m_flush(unit_nr)
     255             :       END IF
     256             : 
     257         132 :       CALL dbcsr_release(tmp1)
     258         132 :       CALL dbcsr_release(tmp2)
     259         132 :       CALL dbcsr_release(tmp3)
     260             : 
     261         132 :       CALL timestop(handle)
     262             : 
     263         132 :    END SUBROUTINE determinant
     264             : 
     265             : ! **************************************************************************************************
     266             : !> \brief invert a symmetric positive definite diagonally dominant matrix
     267             : !> \param matrix_inverse ...
     268             : !> \param matrix ...
     269             : !> \param threshold convergence threshold nased on the max abs
     270             : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     271             : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     272             : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     273             : !> \param accelerator_order ...
     274             : !> \param max_iter_lanczos ...
     275             : !> \param eps_lanczos ...
     276             : !> \param silent ...
     277             : !> \par History
     278             : !>       2010.10 created [Joost VandeVondele]
     279             : !>       2011.10 guess option added [Rustam Z Khaliullin]
     280             : !> \author Joost VandeVondele
     281             : ! **************************************************************************************************
     282          26 :    SUBROUTINE invert_Taylor(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     283             :                             norm_convergence, filter_eps, accelerator_order, &
     284             :                             max_iter_lanczos, eps_lanczos, silent)
     285             : 
     286             :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     287             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     288             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     289             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     290             :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     291             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     292             :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     293             : 
     294             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Taylor'
     295             : 
     296             :       INTEGER                                            :: accelerator_type, handle, i, &
     297             :                                                             my_max_iter_lanczos, nrows, unit_nr
     298             :       INTEGER(KIND=int_8)                                :: flop2
     299             :       LOGICAL                                            :: converged, use_inv_guess
     300             :       REAL(KIND=dp)                                      :: coeff, convergence, maxnorm_matrix, &
     301             :                                                             my_eps_lanczos, occ_matrix, t1, t2
     302          26 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: p_diagonal
     303             :       TYPE(cp_logger_type), POINTER                      :: logger
     304             :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2, tmp3_sym
     305             : 
     306          26 :       CALL timeset(routineN, handle)
     307             : 
     308          26 :       logger => cp_get_default_logger()
     309          26 :       IF (logger%para_env%is_source()) THEN
     310          13 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     311             :       ELSE
     312          13 :          unit_nr = -1
     313             :       END IF
     314          26 :       IF (PRESENT(silent)) THEN
     315          26 :          IF (silent) unit_nr = -1
     316             :       END IF
     317             : 
     318          26 :       convergence = threshold
     319          26 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     320             : 
     321          26 :       accelerator_type = 0
     322          26 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     323           0 :       IF (accelerator_type .GT. 1) accelerator_type = 1
     324             : 
     325          26 :       use_inv_guess = .FALSE.
     326          26 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     327             : 
     328          26 :       my_max_iter_lanczos = 64
     329          26 :       my_eps_lanczos = 1.0E-3_dp
     330          26 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     331          26 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     332             : 
     333          26 :       CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     334          26 :       CALL dbcsr_create(tmp2, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     335          26 :       CALL dbcsr_create(tmp3_sym, template=matrix_inverse)
     336             : 
     337          26 :       CALL dbcsr_get_info(matrix, nfullrows_total=nrows)
     338          78 :       ALLOCATE (p_diagonal(nrows))
     339             : 
     340             :       ! generate the initial guess
     341          26 :       IF (.NOT. use_inv_guess) THEN
     342             : 
     343          26 :          SELECT CASE (accelerator_type)
     344             :          CASE (0)
     345             :             ! use tmp1 to hold off-diagonal elements
     346          26 :             CALL dbcsr_desymmetrize(matrix, tmp1)
     347         858 :             p_diagonal(:) = 0.0_dp
     348          26 :             CALL dbcsr_set_diag(tmp1, p_diagonal)
     349             :             !CALL dbcsr_print(tmp1)
     350             :             ! invert the main diagonal
     351          26 :             CALL dbcsr_get_diag(matrix, p_diagonal)
     352         858 :             DO i = 1, nrows
     353         858 :                IF (p_diagonal(i) .NE. 0.0_dp) THEN
     354         416 :                   p_diagonal(i) = 1.0_dp/p_diagonal(i)
     355             :                END IF
     356             :             END DO
     357          26 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     358          26 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     359          26 :             CALL dbcsr_set_diag(matrix_inverse, p_diagonal)
     360             :          CASE DEFAULT
     361          26 :             CPABORT("Illegal accelerator order")
     362             :          END SELECT
     363             : 
     364             :       ELSE
     365             : 
     366           0 :          CPABORT("Guess is NYI")
     367             : 
     368             :       END IF
     369             : 
     370             :       CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, &
     371          26 :                           0.0_dp, tmp2, filter_eps=filter_eps)
     372             : 
     373          26 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     374             : 
     375             :       ! scale the approximate inverse to be within the convergence radius
     376          26 :       t1 = m_walltime()
     377             : 
     378             :       ! done with the initial guess, start iterations
     379          26 :       converged = .FALSE.
     380          26 :       CALL dbcsr_desymmetrize(matrix_inverse, tmp1)
     381          26 :       coeff = 1.0_dp
     382         284 :       DO i = 1, 100
     383             : 
     384             :          ! coeff = +/- 1
     385         284 :          coeff = -1.0_dp*coeff
     386             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, &
     387             :                              tmp3_sym, &
     388         284 :                              flop=flop2, filter_eps=filter_eps)
     389             :          !flop=flop2)
     390         284 :          CALL dbcsr_add(matrix_inverse, tmp3_sym, 1.0_dp, coeff)
     391         284 :          CALL dbcsr_release(tmp1)
     392         284 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     393         284 :          CALL dbcsr_desymmetrize(tmp3_sym, tmp1)
     394             : 
     395             :          ! for the convergence check
     396             :          CALL dbcsr_norm(tmp3_sym, &
     397         284 :                          dbcsr_norm_maxabsnorm, norm_scalar=maxnorm_matrix)
     398             : 
     399         284 :          t2 = m_walltime()
     400         284 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     401             : 
     402         284 :          IF (unit_nr > 0) THEN
     403         142 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Taylor iter", i, occ_matrix, &
     404         142 :                maxnorm_matrix, t2 - t1, &
     405         284 :                flop2/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     406         142 :             CALL m_flush(unit_nr)
     407             :          END IF
     408             : 
     409         284 :          IF (maxnorm_matrix < convergence) THEN
     410             :             converged = .TRUE.
     411             :             EXIT
     412             :          END IF
     413             : 
     414         258 :          t1 = m_walltime()
     415             : 
     416             :       END DO
     417             : 
     418             :       !last convergence check
     419             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_inverse, 0.0_dp, tmp1, &
     420          26 :                           filter_eps=filter_eps)
     421          26 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     422             :       !frob_matrix =  dbcsr_frobenius_norm(tmp1)
     423          26 :       CALL dbcsr_norm(tmp1, dbcsr_norm_maxabsnorm, norm_scalar=maxnorm_matrix)
     424          26 :       IF (unit_nr > 0) THEN
     425          13 :          WRITE (unit_nr, '(T6,A,E12.5)') "Final Taylor error", maxnorm_matrix
     426          13 :          WRITE (unit_nr, '()')
     427          13 :          CALL m_flush(unit_nr)
     428             :       END IF
     429          26 :       IF (maxnorm_matrix > convergence) THEN
     430           0 :          converged = .FALSE.
     431           0 :          IF (unit_nr > 0) THEN
     432           0 :             WRITE (unit_nr, *) 'Final convergence check failed'
     433             :          END IF
     434             :       END IF
     435             : 
     436          26 :       IF (.NOT. converged) THEN
     437           0 :          CPABORT("Taylor inversion did not converge")
     438             :       END IF
     439             : 
     440          26 :       CALL dbcsr_release(tmp1)
     441          26 :       CALL dbcsr_release(tmp2)
     442          26 :       CALL dbcsr_release(tmp3_sym)
     443             : 
     444          26 :       DEALLOCATE (p_diagonal)
     445             : 
     446          26 :       CALL timestop(handle)
     447             : 
     448          52 :    END SUBROUTINE invert_Taylor
     449             : 
     450             : ! **************************************************************************************************
     451             : !> \brief invert a symmetric positive definite matrix by Hotelling's method
     452             : !>        explicit symmetrization makes this code not suitable for other matrix types
     453             : !>        Currently a bit messy with the options, to to be cleaned soon
     454             : !> \param matrix_inverse ...
     455             : !> \param matrix ...
     456             : !> \param threshold convergence threshold nased on the max abs
     457             : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     458             : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     459             : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     460             : !> \param accelerator_order ...
     461             : !> \param max_iter_lanczos ...
     462             : !> \param eps_lanczos ...
     463             : !> \param silent ...
     464             : !> \par History
     465             : !>       2010.10 created [Joost VandeVondele]
     466             : !>       2011.10 guess option added [Rustam Z Khaliullin]
     467             : !> \author Joost VandeVondele
     468             : ! **************************************************************************************************
     469        2032 :    SUBROUTINE invert_Hotelling(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     470             :                                norm_convergence, filter_eps, accelerator_order, &
     471             :                                max_iter_lanczos, eps_lanczos, silent)
     472             : 
     473             :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     474             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     475             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     476             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     477             :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     478             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     479             :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     480             : 
     481             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Hotelling'
     482             : 
     483             :       INTEGER                                            :: accelerator_type, handle, i, &
     484             :                                                             my_max_iter_lanczos, unit_nr
     485             :       INTEGER(KIND=int_8)                                :: flop1, flop2
     486             :       LOGICAL                                            :: arnoldi_converged, converged, &
     487             :                                                             use_inv_guess
     488             :       REAL(KIND=dp) :: convergence, frob_matrix, gershgorin_norm, max_ev, maxnorm_matrix, min_ev, &
     489             :          my_eps_lanczos, my_filter_eps, occ_matrix, scalingf, t1, t2
     490             :       TYPE(cp_logger_type), POINTER                      :: logger
     491             :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2
     492             : 
     493             :       !TYPE(arnoldi_data_type)                            :: my_arnoldi
     494             :       !TYPE(dbcsr_p_type), DIMENSION(1)                   :: mymat
     495             : 
     496        2032 :       CALL timeset(routineN, handle)
     497             : 
     498        2032 :       logger => cp_get_default_logger()
     499        2032 :       IF (logger%para_env%is_source()) THEN
     500        1016 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     501             :       ELSE
     502        1016 :          unit_nr = -1
     503             :       END IF
     504        2032 :       IF (PRESENT(silent)) THEN
     505        2014 :          IF (silent) unit_nr = -1
     506             :       END IF
     507             : 
     508        2032 :       convergence = threshold
     509        2032 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     510             : 
     511        2032 :       accelerator_type = 1
     512        2032 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     513        1436 :       IF (accelerator_type .GT. 1) accelerator_type = 1
     514             : 
     515        2032 :       use_inv_guess = .FALSE.
     516        2032 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     517             : 
     518        2032 :       my_max_iter_lanczos = 64
     519        2032 :       my_eps_lanczos = 1.0E-3_dp
     520        2032 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     521        2032 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     522             : 
     523        2032 :       my_filter_eps = threshold
     524        2032 :       IF (PRESENT(filter_eps)) my_filter_eps = filter_eps
     525             : 
     526             :       ! generate the initial guess
     527        2032 :       IF (.NOT. use_inv_guess) THEN
     528             : 
     529           0 :          SELECT CASE (accelerator_type)
     530             :          CASE (0)
     531           0 :             gershgorin_norm = dbcsr_gershgorin_norm(matrix)
     532           0 :             frob_matrix = dbcsr_frobenius_norm(matrix)
     533           0 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     534           0 :             CALL dbcsr_add_on_diag(matrix_inverse, 1/MIN(gershgorin_norm, frob_matrix))
     535             :          CASE (1)
     536             :             ! initialize matrix to unity and use arnoldi (below) to scale it into the convergence range
     537        1558 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     538        1558 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     539             :          CASE DEFAULT
     540        1558 :             CPABORT("Illegal accelerator order")
     541             :          END SELECT
     542             : 
     543             :          ! everything commutes, therefore our all products will be symmetric
     544        1558 :          CALL dbcsr_create(tmp1, template=matrix_inverse)
     545             : 
     546             :       ELSE
     547             : 
     548             :          ! It is unlikely that our guess will commute with the matrix, therefore the first product will
     549             :          ! be non symmetric
     550         474 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     551             : 
     552             :       END IF
     553             : 
     554        2032 :       CALL dbcsr_create(tmp2, template=matrix_inverse)
     555             : 
     556        2032 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     557             : 
     558             :       ! scale the approximate inverse to be within the convergence radius
     559        2032 :       t1 = m_walltime()
     560             : 
     561             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     562        2032 :                           0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     563             : 
     564        2032 :       IF (accelerator_type == 1) THEN
     565             : 
     566             :          ! scale the matrix to get into the convergence range
     567             :          CALL arnoldi_extremal(tmp1, max_eV, min_eV, threshold=my_eps_lanczos, &
     568        2032 :                                max_iter=my_max_iter_lanczos, converged=arnoldi_converged)
     569             :          !mymat(1)%matrix => tmp1
     570             :          !CALL setup_arnoldi_data(my_arnoldi, mymat, max_iter=30, threshold=1.0E-3_dp, selection_crit=1, &
     571             :          !                        nval_request=2, nrestarts=2, generalized_ev=.FALSE., iram=.TRUE.)
     572             :          !CALL arnoldi_ev(mymat, my_arnoldi)
     573             :          !max_eV = REAL(get_selected_ritz_val(my_arnoldi, 2), dp)
     574             :          !min_eV = REAL(get_selected_ritz_val(my_arnoldi, 1), dp)
     575             :          !CALL deallocate_arnoldi_data(my_arnoldi)
     576             : 
     577        2032 :          IF (unit_nr > 0) THEN
     578         768 :             WRITE (unit_nr, *)
     579         768 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", my_eps_lanczos
     580         768 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_eV, min_eV
     581         768 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_eV/MAX(min_eV, EPSILON(min_eV))
     582             :          END IF
     583             : 
     584             :          ! 2.0 would be the correct scaling however, we should make sure here, that we are in the convergence radius
     585        2032 :          scalingf = 1.9_dp/(max_eV + min_eV)
     586        2032 :          CALL dbcsr_scale(tmp1, scalingf)
     587        2032 :          CALL dbcsr_scale(matrix_inverse, scalingf)
     588        2032 :          min_ev = min_ev*scalingf
     589             : 
     590             :       END IF
     591             : 
     592             :       ! done with the initial guess, start iterations
     593        2032 :       converged = .FALSE.
     594        9000 :       DO i = 1, 100
     595             : 
     596             :          ! tmp1 = S^-1 S
     597             : 
     598             :          ! for the convergence check
     599        9000 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     600             :          CALL dbcsr_norm(tmp1, &
     601        9000 :                          dbcsr_norm_maxabsnorm, norm_scalar=maxnorm_matrix)
     602        9000 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     603             : 
     604             :          ! tmp2 = S^-1 S S^-1
     605             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, 0.0_dp, tmp2, &
     606        9000 :                              flop=flop2, filter_eps=my_filter_eps)
     607             :          ! S^-1_{n+1} = 2 S^-1 - S^-1 S S^-1
     608        9000 :          CALL dbcsr_add(matrix_inverse, tmp2, 2.0_dp, -1.0_dp)
     609             : 
     610        9000 :          CALL dbcsr_filter(matrix_inverse, my_filter_eps)
     611        9000 :          t2 = m_walltime()
     612        9000 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     613             : 
     614             :          ! use the scalar form of the algorithm to trace the EV
     615        9000 :          IF (accelerator_type == 1) THEN
     616        9000 :             min_ev = min_ev*(2.0_dp - min_ev)
     617        9000 :             IF (PRESENT(norm_convergence)) maxnorm_matrix = ABS(min_eV - 1.0_dp)
     618             :          END IF
     619             : 
     620        9000 :          IF (unit_nr > 0) THEN
     621        3718 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Hotelling iter", i, occ_matrix, &
     622        3718 :                maxnorm_matrix, t2 - t1, &
     623        7436 :                (flop1 + flop2)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     624        3718 :             CALL m_flush(unit_nr)
     625             :          END IF
     626             : 
     627        9000 :          IF (maxnorm_matrix < convergence) THEN
     628             :             converged = .TRUE.
     629             :             EXIT
     630             :          END IF
     631             : 
     632             :          ! scale the matrix for improved convergence
     633        6968 :          IF (accelerator_type == 1) THEN
     634        6968 :             min_ev = min_ev*2.0_dp/(min_ev + 1.0_dp)
     635        6968 :             CALL dbcsr_scale(matrix_inverse, 2.0_dp/(min_ev + 1.0_dp))
     636             :          END IF
     637             : 
     638        6968 :          t1 = m_walltime()
     639             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     640        6968 :                              0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     641             : 
     642             :       END DO
     643             : 
     644        2032 :       IF (.NOT. converged) THEN
     645           0 :          CPABORT("Hotelling inversion did not converge")
     646             :       END IF
     647             : 
     648             :       ! try to symmetrize the output matrix
     649        2032 :       IF (dbcsr_get_matrix_type(matrix_inverse) == dbcsr_type_no_symmetry) THEN
     650         100 :          CALL dbcsr_transposed(tmp2, matrix_inverse)
     651        2132 :          CALL dbcsr_add(matrix_inverse, tmp2, 0.5_dp, 0.5_dp)
     652             :       END IF
     653             : 
     654        2032 :       IF (unit_nr > 0) THEN
     655             : !           WRITE(unit_nr,'(T6,A,1X,I3,1X,F10.8,E12.3)') "Final Hotelling ",i,occ_matrix,&
     656             : !              !frob_matrix/frob_matrix_base
     657             : !              maxnorm_matrix
     658         768 :          WRITE (unit_nr, '()')
     659         768 :          CALL m_flush(unit_nr)
     660             :       END IF
     661             : 
     662        2032 :       CALL dbcsr_release(tmp1)
     663        2032 :       CALL dbcsr_release(tmp2)
     664             : 
     665        2032 :       CALL timestop(handle)
     666             : 
     667        2032 :    END SUBROUTINE invert_Hotelling
     668             : 
     669             : ! **************************************************************************************************
     670             : !> \brief compute the sign a matrix using Newton-Schulz iterations
     671             : !> \param matrix_sign ...
     672             : !> \param matrix ...
     673             : !> \param threshold ...
     674             : !> \param sign_order ...
     675             : !> \par History
     676             : !>       2010.10 created [Joost VandeVondele]
     677             : !>       2019.05 extended to order byxond 2 [Robert Schade]
     678             : !> \author Joost VandeVondele, Robert Schade
     679             : ! **************************************************************************************************
     680        1058 :    SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order)
     681             : 
     682             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
     683             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     684             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
     685             : 
     686             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_Newton_Schulz'
     687             : 
     688             :       INTEGER                                            :: count, handle, i, order, unit_nr
     689             :       INTEGER(KIND=int_8)                                :: flops
     690             :       REAL(KIND=dp)                                      :: a0, a1, a2, a3, a4, a5, floptot, &
     691             :                                                             frob_matrix, frob_matrix_base, &
     692             :                                                             gersh_matrix, occ_matrix, prefactor, &
     693             :                                                             t1, t2
     694             :       TYPE(cp_logger_type), POINTER                      :: logger
     695             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3, tmp4
     696             : 
     697        1058 :       CALL timeset(routineN, handle)
     698             : 
     699        1058 :       logger => cp_get_default_logger()
     700        1058 :       IF (logger%para_env%is_source()) THEN
     701         529 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     702             :       ELSE
     703         529 :          unit_nr = -1
     704             :       END IF
     705             : 
     706        1058 :       IF (PRESENT(sign_order)) THEN
     707        1058 :          order = sign_order
     708             :       ELSE
     709             :          order = 2
     710             :       END IF
     711             : 
     712        1058 :       CALL dbcsr_create(tmp1, template=matrix_sign)
     713             : 
     714        1058 :       CALL dbcsr_create(tmp2, template=matrix_sign)
     715        1058 :       IF (ABS(order) .GE. 4) THEN
     716           8 :          CALL dbcsr_create(tmp3, template=matrix_sign)
     717             :       END IF
     718           8 :       IF (ABS(order) .GT. 4) THEN
     719           6 :          CALL dbcsr_create(tmp4, template=matrix_sign)
     720             :       END IF
     721             : 
     722        1058 :       CALL dbcsr_copy(matrix_sign, matrix)
     723        1058 :       CALL dbcsr_filter(matrix_sign, threshold)
     724             : 
     725             :       ! scale the matrix to get into the convergence range
     726        1058 :       frob_matrix = dbcsr_frobenius_norm(matrix_sign)
     727        1058 :       gersh_matrix = dbcsr_gershgorin_norm(matrix_sign)
     728        1058 :       CALL dbcsr_scale(matrix_sign, 1/MIN(frob_matrix, gersh_matrix))
     729             : 
     730        1058 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     731             : 
     732        1058 :       count = 0
     733       13202 :       DO i = 1, 100
     734       13202 :          floptot = 0_dp
     735       13202 :          t1 = m_walltime()
     736             :          ! tmp1 = X * X
     737             :          CALL dbcsr_multiply("N", "N", -1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     738       13202 :                              filter_eps=threshold, flop=flops)
     739       13202 :          floptot = floptot + flops
     740             : 
     741             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
     742       13202 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     743       13202 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     744       13202 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
     745             : 
     746             :          ! f(y) approx 1/sqrt(1-y)
     747             :          ! f(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     748             :          ! f2(y)=1+y/2=1/2*(2+y)
     749             :          ! f3(y)=1+y/2+3/8*y^2=3/8*(8/3+4/3*y+y^2)
     750             :          ! f4(y)=1+y/2+3/8*y^2+5/16*y^3=5/16*(16/5+8/5*y+6/5*y^2+y^3)
     751             :          ! f5(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4=35/128*(128/35+128/70*y+48/35*y^2+8/7*y^3+y^4)
     752             :          !      z(y)=(y+a_0)*y+a_1
     753             :          ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     754             :          !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     755             :          !    a_0=1/14
     756             :          !    a_1=23819/13720
     757             :          !    a_2=1269/980-2a_1=-3734/1715
     758             :          !    a_3=832591127/188238400
     759             :          ! f6(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5
     760             :          !      =63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     761             :          ! f7(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     762             :          !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     763             :          ! z(y)=(y+a_0)*y+a_1
     764             :          ! w(y)=(y+a_2)*z(y)+a_3
     765             :          ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     766             :          ! a_0= 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422
     767             :          ! a_1= 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568
     768             :          ! a_2=-1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968
     769             :          ! a_3= 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453
     770             :          ! a_4=-3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667
     771             :          ! a_5= 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578
     772             :          !
     773             :          ! y=1-X*X
     774             : 
     775             :          ! tmp1 = I-x*x
     776             :          IF (order .EQ. 2) THEN
     777       13156 :             prefactor = 0.5_dp
     778             : 
     779             :             ! update the above to 3*I-X*X
     780       13156 :             CALL dbcsr_add_on_diag(tmp1, +2.0_dp)
     781       13156 :             occ_matrix = dbcsr_get_occupation(matrix_sign)
     782             :          ELSE IF (order .EQ. 3) THEN
     783             :             ! with one multiplication
     784             :             ! tmp1=y
     785          12 :             CALL dbcsr_copy(tmp2, tmp1)
     786          12 :             CALL dbcsr_scale(tmp1, 4.0_dp/3.0_dp)
     787          12 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp/3.0_dp)
     788             : 
     789             :             ! tmp2=y^2
     790             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp2, 1.0_dp, tmp1, &
     791          12 :                                 filter_eps=threshold, flop=flops)
     792          12 :             floptot = floptot + flops
     793          12 :             prefactor = 3.0_dp/8.0_dp
     794             : 
     795             :          ELSE IF (order .EQ. 4) THEN
     796             :             ! with two multiplications
     797             :             ! tmp1=y
     798          10 :             CALL dbcsr_copy(tmp3, tmp1)
     799          10 :             CALL dbcsr_scale(tmp1, 8.0_dp/5.0_dp)
     800          10 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp/5.0_dp)
     801             : 
     802             :             !
     803             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     804          10 :                                 filter_eps=threshold, flop=flops)
     805          10 :             floptot = floptot + flops
     806             : 
     807          10 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp/5.0_dp)
     808             : 
     809             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     810          10 :                                 filter_eps=threshold, flop=flops)
     811          10 :             floptot = floptot + flops
     812             : 
     813          10 :             prefactor = 5.0_dp/16.0_dp
     814             :          ELSE IF (order .EQ. -5) THEN
     815             :             ! with three multiplications
     816             :             ! tmp1=y
     817           0 :             CALL dbcsr_copy(tmp3, tmp1)
     818           0 :             CALL dbcsr_scale(tmp1, 128.0_dp/70.0_dp)
     819           0 :             CALL dbcsr_add_on_diag(tmp1, 128.0_dp/35.0_dp)
     820             : 
     821             :             !
     822             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     823           0 :                                 filter_eps=threshold, flop=flops)
     824           0 :             floptot = floptot + flops
     825             : 
     826           0 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 48.0_dp/35.0_dp)
     827             : 
     828             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     829           0 :                                 filter_eps=threshold, flop=flops)
     830           0 :             floptot = floptot + flops
     831             : 
     832           0 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 8.0_dp/7.0_dp)
     833             : 
     834             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 1.0_dp, tmp1, &
     835           0 :                                 filter_eps=threshold, flop=flops)
     836           0 :             floptot = floptot + flops
     837             : 
     838           0 :             prefactor = 35.0_dp/128.0_dp
     839             :          ELSE IF (order .EQ. 5) THEN
     840             :             ! with two multiplications
     841             :             !      z(y)=(y+a_0)*y+a_1
     842             :             ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     843             :             !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     844             :             !    a_0=1/14
     845             :             !    a_1=23819/13720
     846             :             !    a_2=1269/980-2a_1=-3734/1715
     847             :             !    a_3=832591127/188238400
     848           8 :             a0 = 1.0_dp/14.0_dp
     849           8 :             a1 = 23819.0_dp/13720.0_dp
     850           8 :             a2 = -3734_dp/1715.0_dp
     851           8 :             a3 = 832591127_dp/188238400.0_dp
     852             : 
     853             :             ! tmp1=y
     854             :             ! tmp3=z
     855           8 :             CALL dbcsr_copy(tmp3, tmp1)
     856           8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     857             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     858           8 :                                 filter_eps=threshold, flop=flops)
     859           8 :             floptot = floptot + flops
     860           8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     861             : 
     862           8 :             CALL dbcsr_add_on_diag(tmp1, a2)
     863           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
     864             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, tmp3, &
     865           8 :                                 filter_eps=threshold, flop=flops)
     866           8 :             floptot = floptot + flops
     867           8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     868           8 :             CALL dbcsr_copy(tmp1, tmp3)
     869             : 
     870           8 :             prefactor = 35.0_dp/128.0_dp
     871             :          ELSE IF (order .EQ. 6) THEN
     872             :             ! with four multiplications
     873             :             ! f6(y)=63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     874             :             ! tmp1=y
     875           8 :             CALL dbcsr_copy(tmp3, tmp1)
     876           8 :             CALL dbcsr_scale(tmp1, 128.0_dp/63.0_dp)
     877           8 :             CALL dbcsr_add_on_diag(tmp1, 256.0_dp/63.0_dp)
     878             : 
     879             :             !
     880             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     881           8 :                                 filter_eps=threshold, flop=flops)
     882           8 :             floptot = floptot + flops
     883             : 
     884           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 32.0_dp/21.0_dp)
     885             : 
     886             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     887           8 :                                 filter_eps=threshold, flop=flops)
     888           8 :             floptot = floptot + flops
     889             : 
     890           8 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 80.0_dp/63.0_dp)
     891             : 
     892             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 0.0_dp, tmp2, &
     893           8 :                                 filter_eps=threshold, flop=flops)
     894           8 :             floptot = floptot + flops
     895             : 
     896           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 10.0_dp/9.0_dp)
     897             : 
     898             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     899           8 :                                 filter_eps=threshold, flop=flops)
     900           8 :             floptot = floptot + flops
     901             : 
     902           8 :             prefactor = 63.0_dp/256.0_dp
     903             :          ELSE IF (order .EQ. 7) THEN
     904             :             ! with three multiplications
     905             : 
     906           8 :             a0 = 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422_dp
     907           8 :             a1 = 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568_dp
     908           8 :             a2 = -1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968_dp
     909           8 :             a3 = 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453_dp
     910           8 :             a4 = -3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667_dp
     911           8 :             a5 = 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578_dp
     912             :             !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     913             :             ! z(y)=(y+a_0)*y+a_1
     914             :             ! w(y)=(y+a_2)*z(y)+a_3
     915             :             ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     916             : 
     917             :             ! tmp1=y
     918             :             ! tmp3=z
     919           8 :             CALL dbcsr_copy(tmp3, tmp1)
     920           8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     921             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     922           8 :                                 filter_eps=threshold, flop=flops)
     923           8 :             floptot = floptot + flops
     924           8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     925             : 
     926             :             ! tmp4=w
     927           8 :             CALL dbcsr_copy(tmp4, tmp1)
     928           8 :             CALL dbcsr_add_on_diag(tmp4, a2)
     929             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp2, 0.0_dp, tmp3, &
     930           8 :                                 filter_eps=threshold, flop=flops)
     931           8 :             floptot = floptot + flops
     932           8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     933             : 
     934           8 :             CALL dbcsr_add(tmp2, tmp3, 1.0_dp, 1.0_dp)
     935           8 :             CALL dbcsr_add_on_diag(tmp2, a4)
     936             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
     937           8 :                                 filter_eps=threshold, flop=flops)
     938           8 :             floptot = floptot + flops
     939           8 :             CALL dbcsr_add_on_diag(tmp1, a5)
     940             : 
     941           8 :             prefactor = 231.0_dp/1024.0_dp
     942             :          ELSE
     943           0 :             CPABORT("requested order is not implemented.")
     944             :          END IF
     945             : 
     946             :          ! tmp2 = X * prefactor *
     947             :          CALL dbcsr_multiply("N", "N", prefactor, matrix_sign, tmp1, 0.0_dp, tmp2, &
     948       13202 :                              filter_eps=threshold, flop=flops)
     949       13202 :          floptot = floptot + flops
     950             : 
     951             :          ! done iterating
     952             :          ! CALL dbcsr_filter(tmp2,threshold)
     953       13202 :          CALL dbcsr_copy(matrix_sign, tmp2)
     954       13202 :          t2 = m_walltime()
     955             : 
     956       13202 :          occ_matrix = dbcsr_get_occupation(matrix_sign)
     957             : 
     958       13202 :          IF (unit_nr > 0) THEN
     959        6601 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sign iter ", i, occ_matrix, &
     960        6601 :                frob_matrix/frob_matrix_base, t2 - t1, &
     961       13202 :                floptot/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     962        6601 :             CALL m_flush(unit_nr)
     963             :          END IF
     964             : 
     965             :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
     966       13202 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) EXIT
     967             : 
     968             :       END DO
     969             : 
     970             :       ! this check is not really needed
     971             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     972        1058 :                           filter_eps=threshold)
     973        1058 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     974        1058 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     975        1058 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
     976        1058 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
     977        1058 :       IF (unit_nr > 0) THEN
     978         529 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sign iter", i, occ_matrix, &
     979        1058 :             frob_matrix/frob_matrix_base
     980         529 :          WRITE (unit_nr, '()')
     981         529 :          CALL m_flush(unit_nr)
     982             :       END IF
     983             : 
     984        1058 :       CALL dbcsr_release(tmp1)
     985        1058 :       CALL dbcsr_release(tmp2)
     986        1058 :       IF (ABS(order) .GE. 4) THEN
     987           8 :          CALL dbcsr_release(tmp3)
     988             :       END IF
     989           8 :       IF (ABS(order) .GT. 4) THEN
     990           6 :          CALL dbcsr_release(tmp4)
     991             :       END IF
     992             : 
     993        1058 :       CALL timestop(handle)
     994             : 
     995        1058 :    END SUBROUTINE matrix_sign_Newton_Schulz
     996             : 
     997             :    ! **************************************************************************************************
     998             : !> \brief compute the sign a matrix using the general algorithm for the p-th root of Richters et al.
     999             : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1000             : !> \param matrix_sign ...
    1001             : !> \param matrix ...
    1002             : !> \param threshold ...
    1003             : !> \param sign_order ...
    1004             : !> \par History
    1005             : !>       2019.03 created [Robert Schade]
    1006             : !> \author Robert Schade
    1007             : ! **************************************************************************************************
    1008          16 :    SUBROUTINE matrix_sign_proot(matrix_sign, matrix, threshold, sign_order)
    1009             : 
    1010             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1011             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1012             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1013             : 
    1014             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sign_proot'
    1015             : 
    1016             :       INTEGER                                            :: handle, order, unit_nr
    1017             :       INTEGER(KIND=int_8)                                :: flop0, flop1, flop2
    1018             :       LOGICAL                                            :: converged, symmetrize
    1019             :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, occ_matrix
    1020             :       TYPE(cp_logger_type), POINTER                      :: logger
    1021             :       TYPE(dbcsr_type)                                   :: matrix2, matrix_sqrt, matrix_sqrt_inv, &
    1022             :                                                             tmp1, tmp2
    1023             : 
    1024           8 :       CALL cite_reference(Richters2018)
    1025             : 
    1026           8 :       CALL timeset(routineN, handle)
    1027             : 
    1028           8 :       logger => cp_get_default_logger()
    1029           8 :       IF (logger%para_env%is_source()) THEN
    1030           4 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1031             :       ELSE
    1032           4 :          unit_nr = -1
    1033             :       END IF
    1034             : 
    1035           8 :       IF (PRESENT(sign_order)) THEN
    1036           8 :          order = sign_order
    1037             :       ELSE
    1038           0 :          order = 2
    1039             :       END IF
    1040             : 
    1041           8 :       CALL dbcsr_create(tmp1, template=matrix_sign)
    1042             : 
    1043           8 :       CALL dbcsr_create(tmp2, template=matrix_sign)
    1044             : 
    1045           8 :       CALL dbcsr_create(matrix2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1046             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix, 0.0_dp, matrix2, &
    1047           8 :                           filter_eps=threshold, flop=flop0)
    1048             :       !CALL dbcsr_filter(matrix2, threshold)
    1049             : 
    1050             :       !CALL dbcsr_copy(matrix_sign, matrix)
    1051             :       !CALL dbcsr_filter(matrix_sign, threshold)
    1052             : 
    1053           8 :       IF (unit_nr > 0) WRITE (unit_nr, *)
    1054             : 
    1055           8 :       CALL dbcsr_create(matrix_sqrt, template=matrix2)
    1056           8 :       CALL dbcsr_create(matrix_sqrt_inv, template=matrix2)
    1057           8 :       IF (unit_nr > 0) WRITE (unit_nr, *) "Threshold=", threshold
    1058             : 
    1059           8 :       symmetrize = .FALSE.
    1060             :       CALL matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix2, threshold, order, &
    1061           8 :                              0.01_dp, 100, symmetrize, converged)
    1062             : !      call matrix_sqrt_Newton_Schulz(matrix_sqrt, matrix_sqrt_inv, matrix2, threshold, order, &
    1063             : !                                        0.01_dp,100, symmetrize,converged)
    1064             : 
    1065             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_sqrt_inv, 0.0_dp, matrix_sign, &
    1066           8 :                           filter_eps=threshold, flop=flop1)
    1067             : 
    1068             :       ! this check is not really needed
    1069             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
    1070           8 :                           filter_eps=threshold, flop=flop2)
    1071           8 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1072           8 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1073           8 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1074           8 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
    1075           8 :       IF (unit_nr > 0) THEN
    1076           4 :          WRITE (unit_nr, '(T6,A,F10.8,E12.3)') "Final proot sign iter", occ_matrix, &
    1077           8 :             frob_matrix/frob_matrix_base
    1078           4 :          WRITE (unit_nr, '()')
    1079           4 :          CALL m_flush(unit_nr)
    1080             :       END IF
    1081             : 
    1082           8 :       CALL dbcsr_release(tmp1)
    1083           8 :       CALL dbcsr_release(tmp2)
    1084           8 :       CALL dbcsr_release(matrix2)
    1085           8 :       CALL dbcsr_release(matrix_sqrt)
    1086           8 :       CALL dbcsr_release(matrix_sqrt_inv)
    1087             : 
    1088           8 :       CALL timestop(handle)
    1089             : 
    1090           8 :    END SUBROUTINE matrix_sign_proot
    1091             : 
    1092             : ! **************************************************************************************************
    1093             : !> \brief compute the sign of a dense matrix using Newton-Schulz iterations
    1094             : !> \param matrix_sign ...
    1095             : !> \param matrix ...
    1096             : !> \param matrix_id ...
    1097             : !> \param threshold ...
    1098             : !> \param sign_order ...
    1099             : !> \author Michael Lass, Robert Schade
    1100             : ! **************************************************************************************************
    1101           2 :    SUBROUTINE dense_matrix_sign_Newton_Schulz(matrix_sign, matrix, matrix_id, threshold, sign_order)
    1102             : 
    1103             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: matrix_sign
    1104             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: matrix
    1105             :       INTEGER, INTENT(IN), OPTIONAL                      :: matrix_id
    1106             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1107             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1108             : 
    1109             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dense_matrix_sign_Newton_Schulz'
    1110             : 
    1111             :       INTEGER                                            :: handle, i, j, sz, unit_nr
    1112             :       LOGICAL                                            :: converged
    1113             :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, &
    1114             :                                                             gersh_matrix, prefactor, scaling_factor
    1115           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp1, tmp2
    1116             :       REAL(KIND=dp), DIMENSION(1)                        :: work
    1117             :       REAL(KIND=dp), EXTERNAL                            :: dlange
    1118             :       TYPE(cp_logger_type), POINTER                      :: logger
    1119             : 
    1120           2 :       CALL timeset(routineN, handle)
    1121             : 
    1122             :       ! print output on all ranks
    1123           2 :       logger => cp_get_default_logger()
    1124           2 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1125             : 
    1126             :       ! scale the matrix to get into the convergence range
    1127           2 :       sz = SIZE(matrix, 1)
    1128           2 :       frob_matrix = dlange('F', sz, sz, matrix, sz, work) !dbcsr_frobenius_norm(matrix_sign)
    1129           2 :       gersh_matrix = dlange('1', sz, sz, matrix, sz, work) !dbcsr_gershgorin_norm(matrix_sign)
    1130           2 :       scaling_factor = 1/MIN(frob_matrix, gersh_matrix)
    1131          86 :       matrix_sign = matrix*scaling_factor
    1132           8 :       ALLOCATE (tmp1(sz, sz))
    1133           6 :       ALLOCATE (tmp2(sz, sz))
    1134             : 
    1135           2 :       converged = .FALSE.
    1136          14 :       DO i = 1, 100
    1137          14 :          CALL dgemm('N', 'N', sz, sz, sz, -1.0_dp, matrix_sign, sz, matrix_sign, sz, 0.0_dp, tmp1, sz)
    1138             : 
    1139             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1140          14 :          frob_matrix_base = dlange('F', sz, sz, tmp1, sz, work)
    1141          98 :          DO j = 1, sz
    1142          98 :             tmp1(j, j) = tmp1(j, j) + 1.0_dp
    1143             :          END DO
    1144          14 :          frob_matrix = dlange('F', sz, sz, tmp1, sz, work)
    1145             : 
    1146          14 :          IF (sign_order .EQ. 2) THEN
    1147           8 :             prefactor = 0.5_dp
    1148             :             ! update the above to 3*I-X*X
    1149          56 :             DO j = 1, sz
    1150          56 :                tmp1(j, j) = tmp1(j, j) + 2.0_dp
    1151             :             END DO
    1152           6 :          ELSE IF (sign_order .EQ. 3) THEN
    1153         258 :             tmp2(:, :) = tmp1
    1154         258 :             tmp1 = tmp1*4.0_dp/3.0_dp
    1155          42 :             DO j = 1, sz
    1156          42 :                tmp1(j, j) = tmp1(j, j) + 8.0_dp/3.0_dp
    1157             :             END DO
    1158           6 :             CALL dgemm('N', 'N', sz, sz, sz, 1.0_dp, tmp2, sz, tmp2, sz, 1.0_dp, tmp1, sz)
    1159           6 :             prefactor = 3.0_dp/8.0_dp
    1160             :          ELSE
    1161           0 :             CPABORT("requested order is not implemented.")
    1162             :          END IF
    1163             : 
    1164          14 :          CALL dgemm('N', 'N', sz, sz, sz, prefactor, matrix_sign, sz, tmp1, sz, 0.0_dp, tmp2, sz)
    1165         602 :          matrix_sign = tmp2
    1166             : 
    1167             :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
    1168          14 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) THEN
    1169             :             WRITE (unit_nr, '(T6,A,1X,I6,1X,A,1X,I3,E12.3)') &
    1170           2 :                "Submatrix", matrix_id, "final NS sign iter", i, frob_matrix/frob_matrix_base
    1171           2 :             CALL m_flush(unit_nr)
    1172             :             converged = .TRUE.
    1173             :             EXIT
    1174             :          END IF
    1175             :       END DO
    1176             : 
    1177             :       IF (.NOT. converged) &
    1178           0 :          CPABORT("dense_matrix_sign_Newton_Schulz did not converge within 100 iterations")
    1179             : 
    1180           2 :       DEALLOCATE (tmp1)
    1181           2 :       DEALLOCATE (tmp2)
    1182             : 
    1183           2 :       CALL timestop(handle)
    1184             : 
    1185           2 :    END SUBROUTINE dense_matrix_sign_Newton_Schulz
    1186             : 
    1187             : ! **************************************************************************************************
    1188             : !> \brief Perform eigendecomposition of a dense matrix
    1189             : !> \param sm ...
    1190             : !> \param N ...
    1191             : !> \param eigvals ...
    1192             : !> \param eigvecs ...
    1193             : !> \par History
    1194             : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1195             : !> \author Michael Lass, Robert Schade
    1196             : ! **************************************************************************************************
    1197           4 :    SUBROUTINE eigdecomp(sm, N, eigvals, eigvecs)
    1198             :       INTEGER, INTENT(IN)                                :: N
    1199             :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1200             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1201             :          INTENT(OUT)                                     :: eigvals
    1202             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1203             :          INTENT(OUT)                                     :: eigvecs
    1204             : 
    1205             :       INTEGER                                            :: info, liwork, lwork
    1206           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iwork
    1207           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work
    1208           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp
    1209             : 
    1210          24 :       ALLOCATE (eigvecs(N, N), tmp(N, N))
    1211          12 :       ALLOCATE (eigvals(N))
    1212             : 
    1213             :       ! symmetrize sm
    1214         172 :       eigvecs(:, :) = 0.5*(sm + TRANSPOSE(sm))
    1215             : 
    1216             :       ! probe optimal sizes for WORK and IWORK
    1217           4 :       LWORK = -1
    1218           4 :       LIWORK = -1
    1219           4 :       ALLOCATE (WORK(1))
    1220           4 :       ALLOCATE (IWORK(1))
    1221           4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1222           4 :       LWORK = INT(WORK(1))
    1223           4 :       LIWORK = INT(IWORK(1))
    1224           4 :       DEALLOCATE (IWORK)
    1225           4 :       DEALLOCATE (WORK)
    1226             : 
    1227             :       ! calculate eigenvalues and eigenvectors
    1228          12 :       ALLOCATE (WORK(LWORK))
    1229          12 :       ALLOCATE (IWORK(LIWORK))
    1230           4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1231           4 :       DEALLOCATE (IWORK)
    1232           4 :       DEALLOCATE (WORK)
    1233           4 :       IF (INFO .NE. 0) CPABORT("dsyevd did not succeed")
    1234             : 
    1235           4 :       DEALLOCATE (tmp)
    1236           4 :    END SUBROUTINE eigdecomp
    1237             : 
    1238             : ! **************************************************************************************************
    1239             : !> \brief Calculate the sign matrix from eigenvalues and eigenvectors of a matrix
    1240             : !> \param sm_sign ...
    1241             : !> \param eigvals ...
    1242             : !> \param eigvecs ...
    1243             : !> \param N ...
    1244             : !> \param mu_correction ...
    1245             : !> \par History
    1246             : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1247             : !> \author Michael Lass, Robert Schade
    1248             : ! **************************************************************************************************
    1249           4 :    SUBROUTINE sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, mu_correction)
    1250             :       INTEGER                                            :: N
    1251             :       REAL(KIND=dp), INTENT(IN)                          :: eigvecs(N, N), eigvals(N)
    1252             :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1253             :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1254             : 
    1255             :       INTEGER                                            :: i
    1256           4 :       REAL(KIND=dp)                                      :: modified_eigval, tmp(N, N)
    1257             : 
    1258         172 :       sm_sign = 0
    1259          28 :       DO i = 1, N
    1260          24 :          modified_eigval = eigvals(i) - mu_correction
    1261          28 :          IF (modified_eigval > 0) THEN
    1262           6 :             sm_sign(i, i) = 1.0
    1263          18 :          ELSE IF (modified_eigval < 0) THEN
    1264          18 :             sm_sign(i, i) = -1.0
    1265             :          ELSE
    1266           0 :             sm_sign(i, i) = 0.0
    1267             :          END IF
    1268             :       END DO
    1269             : 
    1270             :       ! Create matrix with eigenvalues in {-1,0,1} and eigenvectors of sm:
    1271             :       ! sm_sign = eigvecs * sm_sign * eigvecs.T
    1272           4 :       CALL dgemm('N', 'N', N, N, N, 1.0_dp, eigvecs, N, sm_sign, N, 0.0_dp, tmp, N)
    1273           4 :       CALL dgemm('N', 'T', N, N, N, 1.0_dp, tmp, N, eigvecs, N, 0.0_dp, sm_sign, N)
    1274           4 :    END SUBROUTINE sign_from_eigdecomp
    1275             : 
    1276             : ! **************************************************************************************************
    1277             : !> \brief Compute partial trace of a matrix from its eigenvalues and eigenvectors
    1278             : !> \param eigvals ...
    1279             : !> \param eigvecs ...
    1280             : !> \param firstcol ...
    1281             : !> \param lastcol ...
    1282             : !> \param mu_correction ...
    1283             : !> \return ...
    1284             : !> \par History
    1285             : !>       2020.05 Created [Michael Lass]
    1286             : !> \author Michael Lass
    1287             : ! **************************************************************************************************
    1288          36 :    FUNCTION trace_from_eigdecomp(eigvals, eigvecs, firstcol, lastcol, mu_correction) RESULT(trace)
    1289             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1290             :          INTENT(IN)                                      :: eigvals
    1291             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1292             :          INTENT(IN)                                      :: eigvecs
    1293             :       INTEGER, INTENT(IN)                                :: firstcol, lastcol
    1294             :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1295             :       REAL(KIND=dp)                                      :: trace
    1296             : 
    1297             :       INTEGER                                            :: i, j, sm_size
    1298             :       REAL(KIND=dp)                                      :: modified_eigval, tmpsum
    1299          36 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: mapped_eigvals
    1300             : 
    1301          36 :       sm_size = SIZE(eigvals)
    1302         108 :       ALLOCATE (mapped_eigvals(sm_size))
    1303             : 
    1304         252 :       DO i = 1, sm_size
    1305         216 :          modified_eigval = eigvals(i) - mu_correction
    1306         252 :          IF (modified_eigval > 0) THEN
    1307          26 :             mapped_eigvals(i) = 1.0
    1308         190 :          ELSE IF (modified_eigval < 0) THEN
    1309         190 :             mapped_eigvals(i) = -1.0
    1310             :          ELSE
    1311           0 :             mapped_eigvals(i) = 0.0
    1312             :          END IF
    1313             :       END DO
    1314             : 
    1315          36 :       trace = 0.0_dp
    1316         252 :       DO i = firstcol, lastcol
    1317             :          tmpsum = 0.0_dp
    1318        1512 :          DO j = 1, sm_size
    1319        1512 :             tmpsum = tmpsum + (eigvecs(i, j)*mapped_eigvals(j)*eigvecs(i, j))
    1320             :          END DO
    1321         252 :          trace = trace - 0.5_dp*tmpsum + 0.5_dp
    1322             :       END DO
    1323          36 :    END FUNCTION trace_from_eigdecomp
    1324             : 
    1325             : ! **************************************************************************************************
    1326             : !> \brief Calculate the sign matrix by direct calculation of all eigenvalues and eigenvectors
    1327             : !> \param sm_sign ...
    1328             : !> \param sm ...
    1329             : !> \param N ...
    1330             : !> \par History
    1331             : !>       2020.02 Created [Michael Lass, Robert Schade]
    1332             : !>       2020.05 Extracted eigdecomp and sign_from_eigdecomp [Michael Lass]
    1333             : !> \author Michael Lass, Robert Schade
    1334             : ! **************************************************************************************************
    1335           2 :    SUBROUTINE dense_matrix_sign_direct(sm_sign, sm, N)
    1336             :       INTEGER, INTENT(IN)                                :: N
    1337             :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1338             :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1339             : 
    1340           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigvals
    1341           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: eigvecs
    1342             : 
    1343             :       CALL eigdecomp(sm, N, eigvals, eigvecs)
    1344           2 :       CALL sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, 0.0_dp)
    1345             : 
    1346           2 :       DEALLOCATE (eigvals, eigvecs)
    1347           2 :    END SUBROUTINE dense_matrix_sign_direct
    1348             : 
    1349             : ! **************************************************************************************************
    1350             : !> \brief Submatrix method
    1351             : !> \param matrix_sign ...
    1352             : !> \param matrix ...
    1353             : !> \param threshold ...
    1354             : !> \param sign_order ...
    1355             : !> \param submatrix_sign_method ...
    1356             : !> \par History
    1357             : !>       2019.03 created [Robert Schade]
    1358             : !>       2019.06 impl. submatrix method [Michael Lass]
    1359             : !> \author Robert Schade, Michael Lass
    1360             : ! **************************************************************************************************
    1361           6 :    SUBROUTINE matrix_sign_submatrix(matrix_sign, matrix, threshold, sign_order, submatrix_sign_method)
    1362             : 
    1363             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1364             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1365             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1366             :       INTEGER, INTENT(IN)                                :: submatrix_sign_method
    1367             : 
    1368             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix'
    1369             : 
    1370             :       INTEGER                                            :: group, handle, i, myrank, nblkcols, &
    1371             :                                                             order, sm_size, unit_nr
    1372           6 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1373           6 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign
    1374             :       TYPE(cp_logger_type), POINTER                      :: logger
    1375             :       TYPE(dbcsr_distribution_type)                      :: dist
    1376           6 :       TYPE(submatrix_dissection_type)                    :: dissection
    1377             : 
    1378           6 :       CALL timeset(routineN, handle)
    1379             : 
    1380             :       ! print output on all ranks
    1381           6 :       logger => cp_get_default_logger()
    1382           6 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1383             : 
    1384           6 :       IF (PRESENT(sign_order)) THEN
    1385           6 :          order = sign_order
    1386             :       ELSE
    1387           0 :          order = 2
    1388             :       END IF
    1389             : 
    1390           6 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist, group=group)
    1391           6 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1392             : 
    1393           6 :       CALL dissection%init(matrix)
    1394           6 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1395             : 
    1396             :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1397             :       !$OMP          PRIVATE(sm, sm_sign, sm_size) &
    1398           6 :       !$OMP          SHARED(dissection, myrank, my_sms, order, submatrix_sign_method, threshold, unit_nr)
    1399             :       !$OMP DO SCHEDULE(GUIDED)
    1400             :       DO i = 1, SIZE(my_sms)
    1401             :          WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "processing submatrix", my_sms(i)
    1402             :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1403             :          sm_size = SIZE(sm, 1)
    1404             :          ALLOCATE (sm_sign(sm_size, sm_size))
    1405             :          SELECT CASE (submatrix_sign_method)
    1406             :          CASE (ls_scf_submatrix_sign_ns)
    1407             :             CALL dense_matrix_sign_Newton_Schulz(sm_sign, sm, my_sms(i), threshold, order)
    1408             :          CASE (ls_scf_submatrix_sign_direct, ls_scf_submatrix_sign_direct_muadj, ls_scf_submatrix_sign_direct_muadj_lowmem)
    1409             :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1410             :          CASE DEFAULT
    1411             :             CPABORT("Unkown submatrix sign method.")
    1412             :          END SELECT
    1413             :          CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1414             :          DEALLOCATE (sm, sm_sign)
    1415             :       END DO
    1416             :       !$OMP END DO
    1417             :       !$OMP END PARALLEL
    1418             : 
    1419           6 :       CALL dissection%communicate_results(matrix_sign)
    1420           6 :       CALL dissection%final
    1421             : 
    1422           6 :       CALL timestop(handle)
    1423             : 
    1424          12 :    END SUBROUTINE matrix_sign_submatrix
    1425             : 
    1426             : ! **************************************************************************************************
    1427             : !> \brief Submatrix method with internal adjustment of chemical potential
    1428             : !> \param matrix_sign ...
    1429             : !> \param matrix ...
    1430             : !> \param mu ...
    1431             : !> \param nelectron ...
    1432             : !> \param threshold ...
    1433             : !> \param variant ...
    1434             : !> \par History
    1435             : !>       2020.05 Created [Michael Lass]
    1436             : !> \author Robert Schade, Michael Lass
    1437             : ! **************************************************************************************************
    1438           4 :    SUBROUTINE matrix_sign_submatrix_mu_adjust(matrix_sign, matrix, mu, nelectron, threshold, variant)
    1439             : 
    1440             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1441             :       REAL(KIND=dp), INTENT(INOUT)                       :: mu
    1442             :       INTEGER, INTENT(IN)                                :: nelectron
    1443             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1444             :       INTEGER, INTENT(IN)                                :: variant
    1445             : 
    1446             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix_mu_adjust'
    1447             :       REAL(KIND=dp), PARAMETER                           :: initial_increment = 0.01_dp
    1448             : 
    1449             :       INTEGER                                            :: group_handle, handle, i, j, myrank, &
    1450             :                                                             nblkcols, sm_firstcol, sm_lastcol, &
    1451             :                                                             sm_size, unit_nr
    1452           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1453             :       LOGICAL                                            :: has_mu_high, has_mu_low
    1454             :       REAL(KIND=dp)                                      :: increment, mu_high, mu_low, new_mu, trace
    1455           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign, tmp
    1456             :       TYPE(cp_logger_type), POINTER                      :: logger
    1457             :       TYPE(dbcsr_distribution_type)                      :: dist
    1458           4 :       TYPE(eigbuf), ALLOCATABLE, DIMENSION(:)            :: eigbufs
    1459             :       TYPE(mp_comm_type)                                 :: group
    1460           4 :       TYPE(submatrix_dissection_type)                    :: dissection
    1461             : 
    1462           4 :       CALL timeset(routineN, handle)
    1463             : 
    1464             :       ! print output on all ranks
    1465           4 :       logger => cp_get_default_logger()
    1466           4 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1467             : 
    1468           4 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist, group=group_handle)
    1469           4 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1470             : 
    1471           4 :       CALL group%set_handle(group_handle)
    1472             : 
    1473           4 :       CALL dissection%init(matrix)
    1474           4 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1475             : 
    1476          12 :       ALLOCATE (eigbufs(SIZE(my_sms)))
    1477             : 
    1478             :       trace = 0.0_dp
    1479             : 
    1480             :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1481             :       !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j, tmp) &
    1482             :       !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, threshold, variant) &
    1483           4 :       !$OMP          REDUCTION(+:trace)
    1484             :       !$OMP DO SCHEDULE(GUIDED)
    1485             :       DO i = 1, SIZE(my_sms)
    1486             :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1487             :          sm_size = SIZE(sm, 1)
    1488             :          WRITE (unit_nr, *) "Rank", myrank, "processing submatrix", my_sms(i), "size", sm_size
    1489             : 
    1490             :          CALL dissection%get_relevant_sm_columns(my_sms(i), sm_firstcol, sm_lastcol)
    1491             : 
    1492             :          IF (variant .EQ. ls_scf_submatrix_sign_direct_muadj) THEN
    1493             :             ! Store all eigenvectors in buffer. We will use it to compute sm_sign at the end.
    1494             :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=eigbufs(i)%eigvecs)
    1495             :          ELSE
    1496             :             ! Only store eigenvectors that are required for mu adjustment.
    1497             :             ! Calculate sm_sign right away in the hope that mu is already correct.
    1498             :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=tmp)
    1499             :             ALLOCATE (eigbufs(i)%eigvecs(sm_firstcol:sm_lastcol, 1:sm_size))
    1500             :             eigbufs(i)%eigvecs(:, :) = tmp(sm_firstcol:sm_lastcol, 1:sm_size)
    1501             : 
    1502             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1503             :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, tmp, sm_size, 0.0_dp)
    1504             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1505             :             DEALLOCATE (sm_sign, tmp)
    1506             :          END IF
    1507             : 
    1508             :          DEALLOCATE (sm)
    1509             :          trace = trace + trace_from_eigdecomp(eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_firstcol, sm_lastcol, 0.0_dp)
    1510             :       END DO
    1511             :       !$OMP END DO
    1512             :       !$OMP END PARALLEL
    1513             : 
    1514           4 :       has_mu_low = .FALSE.
    1515           4 :       has_mu_high = .FALSE.
    1516           4 :       increment = initial_increment
    1517           4 :       new_mu = mu
    1518          72 :       DO i = 1, 30
    1519          72 :          CALL group%sum(trace)
    1520          72 :          IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,1X,F13.9,1X,F15.9)') &
    1521          72 :             "Density matrix:  mu, trace error: ", new_mu, trace - nelectron
    1522          72 :          IF (ABS(trace - nelectron) < 0.5_dp) EXIT
    1523          68 :          IF (trace < nelectron) THEN
    1524           8 :             mu_low = new_mu
    1525           8 :             new_mu = new_mu + increment
    1526           8 :             has_mu_low = .TRUE.
    1527           8 :             increment = increment*2
    1528             :          ELSE
    1529          60 :             mu_high = new_mu
    1530          60 :             new_mu = new_mu - increment
    1531          60 :             has_mu_high = .TRUE.
    1532          60 :             increment = increment*2
    1533             :          END IF
    1534             : 
    1535          68 :          IF (has_mu_low .AND. has_mu_high) THEN
    1536          20 :             new_mu = (mu_low + mu_high)/2
    1537          20 :             IF (ABS(mu_high - mu_low) < threshold) EXIT
    1538             :          END IF
    1539             : 
    1540             :          trace = 0
    1541             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1542             :          !$OMP          PRIVATE(i, sm_sign, tmp, sm_size, sm_firstcol, sm_lastcol) &
    1543             :          !$OMP          SHARED(dissection, my_sms, unit_nr, eigbufs, mu, new_mu, nelectron) &
    1544          72 :          !$OMP          REDUCTION(+:trace)
    1545             :          !$OMP DO SCHEDULE(GUIDED)
    1546             :          DO j = 1, SIZE(my_sms)
    1547             :             sm_size = SIZE(eigbufs(j)%eigvals)
    1548             :             CALL dissection%get_relevant_sm_columns(my_sms(j), sm_firstcol, sm_lastcol)
    1549             :             trace = trace + trace_from_eigdecomp(eigbufs(j)%eigvals, eigbufs(j)%eigvecs, sm_firstcol, sm_lastcol, new_mu - mu)
    1550             :          END DO
    1551             :          !$OMP END DO
    1552             :          !$OMP END PARALLEL
    1553             :       END DO
    1554             : 
    1555             :       ! Finalize sign matrix from eigendecompositions if we kept all eigenvectors
    1556           4 :       IF (variant .EQ. ls_scf_submatrix_sign_direct_muadj) THEN
    1557             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1558             :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1559           2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1560             :          !$OMP DO SCHEDULE(GUIDED)
    1561             :          DO i = 1, SIZE(my_sms)
    1562             :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "finalizing submatrix", my_sms(i)
    1563             :             sm_size = SIZE(eigbufs(i)%eigvals)
    1564             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1565             :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_size, new_mu - mu)
    1566             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1567             :             DEALLOCATE (sm_sign)
    1568             :          END DO
    1569             :          !$OMP END DO
    1570             :          !$OMP END PARALLEL
    1571             :       END IF
    1572             : 
    1573           6 :       DEALLOCATE (eigbufs)
    1574             : 
    1575             :       ! If we only stored parts of the eigenvectors and mu has changed, we need to recompute sm_sign
    1576           4 :       IF ((variant .EQ. ls_scf_submatrix_sign_direct_muadj_lowmem) .AND. (mu .NE. new_mu)) THEN
    1577             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1578             :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1579           2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1580             :          !$OMP DO SCHEDULE(GUIDED)
    1581             :          DO i = 1, SIZE(my_sms)
    1582             :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "reprocessing submatrix", my_sms(i)
    1583             :             CALL dissection%generate_submatrix(my_sms(i), sm)
    1584             :             sm_size = SIZE(sm, 1)
    1585             :             DO j = 1, sm_size
    1586             :                sm(j, j) = sm(j, j) + mu - new_mu
    1587             :             END DO
    1588             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1589             :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1590             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1591             :             DEALLOCATE (sm, sm_sign)
    1592             :          END DO
    1593             :          !$OMP END DO
    1594             :          !$OMP END PARALLEL
    1595             :       END IF
    1596             : 
    1597           4 :       mu = new_mu
    1598             : 
    1599           4 :       CALL dissection%communicate_results(matrix_sign)
    1600           4 :       CALL dissection%final
    1601             : 
    1602           4 :       CALL timestop(handle)
    1603             : 
    1604          12 :    END SUBROUTINE matrix_sign_submatrix_mu_adjust
    1605             : 
    1606             : ! **************************************************************************************************
    1607             : !> \brief compute the sqrt of a matrix via the sign function and the corresponding Newton-Schulz iterations
    1608             : !>        the order of the algorithm should be 2..5, 3 or 5 is recommended
    1609             : !> \param matrix_sqrt ...
    1610             : !> \param matrix_sqrt_inv ...
    1611             : !> \param matrix ...
    1612             : !> \param threshold ...
    1613             : !> \param order ...
    1614             : !> \param eps_lanczos ...
    1615             : !> \param max_iter_lanczos ...
    1616             : !> \param symmetrize ...
    1617             : !> \param converged ...
    1618             : !> \par History
    1619             : !>       2010.10 created [Joost VandeVondele]
    1620             : !> \author Joost VandeVondele
    1621             : ! **************************************************************************************************
    1622       14278 :    SUBROUTINE matrix_sqrt_Newton_Schulz(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1623             :                                         eps_lanczos, max_iter_lanczos, symmetrize, converged)
    1624             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1625             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1626             :       INTEGER, INTENT(IN)                                :: order
    1627             :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1628             :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1629             :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1630             : 
    1631             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sqrt_Newton_Schulz'
    1632             : 
    1633             :       INTEGER                                            :: handle, i, unit_nr
    1634             :       INTEGER(KIND=int_8)                                :: flop1, flop2, flop3, flop4, flop5
    1635             :       LOGICAL                                            :: arnoldi_converged, tsym
    1636             :       REAL(KIND=dp)                                      :: a, b, c, conv, d, frob_matrix, &
    1637             :                                                             frob_matrix_base, gershgorin_norm, &
    1638             :                                                             max_ev, min_ev, oa, ob, oc, &
    1639             :                                                             occ_matrix, od, scaling, t1, t2
    1640             :       TYPE(cp_logger_type), POINTER                      :: logger
    1641             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
    1642             : 
    1643       14278 :       CALL timeset(routineN, handle)
    1644             : 
    1645       14278 :       logger => cp_get_default_logger()
    1646       14278 :       IF (logger%para_env%is_source()) THEN
    1647        7139 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1648             :       ELSE
    1649        7139 :          unit_nr = -1
    1650             :       END IF
    1651             : 
    1652       14278 :       IF (PRESENT(converged)) converged = .FALSE.
    1653       14278 :       IF (PRESENT(symmetrize)) THEN
    1654           0 :          tsym = symmetrize
    1655             :       ELSE
    1656             :          tsym = .TRUE.
    1657             :       END IF
    1658             : 
    1659             :       ! for stability symmetry can not be assumed
    1660       14278 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1661       14278 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1662       14278 :       IF (order .GE. 4) THEN
    1663          20 :          CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1664             :       END IF
    1665             : 
    1666       14278 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1667       14278 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1668       14278 :       CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1669       14278 :       CALL dbcsr_copy(matrix_sqrt, matrix)
    1670             : 
    1671             :       ! scale the matrix to get into the convergence range
    1672       14278 :       IF (order == 0) THEN
    1673             : 
    1674           0 :          gershgorin_norm = dbcsr_gershgorin_norm(matrix_sqrt)
    1675           0 :          frob_matrix = dbcsr_frobenius_norm(matrix_sqrt)
    1676           0 :          scaling = 1.0_dp/MIN(frob_matrix, gershgorin_norm)
    1677             : 
    1678             :       ELSE
    1679             : 
    1680             :          ! scale the matrix to get into the convergence range
    1681             :          CALL arnoldi_extremal(matrix_sqrt, max_ev, min_ev, threshold=eps_lanczos, &
    1682       14278 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1683       14278 :          IF (unit_nr > 0) THEN
    1684        7139 :             WRITE (unit_nr, *)
    1685        7139 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1686        7139 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1687        7139 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1688             :          END IF
    1689             :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1690             :          ! and adjust the scaling to be on the safe side
    1691       14278 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1692             : 
    1693             :       END IF
    1694             : 
    1695       14278 :       CALL dbcsr_scale(matrix_sqrt, scaling)
    1696       14278 :       CALL dbcsr_filter(matrix_sqrt, threshold)
    1697       14278 :       IF (unit_nr > 0) THEN
    1698        7139 :          WRITE (unit_nr, *)
    1699        7139 :          WRITE (unit_nr, *) "Order=", order
    1700             :       END IF
    1701             : 
    1702       67374 :       DO i = 1, 100
    1703             : 
    1704       67374 :          t1 = m_walltime()
    1705             : 
    1706             :          ! tmp1 = Zk * Yk - I
    1707             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1708       67374 :                              filter_eps=threshold, flop=flop1)
    1709       67374 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1710       67374 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1711             : 
    1712             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1713       67374 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    1714             : 
    1715       67374 :          flop4 = 0; flop5 = 0
    1716          36 :          SELECT CASE (order)
    1717             :          CASE (0, 2)
    1718             :             ! update the above to 0.5*(3*I-Zk*Yk)
    1719          36 :             CALL dbcsr_add_on_diag(tmp1, -2.0_dp)
    1720          36 :             CALL dbcsr_scale(tmp1, -0.5_dp)
    1721             :          CASE (3)
    1722             :             ! tmp2 = tmp1 ** 2
    1723             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1724       67278 :                                 filter_eps=threshold, flop=flop4)
    1725             :             ! tmp1 = 1/16 * (16*I-8*tmp1+6*tmp1**2-5*tmp1**3)
    1726       67278 :             CALL dbcsr_add(tmp1, tmp2, -4.0_dp, 3.0_dp)
    1727       67278 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp)
    1728       67278 :             CALL dbcsr_scale(tmp1, 0.125_dp)
    1729             :          CASE (4) ! as expensive as case(5), so little need to use it
    1730             :             ! tmp2 = tmp1 ** 2
    1731             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1732          32 :                                 filter_eps=threshold, flop=flop4)
    1733             :             ! tmp3 = tmp2 * tmp1
    1734             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp1, 0.0_dp, tmp3, &
    1735          32 :                                 filter_eps=threshold, flop=flop5)
    1736          32 :             CALL dbcsr_scale(tmp1, -8.0_dp)
    1737          32 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp)
    1738          32 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp)
    1739          32 :             CALL dbcsr_add(tmp1, tmp3, 1.0_dp, -5.0_dp)
    1740          32 :             CALL dbcsr_scale(tmp1, 1/16.0_dp)
    1741             :          CASE (5)
    1742             :             ! Knuth's reformulation to evaluate the polynomial of 4th degree in 2 multiplications
    1743             :             ! p = y4+A*y3+B*y2+C*y+D
    1744             :             ! z := y * (y+a); P := (z+y+b) * (z+c) + d.
    1745             :             ! a=(A-1)/2 ; b=B*(a+1)-C-a*(a+1)*(a+1)
    1746             :             ! c=B-b-a*(a+1)
    1747             :             ! d=D-bc
    1748          28 :             oa = -40.0_dp/35.0_dp
    1749          28 :             ob = 48.0_dp/35.0_dp
    1750          28 :             oc = -64.0_dp/35.0_dp
    1751          28 :             od = 128.0_dp/35.0_dp
    1752          28 :             a = (oa - 1)/2
    1753          28 :             b = ob*(a + 1) - oc - a*(a + 1)**2
    1754          28 :             c = ob - b - a*(a + 1)
    1755          28 :             d = od - b*c
    1756             :             ! tmp2 = tmp1 ** 2 + a * tmp1
    1757             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1758          28 :                                 filter_eps=threshold, flop=flop4)
    1759          28 :             CALL dbcsr_add(tmp2, tmp1, 1.0_dp, a)
    1760             :             ! tmp3 = tmp2 + tmp1 + b
    1761          28 :             CALL dbcsr_copy(tmp3, tmp2)
    1762          28 :             CALL dbcsr_add(tmp3, tmp1, 1.0_dp, 1.0_dp)
    1763          28 :             CALL dbcsr_add_on_diag(tmp3, b)
    1764             :             ! tmp2 = tmp2 + c
    1765          28 :             CALL dbcsr_add_on_diag(tmp2, c)
    1766             :             ! tmp1 = tmp2 * tmp3
    1767             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
    1768          28 :                                 filter_eps=threshold, flop=flop5)
    1769             :             ! tmp1 = tmp1 + d
    1770          28 :             CALL dbcsr_add_on_diag(tmp1, d)
    1771             :             ! final scale
    1772          28 :             CALL dbcsr_scale(tmp1, 35.0_dp/128.0_dp)
    1773             :          CASE DEFAULT
    1774       67374 :             CPABORT("Illegal order value")
    1775             :          END SELECT
    1776             : 
    1777             :          ! tmp2 = Yk * tmp1 = Y(k+1)
    1778             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt, tmp1, 0.0_dp, tmp2, &
    1779       67374 :                              filter_eps=threshold, flop=flop2)
    1780             :          ! CALL dbcsr_filter(tmp2,threshold)
    1781       67374 :          CALL dbcsr_copy(matrix_sqrt, tmp2)
    1782             : 
    1783             :          ! tmp2 = tmp1 * Zk = Z(k+1)
    1784             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_sqrt_inv, 0.0_dp, tmp2, &
    1785       67374 :                              filter_eps=threshold, flop=flop3)
    1786             :          ! CALL dbcsr_filter(tmp2,threshold)
    1787       67374 :          CALL dbcsr_copy(matrix_sqrt_inv, tmp2)
    1788             : 
    1789       67374 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1790             : 
    1791             :          ! done iterating
    1792       67374 :          t2 = m_walltime()
    1793             : 
    1794       67374 :          conv = frob_matrix/frob_matrix_base
    1795             : 
    1796       67374 :          IF (unit_nr > 0) THEN
    1797       33687 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sqrt iter ", i, occ_matrix, &
    1798       33687 :                conv, t2 - t1, &
    1799       67374 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    1800       33687 :             CALL m_flush(unit_nr)
    1801             :          END IF
    1802             : 
    1803       67374 :          IF (abnormal_value(conv)) &
    1804           0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    1805             : 
    1806             :          ! conv < SQRT(threshold)
    1807       67374 :          IF ((conv*conv) < threshold) THEN
    1808       14278 :             IF (PRESENT(converged)) converged = .TRUE.
    1809             :             EXIT
    1810             :          END IF
    1811             : 
    1812             :       END DO
    1813             : 
    1814             :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    1815       14278 :       IF (tsym) THEN
    1816       14278 :          IF (unit_nr > 0) THEN
    1817        7139 :             WRITE (unit_nr, '(T6,A20)') "Symmetrizing Results"
    1818             :          END IF
    1819       14278 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    1820       14278 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    1821       14278 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    1822       14278 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    1823             :       END IF
    1824             : 
    1825             :       ! this check is not really needed
    1826             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1827       14278 :                           filter_eps=threshold)
    1828       14278 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1829       14278 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1830       14278 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1831       14278 :       occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1832       14278 :       IF (unit_nr > 0) THEN
    1833        7139 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sqrt iter ", i, occ_matrix, &
    1834       14278 :             frob_matrix/frob_matrix_base
    1835        7139 :          WRITE (unit_nr, '()')
    1836        7139 :          CALL m_flush(unit_nr)
    1837             :       END IF
    1838             : 
    1839             :       ! scale to proper end results
    1840       14278 :       CALL dbcsr_scale(matrix_sqrt, 1/SQRT(scaling))
    1841       14278 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    1842             : 
    1843       14278 :       CALL dbcsr_release(tmp1)
    1844       14278 :       CALL dbcsr_release(tmp2)
    1845       14278 :       IF (order .GE. 4) THEN
    1846          20 :          CALL dbcsr_release(tmp3)
    1847             :       END IF
    1848             : 
    1849       14278 :       CALL timestop(handle)
    1850             : 
    1851       14278 :    END SUBROUTINE matrix_sqrt_Newton_Schulz
    1852             : 
    1853             : ! **************************************************************************************************
    1854             : !> \brief compute the sqrt of a matrix via the general algorithm for the p-th root of Richters et al.
    1855             : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1856             : !> \param matrix_sqrt ...
    1857             : !> \param matrix_sqrt_inv ...
    1858             : !> \param matrix ...
    1859             : !> \param threshold ...
    1860             : !> \param order ...
    1861             : !> \param eps_lanczos ...
    1862             : !> \param max_iter_lanczos ...
    1863             : !> \param symmetrize ...
    1864             : !> \param converged ...
    1865             : !> \par History
    1866             : !>       2019.04 created [Robert Schade]
    1867             : !> \author Robert Schade
    1868             : ! **************************************************************************************************
    1869          48 :    SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1870             :                                 eps_lanczos, max_iter_lanczos, symmetrize, converged)
    1871             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1872             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1873             :       INTEGER, INTENT(IN)                                :: order
    1874             :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1875             :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1876             :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1877             : 
    1878             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sqrt_proot'
    1879             : 
    1880             :       INTEGER                                            :: choose, handle, i, ii, j, unit_nr
    1881             :       INTEGER(KIND=int_8)                                :: f, flop1, flop2, flop3, flop4, flop5
    1882             :       LOGICAL                                            :: arnoldi_converged, test, tsym
    1883             :       REAL(KIND=dp)                                      :: conv, frob_matrix, frob_matrix_base, &
    1884             :                                                             max_ev, min_ev, occ_matrix, scaling, &
    1885             :                                                             t1, t2
    1886             :       TYPE(cp_logger_type), POINTER                      :: logger
    1887             :       TYPE(dbcsr_type)                                   :: BK2A, matrixS, Rmat, tmp1, tmp2, tmp3
    1888             : 
    1889          16 :       CALL cite_reference(Richters2018)
    1890             : 
    1891          16 :       test = .FALSE.
    1892             : 
    1893          16 :       CALL timeset(routineN, handle)
    1894             : 
    1895          16 :       logger => cp_get_default_logger()
    1896          16 :       IF (logger%para_env%is_source()) THEN
    1897           8 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1898             :       ELSE
    1899           8 :          unit_nr = -1
    1900             :       END IF
    1901             : 
    1902          16 :       IF (PRESENT(converged)) converged = .FALSE.
    1903          16 :       IF (PRESENT(symmetrize)) THEN
    1904          16 :          tsym = symmetrize
    1905             :       ELSE
    1906             :          tsym = .TRUE.
    1907             :       END IF
    1908             : 
    1909             :       ! for stability symmetry can not be assumed
    1910          16 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1911          16 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1912          16 :       CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1913          16 :       CALL dbcsr_create(Rmat, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1914          16 :       CALL dbcsr_create(matrixS, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1915             : 
    1916          16 :       CALL dbcsr_copy(matrixS, matrix)
    1917             :       IF (1 .EQ. 1) THEN
    1918             :          ! scale the matrix to get into the convergence range
    1919             :          CALL arnoldi_extremal(matrixS, max_ev, min_ev, threshold=eps_lanczos, &
    1920          16 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1921          16 :          IF (unit_nr > 0) THEN
    1922           8 :             WRITE (unit_nr, *)
    1923           8 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1924           8 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1925           8 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1926             :          END IF
    1927             :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1928             :          ! and adjust the scaling to be on the safe side
    1929          16 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1930          16 :          CALL dbcsr_scale(matrixS, scaling)
    1931          16 :          CALL dbcsr_filter(matrixS, threshold)
    1932             :       ELSE
    1933             :          scaling = 1.0_dp
    1934             :       END IF
    1935             : 
    1936          16 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1937          16 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1938             :       !CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1939             : 
    1940          16 :       IF (unit_nr > 0) THEN
    1941           8 :          WRITE (unit_nr, *)
    1942           8 :          WRITE (unit_nr, *) "Order=", order
    1943             :       END IF
    1944             : 
    1945          86 :       DO i = 1, 100
    1946             : 
    1947          86 :          t1 = m_walltime()
    1948             :          IF (1 .EQ. 1) THEN
    1949             :             !build R=1-A B_K^2
    1950             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp1, &
    1951          86 :                                 filter_eps=threshold, flop=flop1)
    1952             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrixS, tmp1, 0.0_dp, Rmat, &
    1953          86 :                                 filter_eps=threshold, flop=flop2)
    1954          86 :             CALL dbcsr_scale(Rmat, -1.0_dp)
    1955          86 :             CALL dbcsr_add_on_diag(Rmat, 1.0_dp)
    1956             : 
    1957          86 :             flop4 = 0; flop5 = 0
    1958          86 :             CALL dbcsr_set(tmp1, 0.0_dp)
    1959          86 :             CALL dbcsr_add_on_diag(tmp1, 2.0_dp)
    1960             : 
    1961          86 :             flop3 = 0
    1962             : 
    1963         274 :             DO j = 2, order
    1964         188 :                IF (j .EQ. 2) THEN
    1965          86 :                   CALL dbcsr_copy(tmp2, Rmat)
    1966             :                ELSE
    1967             :                   f = 0
    1968         102 :                   CALL dbcsr_copy(tmp3, tmp2)
    1969             :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, Rmat, 0.0_dp, tmp2, &
    1970         102 :                                       filter_eps=threshold, flop=f)
    1971         102 :                   flop3 = flop3 + f
    1972             :                END IF
    1973         274 :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
    1974             :             END DO
    1975             :          ELSE
    1976             :             CALL dbcsr_create(BK2A, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1977             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrixS, 0.0_dp, tmp3, &
    1978             :                                 filter_eps=threshold, flop=flop1)
    1979             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, tmp3, 0.0_dp, BK2A, &
    1980             :                                 filter_eps=threshold, flop=flop2)
    1981             :             CALL dbcsr_copy(Rmat, BK2A)
    1982             :             CALL dbcsr_add_on_diag(Rmat, -1.0_dp)
    1983             : 
    1984             :             CALL dbcsr_set(tmp1, 0.0_dp)
    1985             :             CALL dbcsr_add_on_diag(tmp1, 1.0_dp)
    1986             : 
    1987             :             CALL dbcsr_set(tmp2, 0.0_dp)
    1988             :             CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
    1989             : 
    1990             :             flop3 = 0
    1991             :             DO j = 1, order
    1992             :                !choose=factorial(order)/(factorial(j)*factorial(order-j))
    1993             :                choose = PRODUCT((/(ii, ii=1, order)/))/(PRODUCT((/(ii, ii=1, j)/))*PRODUCT((/(ii, ii=1, order - j)/)))
    1994             :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, -1.0_dp*(-1)**j*choose)
    1995             :                IF (j .LT. order) THEN
    1996             :                   f = 0
    1997             :                   CALL dbcsr_copy(tmp3, tmp2)
    1998             :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, BK2A, 0.0_dp, tmp2, &
    1999             :                                       filter_eps=threshold, flop=f)
    2000             :                   flop3 = flop3 + f
    2001             :                END IF
    2002             :             END DO
    2003             :             CALL dbcsr_release(BK2A)
    2004             :          END IF
    2005             : 
    2006          86 :          CALL dbcsr_copy(tmp3, matrix_sqrt_inv)
    2007             :          CALL dbcsr_multiply("N", "N", 0.5_dp, tmp3, tmp1, 0.0_dp, matrix_sqrt_inv, &
    2008          86 :                              filter_eps=threshold, flop=flop4)
    2009             : 
    2010          86 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2011             : 
    2012             :          ! done iterating
    2013          86 :          t2 = m_walltime()
    2014             : 
    2015          86 :          conv = dbcsr_frobenius_norm(Rmat)
    2016             : 
    2017          86 :          IF (unit_nr > 0) THEN
    2018          43 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "PROOT sqrt iter ", i, occ_matrix, &
    2019          43 :                conv, t2 - t1, &
    2020          86 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    2021          43 :             CALL m_flush(unit_nr)
    2022             :          END IF
    2023             : 
    2024          86 :          IF (abnormal_value(conv)) &
    2025           0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    2026             : 
    2027             :          ! conv < SQRT(threshold)
    2028          86 :          IF ((conv*conv) < threshold) THEN
    2029          16 :             IF (PRESENT(converged)) converged = .TRUE.
    2030             :             EXIT
    2031             :          END IF
    2032             : 
    2033             :       END DO
    2034             : 
    2035             :       ! scale to proper end results
    2036          16 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    2037             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix, 0.0_dp, matrix_sqrt, &
    2038          16 :                           filter_eps=threshold, flop=flop5)
    2039             : 
    2040             :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    2041          16 :       IF (tsym) THEN
    2042           8 :          IF (unit_nr > 0) THEN
    2043           4 :             WRITE (unit_nr, '(A20)') "SYMMETRIZING RESULTS"
    2044             :          END IF
    2045           8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    2046           8 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    2047           8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    2048           8 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    2049             :       END IF
    2050             : 
    2051             :       ! this check is not really needed
    2052             :       IF (test) THEN
    2053             :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    2054             :                              filter_eps=threshold)
    2055             :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2056             :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2057             :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2058             :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2059             :          IF (unit_nr > 0) THEN
    2060             :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{1/2}-Eins error ", i, occ_matrix, &
    2061             :                frob_matrix/frob_matrix_base
    2062             :             WRITE (unit_nr, '()')
    2063             :             CALL m_flush(unit_nr)
    2064             :          END IF
    2065             : 
    2066             :          ! this check is not really needed
    2067             :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp2, &
    2068             :                              filter_eps=threshold)
    2069             :          CALL dbcsr_multiply("N", "N", +1.0_dp, tmp2, matrix, 0.0_dp, tmp1, &
    2070             :                              filter_eps=threshold)
    2071             :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2072             :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2073             :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2074             :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2075             :          IF (unit_nr > 0) THEN
    2076             :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{-1/2} S-Eins error ", i, occ_matrix, &
    2077             :                frob_matrix/frob_matrix_base
    2078             :             WRITE (unit_nr, '()')
    2079             :             CALL m_flush(unit_nr)
    2080             :          END IF
    2081             :       END IF
    2082             : 
    2083          16 :       CALL dbcsr_release(tmp1)
    2084          16 :       CALL dbcsr_release(tmp2)
    2085          16 :       CALL dbcsr_release(tmp3)
    2086          16 :       CALL dbcsr_release(Rmat)
    2087          16 :       CALL dbcsr_release(matrixS)
    2088             : 
    2089          16 :       CALL timestop(handle)
    2090          16 :    END SUBROUTINE matrix_sqrt_proot
    2091             : 
    2092             : ! **************************************************************************************************
    2093             : !> \brief ...
    2094             : !> \param matrix_exp ...
    2095             : !> \param matrix ...
    2096             : !> \param omega ...
    2097             : !> \param alpha ...
    2098             : !> \param threshold ...
    2099             : ! **************************************************************************************************
    2100        1146 :    SUBROUTINE matrix_exponential(matrix_exp, matrix, omega, alpha, threshold)
    2101             :       ! compute matrix_exp=omega*exp(alpha*matrix)
    2102             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_exp, matrix
    2103             :       REAL(KIND=dp), INTENT(IN)                          :: omega, alpha, threshold
    2104             : 
    2105             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_exponential'
    2106             :       REAL(dp), PARAMETER                                :: one = 1.0_dp, toll = 1.E-17_dp, &
    2107             :                                                             zero = 0.0_dp
    2108             : 
    2109             :       INTEGER                                            :: handle, i, k, unit_nr
    2110             :       REAL(dp)                                           :: factorial, norm_C, norm_D, norm_scalar
    2111             :       TYPE(cp_logger_type), POINTER                      :: logger
    2112             :       TYPE(dbcsr_type)                                   :: B, B_square, C, D, D_product
    2113             : 
    2114        1146 :       CALL timeset(routineN, handle)
    2115             : 
    2116        1146 :       logger => cp_get_default_logger()
    2117        1146 :       IF (logger%para_env%is_source()) THEN
    2118        1058 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2119             :       ELSE
    2120             :          unit_nr = -1
    2121             :       END IF
    2122             : 
    2123             :       ! Calculate the norm of the matrix alpha*matrix, and scale it until it is less than 1.0
    2124        1146 :       norm_scalar = ABS(alpha)*dbcsr_frobenius_norm(matrix)
    2125             : 
    2126             :       ! k=scaling parameter
    2127        1146 :       k = 1
    2128        1008 :       DO
    2129        2154 :          IF ((norm_scalar/2.0_dp**k) <= one) EXIT
    2130        1008 :          k = k + 1
    2131             :       END DO
    2132             : 
    2133             :       ! copy and scale the input matrix in matrix C and in matrix D
    2134        1146 :       CALL dbcsr_create(C, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2135        1146 :       CALL dbcsr_copy(C, matrix)
    2136        1146 :       CALL dbcsr_scale(C, alpha_scalar=alpha/2.0_dp**k)
    2137             : 
    2138        1146 :       CALL dbcsr_create(D, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2139        1146 :       CALL dbcsr_copy(D, C)
    2140             : 
    2141             :       !   write(*,*)
    2142             :       !   write(*,*)
    2143             :       !   CALL dbcsr_print(D, nodata=.FALSE., matlab_format=.TRUE., variable_name="D", unit_nr=6)
    2144             : 
    2145             :       ! set the B matrix as B=Identity+D
    2146        1146 :       CALL dbcsr_create(B, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2147        1146 :       CALL dbcsr_copy(B, D)
    2148        1146 :       CALL dbcsr_add_on_diag(B, alpha_scalar=one)
    2149             : 
    2150             :       !   CALL dbcsr_print(B, nodata=.FALSE., matlab_format=.TRUE., variable_name="B", unit_nr=6)
    2151             : 
    2152             :       ! Calculate the norm of C and moltiply by toll to be used as a threshold
    2153        1146 :       norm_C = toll*dbcsr_frobenius_norm(matrix)
    2154             : 
    2155             :       ! iteration for the truncated taylor series expansion
    2156        1146 :       CALL dbcsr_create(D_product, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2157        1146 :       i = 1
    2158             :       DO
    2159       12676 :          i = i + 1
    2160             :          ! compute D_product=D*C
    2161             :          CALL dbcsr_multiply("N", "N", one, D, C, &
    2162       12676 :                              zero, D_product, filter_eps=threshold)
    2163             : 
    2164             :          ! copy D_product in D
    2165       12676 :          CALL dbcsr_copy(D, D_product)
    2166             : 
    2167             :          ! calculate B=B+D_product/fat(i)
    2168       12676 :          factorial = ifac(i)
    2169       12676 :          CALL dbcsr_add(B, D_product, one, factorial)
    2170             : 
    2171             :          ! check for convergence using the norm of D (copy of the matrix D_product) and C
    2172       12676 :          norm_D = factorial*dbcsr_frobenius_norm(D)
    2173       12676 :          IF (norm_D < norm_C) EXIT
    2174             :       END DO
    2175             : 
    2176             :       ! start the k iteration for the squaring of the matrix
    2177        1146 :       CALL dbcsr_create(B_square, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2178        3300 :       DO i = 1, k
    2179             :          !compute B_square=B*B
    2180             :          CALL dbcsr_multiply("N", "N", one, B, B, &
    2181        2154 :                              zero, B_square, filter_eps=threshold)
    2182             :          ! copy Bsquare in B to iterate
    2183        3300 :          CALL dbcsr_copy(B, B_square)
    2184             :       END DO
    2185             : 
    2186             :       ! copy B_square in matrix_exp and
    2187        1146 :       CALL dbcsr_copy(matrix_exp, B_square)
    2188             : 
    2189             :       ! scale matrix_exp by omega, matrix_exp=omega*B_square
    2190        1146 :       CALL dbcsr_scale(matrix_exp, alpha_scalar=omega)
    2191             :       ! write(*,*) alpha,omega
    2192             : 
    2193        1146 :       CALL dbcsr_release(B)
    2194        1146 :       CALL dbcsr_release(C)
    2195        1146 :       CALL dbcsr_release(D)
    2196        1146 :       CALL dbcsr_release(D_product)
    2197        1146 :       CALL dbcsr_release(B_square)
    2198             : 
    2199        1146 :       CALL timestop(handle)
    2200             : 
    2201        1146 :    END SUBROUTINE matrix_exponential
    2202             : 
    2203             : ! **************************************************************************************************
    2204             : !> \brief McWeeny purification of a matrix in the orthonormal basis
    2205             : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2206             : !> \param threshold Threshold used as filter_eps and convergence criteria
    2207             : !> \param max_steps Max number of iterations
    2208             : !> \par History
    2209             : !>       2013.01 created [Florian Schiffmann]
    2210             : !>       2014.07 slightly refactored [Ole Schuett]
    2211             : !> \author Florian Schiffmann
    2212             : ! **************************************************************************************************
    2213         174 :    SUBROUTINE purify_mcweeny_orth(matrix_p, threshold, max_steps)
    2214             :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2215             :       REAL(KIND=dp)                                      :: threshold
    2216             :       INTEGER                                            :: max_steps
    2217             : 
    2218             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_orth'
    2219             : 
    2220             :       INTEGER                                            :: handle, i, ispin, unit_nr
    2221             :       REAL(KIND=dp)                                      :: frob_norm, trace
    2222             :       TYPE(cp_logger_type), POINTER                      :: logger
    2223             :       TYPE(dbcsr_type)                                   :: matrix_pp, matrix_tmp
    2224             : 
    2225         174 :       CALL timeset(routineN, handle)
    2226         174 :       logger => cp_get_default_logger()
    2227         174 :       IF (logger%para_env%is_source()) THEN
    2228          87 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2229             :       ELSE
    2230          87 :          unit_nr = -1
    2231             :       END IF
    2232             : 
    2233         174 :       CALL dbcsr_create(matrix_pp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2234         174 :       CALL dbcsr_create(matrix_tmp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2235         174 :       CALL dbcsr_trace(matrix_p(1), trace)
    2236             : 
    2237         356 :       DO ispin = 1, SIZE(matrix_p)
    2238         356 :          DO i = 1, max_steps
    2239             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_p(ispin), &
    2240         182 :                                 0.0_dp, matrix_pp, filter_eps=threshold)
    2241             : 
    2242             :             ! test convergence
    2243         182 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2244         182 :             CALL dbcsr_add(matrix_tmp, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2245         182 :             frob_norm = dbcsr_frobenius_norm(matrix_tmp) ! tmp = PP - P
    2246         182 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2247         182 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2248             : 
    2249             :             ! construct new P
    2250         182 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2251             :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_pp, matrix_p(ispin), &
    2252         182 :                                 3.0_dp, matrix_tmp, filter_eps=threshold)
    2253         182 :             CALL dbcsr_copy(matrix_p(ispin), matrix_tmp) ! tmp = 3PP - 2PPP
    2254             : 
    2255             :             ! frob_norm < SQRT(trace*threshold)
    2256         182 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2257             :          END DO
    2258             :       END DO
    2259             : 
    2260         174 :       CALL dbcsr_release(matrix_pp)
    2261         174 :       CALL dbcsr_release(matrix_tmp)
    2262         174 :       CALL timestop(handle)
    2263         174 :    END SUBROUTINE purify_mcweeny_orth
    2264             : 
    2265             : ! **************************************************************************************************
    2266             : !> \brief McWeeny purification of a matrix in the non-orthonormal basis
    2267             : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2268             : !> \param matrix_s Overlap-Matrix
    2269             : !> \param threshold Threshold used as filter_eps and convergence criteria
    2270             : !> \param max_steps Max number of iterations
    2271             : !> \par History
    2272             : !>       2013.01 created [Florian Schiffmann]
    2273             : !>       2014.07 slightly refactored [Ole Schuett]
    2274             : !> \author Florian Schiffmann
    2275             : ! **************************************************************************************************
    2276         184 :    SUBROUTINE purify_mcweeny_nonorth(matrix_p, matrix_s, threshold, max_steps)
    2277             :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2278             :       TYPE(dbcsr_type)                                   :: matrix_s
    2279             :       REAL(KIND=dp)                                      :: threshold
    2280             :       INTEGER                                            :: max_steps
    2281             : 
    2282             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_nonorth'
    2283             : 
    2284             :       INTEGER                                            :: handle, i, ispin, unit_nr
    2285             :       REAL(KIND=dp)                                      :: frob_norm, trace
    2286             :       TYPE(cp_logger_type), POINTER                      :: logger
    2287             :       TYPE(dbcsr_type)                                   :: matrix_ps, matrix_psp, matrix_test
    2288             : 
    2289         184 :       CALL timeset(routineN, handle)
    2290             : 
    2291         184 :       logger => cp_get_default_logger()
    2292         184 :       IF (logger%para_env%is_source()) THEN
    2293          92 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2294             :       ELSE
    2295          92 :          unit_nr = -1
    2296             :       END IF
    2297             : 
    2298         184 :       CALL dbcsr_create(matrix_ps, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2299         184 :       CALL dbcsr_create(matrix_psp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2300         184 :       CALL dbcsr_create(matrix_test, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2301             : 
    2302         368 :       DO ispin = 1, SIZE(matrix_p)
    2303         380 :          DO i = 1, max_steps
    2304             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_s, &
    2305         196 :                                 0.0_dp, matrix_ps, filter_eps=threshold)
    2306             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_ps, matrix_p(ispin), &
    2307         196 :                                 0.0_dp, matrix_psp, filter_eps=threshold)
    2308         196 :             IF (i == 1) CALL dbcsr_trace(matrix_ps, trace)
    2309             : 
    2310             :             ! test convergence
    2311         196 :             CALL dbcsr_copy(matrix_test, matrix_psp)
    2312         196 :             CALL dbcsr_add(matrix_test, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2313         196 :             frob_norm = dbcsr_frobenius_norm(matrix_test) ! test = PSP - P
    2314         196 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,2f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2315         196 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2316             : 
    2317             :             ! construct new P
    2318         196 :             CALL dbcsr_copy(matrix_p(ispin), matrix_psp)
    2319             :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_ps, matrix_psp, &
    2320         196 :                                 3.0_dp, matrix_p(ispin), filter_eps=threshold)
    2321             : 
    2322             :             ! frob_norm < SQRT(trace*threshold)
    2323         196 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2324             :          END DO
    2325             :       END DO
    2326             : 
    2327         184 :       CALL dbcsr_release(matrix_ps)
    2328         184 :       CALL dbcsr_release(matrix_psp)
    2329         184 :       CALL dbcsr_release(matrix_test)
    2330         184 :       CALL timestop(handle)
    2331         184 :    END SUBROUTINE purify_mcweeny_nonorth
    2332             : 
    2333           0 : END MODULE iterate_matrix

Generated by: LCOV version 1.15