LCOV - code coverage report
Current view: top level - src - iterate_matrix.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 94.3 % 874 824
Test Date: 2025-12-04 06:27:48 Functions: 89.5 % 19 17

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : ! **************************************************************************************************
       8              : !> \brief Routines useful for iterative matrix calculations
       9              : !> \par History
      10              : !>       2010.10 created [Joost VandeVondele]
      11              : !> \author Joost VandeVondele
      12              : ! **************************************************************************************************
      13              : MODULE iterate_matrix
      14              :    USE arnoldi_api,                     ONLY: arnoldi_env_type,&
      15              :                                               arnoldi_extremal
      16              :    USE bibliography,                    ONLY: Richters2018,&
      17              :                                               cite_reference
      18              :    USE cp_dbcsr_api,                    ONLY: &
      19              :         dbcsr_add, dbcsr_copy, dbcsr_create, dbcsr_desymmetrize, dbcsr_distribution_get, &
      20              :         dbcsr_distribution_type, dbcsr_filter, dbcsr_get_info, dbcsr_get_matrix_type, &
      21              :         dbcsr_get_occupation, dbcsr_multiply, dbcsr_p_type, dbcsr_release, dbcsr_scale, dbcsr_set, &
      22              :         dbcsr_transposed, dbcsr_type, dbcsr_type_no_symmetry
      23              :    USE cp_dbcsr_contrib,                ONLY: dbcsr_add_on_diag,&
      24              :                                               dbcsr_frobenius_norm,&
      25              :                                               dbcsr_gershgorin_norm,&
      26              :                                               dbcsr_get_diag,&
      27              :                                               dbcsr_maxabs,&
      28              :                                               dbcsr_set_diag,&
      29              :                                               dbcsr_trace
      30              :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      31              :                                               cp_logger_get_default_unit_nr,&
      32              :                                               cp_logger_type
      33              :    USE input_constants,                 ONLY: ls_scf_submatrix_sign_direct,&
      34              :                                               ls_scf_submatrix_sign_direct_muadj,&
      35              :                                               ls_scf_submatrix_sign_direct_muadj_lowmem,&
      36              :                                               ls_scf_submatrix_sign_ns
      37              :    USE kinds,                           ONLY: dp,&
      38              :                                               int_8
      39              :    USE machine,                         ONLY: m_flush,&
      40              :                                               m_walltime
      41              :    USE mathconstants,                   ONLY: ifac
      42              :    USE mathlib,                         ONLY: abnormal_value
      43              :    USE message_passing,                 ONLY: mp_comm_type
      44              :    USE submatrix_dissection,            ONLY: submatrix_dissection_type
      45              : #include "./base/base_uses.f90"
      46              : 
      47              :    IMPLICIT NONE
      48              : 
      49              :    PRIVATE
      50              : 
      51              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'iterate_matrix'
      52              : 
      53              :    TYPE :: eigbuf
      54              :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE    :: eigvals
      55              :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: eigvecs
      56              :    END TYPE eigbuf
      57              : 
      58              :    INTERFACE purify_mcweeny
      59              :       MODULE PROCEDURE purify_mcweeny_orth, purify_mcweeny_nonorth
      60              :    END INTERFACE
      61              : 
      62              :    PUBLIC :: invert_Hotelling, matrix_sign_Newton_Schulz, matrix_sqrt_Newton_Schulz, &
      63              :              matrix_sqrt_proot, matrix_sign_proot, matrix_sign_submatrix, matrix_exponential, &
      64              :              matrix_sign_submatrix_mu_adjust, purify_mcweeny, invert_Taylor, determinant
      65              : 
      66              : CONTAINS
      67              : 
      68              : ! *****************************************************************************
      69              : !> \brief Computes the determinant of a symmetric positive definite matrix
      70              : !>        using the trace of the matrix logarithm via Mercator series:
      71              : !>         det(A) = det(S)det(I+X)det(S), where S=diag(sqrt(Aii),..,sqrt(Ann))
      72              : !>         det(I+X) = Exp(Trace(Ln(I+X)))
      73              : !>         Ln(I+X) = X - X^2/2 + X^3/3 - X^4/4 + ..
      74              : !>        The series converges only if the Frobenius norm of X is less than 1.
      75              : !>        If it is more than one we compute (recursevily) the determinant of
      76              : !>        the square root of (I+X).
      77              : !> \param matrix ...
      78              : !> \param det - determinant
      79              : !> \param threshold ...
      80              : !> \par History
      81              : !>       2015.04 created [Rustam Z Khaliullin]
      82              : !> \author Rustam Z. Khaliullin
      83              : ! **************************************************************************************************
      84          132 :    RECURSIVE SUBROUTINE determinant(matrix, det, threshold)
      85              : 
      86              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix
      87              :       REAL(KIND=dp), INTENT(INOUT)                       :: det
      88              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
      89              : 
      90              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'determinant'
      91              : 
      92              :       INTEGER                                            :: handle, i, max_iter_lanczos, nsize, &
      93              :                                                             order_lanczos, sign_iter, unit_nr
      94              :       INTEGER(KIND=int_8)                                :: flop1
      95              :       INTEGER, SAVE                                      :: recursion_depth = 0
      96              :       REAL(KIND=dp)                                      :: det0, eps_lanczos, frobnorm, maxnorm, &
      97              :                                                             occ_matrix, t1, t2, trace
      98          132 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: diagonal
      99              :       TYPE(cp_logger_type), POINTER                      :: logger
     100              :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
     101              : 
     102          132 :       CALL timeset(routineN, handle)
     103              : 
     104              :       ! get a useful output_unit
     105          132 :       logger => cp_get_default_logger()
     106          132 :       IF (logger%para_env%is_source()) THEN
     107           66 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     108              :       ELSE
     109           66 :          unit_nr = -1
     110              :       END IF
     111              : 
     112              :       ! Note: tmp1 and tmp2 have the same matrix type as the
     113              :       ! initial matrix (tmp3 does not have symmetry constraints)
     114              :       ! this might lead to uninteded results with anti-symmetric
     115              :       ! matrices
     116              :       CALL dbcsr_create(tmp1, template=matrix, &
     117          132 :                         matrix_type=dbcsr_type_no_symmetry)
     118              :       CALL dbcsr_create(tmp2, template=matrix, &
     119          132 :                         matrix_type=dbcsr_type_no_symmetry)
     120              :       CALL dbcsr_create(tmp3, template=matrix, &
     121          132 :                         matrix_type=dbcsr_type_no_symmetry)
     122              : 
     123              :       ! compute the product of the diagonal elements
     124              :       BLOCK
     125              :          TYPE(mp_comm_type) :: group
     126          132 :          CALL dbcsr_get_info(matrix, nfullrows_total=nsize, group=group)
     127          396 :          ALLOCATE (diagonal(nsize))
     128          132 :          CALL dbcsr_get_diag(matrix, diagonal)
     129          132 :          CALL group%sum(diagonal)
     130         2308 :          det = PRODUCT(diagonal)
     131              :       END BLOCK
     132              : 
     133              :       ! create diagonal SQRTI matrix
     134         2176 :       diagonal(:) = 1.0_dp/(SQRT(diagonal(:)))
     135              :       !ROLL CALL dbcsr_copy(tmp1,matrix)
     136          132 :       CALL dbcsr_desymmetrize(matrix, tmp1)
     137          132 :       CALL dbcsr_set(tmp1, 0.0_dp)
     138          132 :       CALL dbcsr_set_diag(tmp1, diagonal)
     139          132 :       CALL dbcsr_filter(tmp1, threshold)
     140          132 :       DEALLOCATE (diagonal)
     141              : 
     142              :       ! normalize the main diagonal, off-diagonal elements are scaled to
     143              :       ! make the norm of the matrix less than 1
     144              :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     145              :                           matrix, &
     146              :                           tmp1, &
     147              :                           0.0_dp, tmp3, &
     148          132 :                           filter_eps=threshold)
     149              :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     150              :                           tmp1, &
     151              :                           tmp3, &
     152              :                           0.0_dp, tmp2, &
     153          132 :                           filter_eps=threshold)
     154              : 
     155              :       ! subtract the main diagonal to create matrix X
     156          132 :       CALL dbcsr_add_on_diag(tmp2, -1.0_dp)
     157          132 :       frobnorm = dbcsr_frobenius_norm(tmp2)
     158          132 :       IF (unit_nr > 0) THEN
     159           66 :          IF (recursion_depth == 0) THEN
     160           41 :             WRITE (unit_nr, '()')
     161              :          ELSE
     162              :             WRITE (unit_nr, '(T6,A28,1X,I15)') &
     163           25 :                "Recursive iteration:", recursion_depth
     164              :          END IF
     165              :          WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     166           66 :             "Frobenius norm:", frobnorm
     167           66 :          CALL m_flush(unit_nr)
     168              :       END IF
     169              : 
     170          132 :       IF (frobnorm >= 1.0_dp) THEN
     171              : 
     172           50 :          CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
     173              :          ! these controls should be provided as input
     174           50 :          order_lanczos = 3
     175           50 :          eps_lanczos = 1.0E-4_dp
     176           50 :          max_iter_lanczos = 40
     177              :          CALL matrix_sqrt_Newton_Schulz( &
     178              :             tmp3, & ! output sqrt
     179              :             tmp1, & ! output sqrti
     180              :             tmp2, & ! input original
     181              :             threshold=threshold, &
     182              :             order=order_lanczos, &
     183              :             eps_lanczos=eps_lanczos, &
     184           50 :             max_iter_lanczos=max_iter_lanczos)
     185           50 :          recursion_depth = recursion_depth + 1
     186           50 :          CALL determinant(tmp3, det0, threshold)
     187           50 :          recursion_depth = recursion_depth - 1
     188           50 :          det = det*det0*det0
     189              : 
     190              :       ELSE
     191              : 
     192              :          ! create accumulator
     193           82 :          CALL dbcsr_copy(tmp1, tmp2)
     194              :          ! re-create to make use of symmetry
     195              :          !ROLL CALL dbcsr_create(tmp3,template=matrix)
     196              : 
     197           82 :          IF (unit_nr > 0) WRITE (unit_nr, *)
     198              : 
     199              :          ! initialize the sign of the term
     200           82 :          sign_iter = -1
     201         1078 :          DO i = 1, 100
     202              : 
     203         1078 :             t1 = m_walltime()
     204              : 
     205              :             ! multiply X^i by X
     206              :             ! note that the first iteration evaluates X^2
     207              :             ! because the trace of X^1 is zero by construction
     208              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, &
     209              :                                 0.0_dp, tmp3, &
     210              :                                 filter_eps=threshold, &
     211         1078 :                                 flop=flop1)
     212         1078 :             CALL dbcsr_copy(tmp1, tmp3)
     213              : 
     214              :             ! get trace
     215         1078 :             CALL dbcsr_trace(tmp1, trace)
     216         1078 :             trace = trace*sign_iter/(1.0_dp*(i + 1))
     217         1078 :             sign_iter = -sign_iter
     218              : 
     219              :             ! update the determinant
     220         1078 :             det = det*EXP(trace)
     221              : 
     222         1078 :             occ_matrix = dbcsr_get_occupation(tmp1)
     223         1078 :             maxnorm = dbcsr_maxabs(tmp1)
     224              : 
     225         1078 :             t2 = m_walltime()
     226              : 
     227         1078 :             IF (unit_nr > 0) THEN
     228              :                WRITE (unit_nr, '(T6,A,1X,I3,1X,F7.5,F16.10,F10.3,F11.3)') &
     229          539 :                   "Determinant iter", i, occ_matrix, &
     230          539 :                   det, t2 - t1, &
     231         1078 :                   flop1/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     232          539 :                CALL m_flush(unit_nr)
     233              :             END IF
     234              : 
     235              :             ! exit if the trace is close to zero
     236         2156 :             IF (maxnorm < threshold) EXIT
     237              : 
     238              :          END DO ! end iterations
     239              : 
     240           82 :          IF (unit_nr > 0) THEN
     241           41 :             WRITE (unit_nr, '()')
     242           41 :             CALL m_flush(unit_nr)
     243              :          END IF
     244              : 
     245              :       END IF ! decide to do sqrt or not
     246              : 
     247          132 :       IF (unit_nr > 0) THEN
     248           66 :          IF (recursion_depth == 0) THEN
     249              :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     250           41 :                "Final determinant:", det
     251           41 :             WRITE (unit_nr, '()')
     252              :          ELSE
     253              :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     254           25 :                "Recursive determinant:", det
     255              :          END IF
     256           66 :          CALL m_flush(unit_nr)
     257              :       END IF
     258              : 
     259          132 :       CALL dbcsr_release(tmp1)
     260          132 :       CALL dbcsr_release(tmp2)
     261          132 :       CALL dbcsr_release(tmp3)
     262              : 
     263          132 :       CALL timestop(handle)
     264              : 
     265          132 :    END SUBROUTINE determinant
     266              : 
     267              : ! **************************************************************************************************
     268              : !> \brief invert a symmetric positive definite diagonally dominant matrix
     269              : !> \param matrix_inverse ...
     270              : !> \param matrix ...
     271              : !> \param threshold convergence threshold nased on the max abs
     272              : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     273              : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     274              : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     275              : !> \param accelerator_order ...
     276              : !> \param max_iter_lanczos ...
     277              : !> \param eps_lanczos ...
     278              : !> \param silent ...
     279              : !> \par History
     280              : !>       2010.10 created [Joost VandeVondele]
     281              : !>       2011.10 guess option added [Rustam Z Khaliullin]
     282              : !> \author Joost VandeVondele
     283              : ! **************************************************************************************************
     284           26 :    SUBROUTINE invert_Taylor(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     285              :                             norm_convergence, filter_eps, accelerator_order, &
     286              :                             max_iter_lanczos, eps_lanczos, silent)
     287              : 
     288              :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     289              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     290              :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     291              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     292              :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     293              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     294              :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     295              : 
     296              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Taylor'
     297              : 
     298              :       INTEGER                                            :: accelerator_type, handle, i, &
     299              :                                                             my_max_iter_lanczos, nrows, unit_nr
     300              :       INTEGER(KIND=int_8)                                :: flop2
     301              :       LOGICAL                                            :: converged, use_inv_guess
     302              :       REAL(KIND=dp)                                      :: coeff, convergence, maxnorm_matrix, &
     303              :                                                             my_eps_lanczos, occ_matrix, t1, t2
     304           26 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: p_diagonal
     305              :       TYPE(cp_logger_type), POINTER                      :: logger
     306              :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2, tmp3_sym
     307              : 
     308           26 :       CALL timeset(routineN, handle)
     309              : 
     310           26 :       logger => cp_get_default_logger()
     311           26 :       IF (logger%para_env%is_source()) THEN
     312           13 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     313              :       ELSE
     314           13 :          unit_nr = -1
     315              :       END IF
     316           26 :       IF (PRESENT(silent)) THEN
     317           26 :          IF (silent) unit_nr = -1
     318              :       END IF
     319              : 
     320           26 :       convergence = threshold
     321           26 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     322              : 
     323           26 :       accelerator_type = 0
     324           26 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     325            0 :       IF (accelerator_type > 1) accelerator_type = 1
     326              : 
     327           26 :       use_inv_guess = .FALSE.
     328           26 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     329              : 
     330           26 :       my_max_iter_lanczos = 64
     331           26 :       my_eps_lanczos = 1.0E-3_dp
     332           26 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     333           26 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     334              : 
     335           26 :       CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     336           26 :       CALL dbcsr_create(tmp2, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     337           26 :       CALL dbcsr_create(tmp3_sym, template=matrix_inverse)
     338              : 
     339           26 :       CALL dbcsr_get_info(matrix, nfullrows_total=nrows)
     340           78 :       ALLOCATE (p_diagonal(nrows))
     341              : 
     342              :       ! generate the initial guess
     343           26 :       IF (.NOT. use_inv_guess) THEN
     344              : 
     345           26 :          SELECT CASE (accelerator_type)
     346              :          CASE (0)
     347              :             ! use tmp1 to hold off-diagonal elements
     348           26 :             CALL dbcsr_desymmetrize(matrix, tmp1)
     349          858 :             p_diagonal(:) = 0.0_dp
     350           26 :             CALL dbcsr_set_diag(tmp1, p_diagonal)
     351              :             !CALL dbcsr_print(tmp1)
     352              :             ! invert the main diagonal
     353           26 :             CALL dbcsr_get_diag(matrix, p_diagonal)
     354          858 :             DO i = 1, nrows
     355          858 :                IF (p_diagonal(i) /= 0.0_dp) THEN
     356          416 :                   p_diagonal(i) = 1.0_dp/p_diagonal(i)
     357              :                END IF
     358              :             END DO
     359           26 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     360           26 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     361           26 :             CALL dbcsr_set_diag(matrix_inverse, p_diagonal)
     362              :          CASE DEFAULT
     363           26 :             CPABORT("Illegal accelerator order")
     364              :          END SELECT
     365              : 
     366              :       ELSE
     367              : 
     368            0 :          CPABORT("Guess is NYI")
     369              : 
     370              :       END IF
     371              : 
     372              :       CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, &
     373           26 :                           0.0_dp, tmp2, filter_eps=filter_eps)
     374              : 
     375           26 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     376              : 
     377              :       ! scale the approximate inverse to be within the convergence radius
     378           26 :       t1 = m_walltime()
     379              : 
     380              :       ! done with the initial guess, start iterations
     381           26 :       converged = .FALSE.
     382           26 :       CALL dbcsr_desymmetrize(matrix_inverse, tmp1)
     383           26 :       coeff = 1.0_dp
     384          284 :       DO i = 1, 100
     385              : 
     386              :          ! coeff = +/- 1
     387          284 :          coeff = -1.0_dp*coeff
     388              :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, &
     389              :                              tmp3_sym, &
     390          284 :                              flop=flop2, filter_eps=filter_eps)
     391              :          !flop=flop2)
     392          284 :          CALL dbcsr_add(matrix_inverse, tmp3_sym, 1.0_dp, coeff)
     393          284 :          CALL dbcsr_release(tmp1)
     394          284 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     395          284 :          CALL dbcsr_desymmetrize(tmp3_sym, tmp1)
     396              : 
     397              :          ! for the convergence check
     398          284 :          maxnorm_matrix = dbcsr_maxabs(tmp3_sym)
     399              : 
     400          284 :          t2 = m_walltime()
     401          284 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     402              : 
     403          284 :          IF (unit_nr > 0) THEN
     404          142 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Taylor iter", i, occ_matrix, &
     405          142 :                maxnorm_matrix, t2 - t1, &
     406          284 :                flop2/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     407          142 :             CALL m_flush(unit_nr)
     408              :          END IF
     409              : 
     410          284 :          IF (maxnorm_matrix < convergence) THEN
     411              :             converged = .TRUE.
     412              :             EXIT
     413              :          END IF
     414              : 
     415          258 :          t1 = m_walltime()
     416              : 
     417              :       END DO
     418              : 
     419              :       !last convergence check
     420              :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_inverse, 0.0_dp, tmp1, &
     421           26 :                           filter_eps=filter_eps)
     422           26 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     423              :       !frob_matrix =  dbcsr_frobenius_norm(tmp1)
     424           26 :       maxnorm_matrix = dbcsr_maxabs(tmp1)
     425           26 :       IF (unit_nr > 0) THEN
     426           13 :          WRITE (unit_nr, '(T6,A,E12.5)') "Final Taylor error", maxnorm_matrix
     427           13 :          WRITE (unit_nr, '()')
     428           13 :          CALL m_flush(unit_nr)
     429              :       END IF
     430           26 :       IF (maxnorm_matrix > convergence) THEN
     431            0 :          converged = .FALSE.
     432            0 :          IF (unit_nr > 0) THEN
     433            0 :             WRITE (unit_nr, *) 'Final convergence check failed'
     434              :          END IF
     435              :       END IF
     436              : 
     437           26 :       IF (.NOT. converged) THEN
     438            0 :          CPABORT("Taylor inversion did not converge")
     439              :       END IF
     440              : 
     441           26 :       CALL dbcsr_release(tmp1)
     442           26 :       CALL dbcsr_release(tmp2)
     443           26 :       CALL dbcsr_release(tmp3_sym)
     444              : 
     445           26 :       DEALLOCATE (p_diagonal)
     446              : 
     447           26 :       CALL timestop(handle)
     448              : 
     449           52 :    END SUBROUTINE invert_Taylor
     450              : 
     451              : ! **************************************************************************************************
     452              : !> \brief invert a symmetric positive definite matrix by Hotelling's method
     453              : !>        explicit symmetrization makes this code not suitable for other matrix types
     454              : !>        Currently a bit messy with the options, to to be cleaned soon
     455              : !> \param matrix_inverse ...
     456              : !> \param matrix ...
     457              : !> \param threshold convergence threshold nased on the max abs
     458              : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     459              : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     460              : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     461              : !> \param accelerator_order ...
     462              : !> \param max_iter_lanczos ...
     463              : !> \param eps_lanczos ...
     464              : !> \param silent ...
     465              : !> \par History
     466              : !>       2010.10 created [Joost VandeVondele]
     467              : !>       2011.10 guess option added [Rustam Z Khaliullin]
     468              : !> \author Joost VandeVondele
     469              : ! **************************************************************************************************
     470         2032 :    SUBROUTINE invert_Hotelling(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     471              :                                norm_convergence, filter_eps, accelerator_order, &
     472              :                                max_iter_lanczos, eps_lanczos, silent)
     473              : 
     474              :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     475              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     476              :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     477              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     478              :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     479              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     480              :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     481              : 
     482              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Hotelling'
     483              : 
     484              :       INTEGER                                            :: accelerator_type, handle, i, &
     485              :                                                             my_max_iter_lanczos, unit_nr
     486              :       INTEGER(KIND=int_8)                                :: flop1, flop2
     487              :       LOGICAL                                            :: arnoldi_converged, converged, &
     488              :                                                             use_inv_guess
     489              :       REAL(KIND=dp) :: convergence, frob_matrix, gershgorin_norm, max_ev, maxnorm_matrix, min_ev, &
     490              :          my_eps_lanczos, my_filter_eps, occ_matrix, scalingf, t1, t2
     491              :       TYPE(cp_logger_type), POINTER                      :: logger
     492              :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2
     493              : 
     494              :       !TYPE(arnoldi_env_type)                            :: arnoldi_env
     495              :       !TYPE(dbcsr_p_type), DIMENSION(1)                   :: mymat
     496              : 
     497         2032 :       CALL timeset(routineN, handle)
     498              : 
     499         2032 :       logger => cp_get_default_logger()
     500         2032 :       IF (logger%para_env%is_source()) THEN
     501         1016 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     502              :       ELSE
     503         1016 :          unit_nr = -1
     504              :       END IF
     505         2032 :       IF (PRESENT(silent)) THEN
     506         2014 :          IF (silent) unit_nr = -1
     507              :       END IF
     508              : 
     509         2032 :       convergence = threshold
     510         2032 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     511              : 
     512         2032 :       accelerator_type = 1
     513         2032 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     514         1436 :       IF (accelerator_type > 1) accelerator_type = 1
     515              : 
     516         2032 :       use_inv_guess = .FALSE.
     517         2032 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     518              : 
     519         2032 :       my_max_iter_lanczos = 64
     520         2032 :       my_eps_lanczos = 1.0E-3_dp
     521         2032 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     522         2032 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     523              : 
     524         2032 :       my_filter_eps = threshold
     525         2032 :       IF (PRESENT(filter_eps)) my_filter_eps = filter_eps
     526              : 
     527              :       ! generate the initial guess
     528         2032 :       IF (.NOT. use_inv_guess) THEN
     529              : 
     530            0 :          SELECT CASE (accelerator_type)
     531              :          CASE (0)
     532            0 :             gershgorin_norm = dbcsr_gershgorin_norm(matrix)
     533            0 :             frob_matrix = dbcsr_frobenius_norm(matrix)
     534            0 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     535            0 :             CALL dbcsr_add_on_diag(matrix_inverse, 1/MIN(gershgorin_norm, frob_matrix))
     536              :          CASE (1)
     537              :             ! initialize matrix to unity and use arnoldi (below) to scale it into the convergence range
     538         1558 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     539         1558 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     540              :          CASE DEFAULT
     541         1558 :             CPABORT("Illegal accelerator order")
     542              :          END SELECT
     543              : 
     544              :          ! everything commutes, therefore our all products will be symmetric
     545         1558 :          CALL dbcsr_create(tmp1, template=matrix_inverse)
     546              : 
     547              :       ELSE
     548              : 
     549              :          ! It is unlikely that our guess will commute with the matrix, therefore the first product will
     550              :          ! be non symmetric
     551          474 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     552              : 
     553              :       END IF
     554              : 
     555         2032 :       CALL dbcsr_create(tmp2, template=matrix_inverse)
     556              : 
     557         2032 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     558              : 
     559              :       ! scale the approximate inverse to be within the convergence radius
     560         2032 :       t1 = m_walltime()
     561              : 
     562              :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     563         2032 :                           0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     564              : 
     565         2032 :       IF (accelerator_type == 1) THEN
     566              : 
     567              :          ! scale the matrix to get into the convergence range
     568              :          CALL arnoldi_extremal(tmp1, max_eV, min_eV, threshold=my_eps_lanczos, &
     569         2032 :                                max_iter=my_max_iter_lanczos, converged=arnoldi_converged)
     570              :          !mymat(1)%matrix => tmp1
     571              :          !CALL setup_arnoldi_env(arnoldi_env, mymat, max_iter=30, threshold=1.0E-3_dp, selection_crit=1, &
     572              :          !                        nval_request=2, nrestarts=2, generalized_ev=.FALSE., iram=.TRUE.)
     573              :          !CALL arnoldi_ev(mymat, arnoldi_env)
     574              :          !max_eV = REAL(get_selected_ritz_val(arnoldi_env, 2), dp)
     575              :          !min_eV = REAL(get_selected_ritz_val(arnoldi_env, 1), dp)
     576              :          !CALL deallocate_arnoldi_env(arnoldi_env)
     577              : 
     578         2032 :          IF (unit_nr > 0) THEN
     579          768 :             WRITE (unit_nr, *)
     580          768 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", my_eps_lanczos
     581          768 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_eV, min_eV
     582          768 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_eV/MAX(min_eV, EPSILON(min_eV))
     583              :          END IF
     584              : 
     585              :          ! 2.0 would be the correct scaling however, we should make sure here, that we are in the convergence radius
     586         2032 :          scalingf = 1.9_dp/(max_eV + min_eV)
     587         2032 :          CALL dbcsr_scale(tmp1, scalingf)
     588         2032 :          CALL dbcsr_scale(matrix_inverse, scalingf)
     589         2032 :          min_ev = min_ev*scalingf
     590              : 
     591              :       END IF
     592              : 
     593              :       ! done with the initial guess, start iterations
     594         2032 :       converged = .FALSE.
     595         8998 :       DO i = 1, 100
     596              : 
     597              :          ! tmp1 = S^-1 S
     598              : 
     599              :          ! for the convergence check
     600         8998 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     601         8998 :          maxnorm_matrix = dbcsr_maxabs(tmp1)
     602         8998 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     603              : 
     604              :          ! tmp2 = S^-1 S S^-1
     605              :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, 0.0_dp, tmp2, &
     606         8998 :                              flop=flop2, filter_eps=my_filter_eps)
     607              :          ! S^-1_{n+1} = 2 S^-1 - S^-1 S S^-1
     608         8998 :          CALL dbcsr_add(matrix_inverse, tmp2, 2.0_dp, -1.0_dp)
     609              : 
     610         8998 :          CALL dbcsr_filter(matrix_inverse, my_filter_eps)
     611         8998 :          t2 = m_walltime()
     612         8998 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     613              : 
     614              :          ! use the scalar form of the algorithm to trace the EV
     615         8998 :          IF (accelerator_type == 1) THEN
     616         8998 :             min_ev = min_ev*(2.0_dp - min_ev)
     617         8998 :             IF (PRESENT(norm_convergence)) maxnorm_matrix = ABS(min_eV - 1.0_dp)
     618              :          END IF
     619              : 
     620         8998 :          IF (unit_nr > 0) THEN
     621         3718 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Hotelling iter", i, occ_matrix, &
     622         3718 :                maxnorm_matrix, t2 - t1, &
     623         7436 :                (flop1 + flop2)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     624         3718 :             CALL m_flush(unit_nr)
     625              :          END IF
     626              : 
     627         8998 :          IF (maxnorm_matrix < convergence) THEN
     628              :             converged = .TRUE.
     629              :             EXIT
     630              :          END IF
     631              : 
     632              :          ! scale the matrix for improved convergence
     633         6966 :          IF (accelerator_type == 1) THEN
     634         6966 :             min_ev = min_ev*2.0_dp/(min_ev + 1.0_dp)
     635         6966 :             CALL dbcsr_scale(matrix_inverse, 2.0_dp/(min_ev + 1.0_dp))
     636              :          END IF
     637              : 
     638         6966 :          t1 = m_walltime()
     639              :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     640         6966 :                              0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     641              : 
     642              :       END DO
     643              : 
     644         2032 :       IF (.NOT. converged) THEN
     645            0 :          CPABORT("Hotelling inversion did not converge")
     646              :       END IF
     647              : 
     648              :       ! try to symmetrize the output matrix
     649         2032 :       IF (dbcsr_get_matrix_type(matrix_inverse) == dbcsr_type_no_symmetry) THEN
     650          100 :          CALL dbcsr_transposed(tmp2, matrix_inverse)
     651         2132 :          CALL dbcsr_add(matrix_inverse, tmp2, 0.5_dp, 0.5_dp)
     652              :       END IF
     653              : 
     654         2032 :       IF (unit_nr > 0) THEN
     655              : !           WRITE(unit_nr,'(T6,A,1X,I3,1X,F10.8,E12.3)') "Final Hotelling ",i,occ_matrix,&
     656              : !              !frob_matrix/frob_matrix_base
     657              : !              maxnorm_matrix
     658          768 :          WRITE (unit_nr, '()')
     659          768 :          CALL m_flush(unit_nr)
     660              :       END IF
     661              : 
     662         2032 :       CALL dbcsr_release(tmp1)
     663         2032 :       CALL dbcsr_release(tmp2)
     664              : 
     665         2032 :       CALL timestop(handle)
     666              : 
     667         2032 :    END SUBROUTINE invert_Hotelling
     668              : 
     669              : ! **************************************************************************************************
     670              : !> \brief compute the sign a matrix using Newton-Schulz iterations
     671              : !> \param matrix_sign ...
     672              : !> \param matrix ...
     673              : !> \param threshold ...
     674              : !> \param sign_order ...
     675              : !> \param iounit ...
     676              : !> \par History
     677              : !>       2010.10 created [Joost VandeVondele]
     678              : !>       2019.05 extended to order byxond 2 [Robert Schade]
     679              : !> \author Joost VandeVondele, Robert Schade
     680              : ! **************************************************************************************************
     681         1058 :    SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order, iounit)
     682              : 
     683              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
     684              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     685              :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order, iounit
     686              : 
     687              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_Newton_Schulz'
     688              : 
     689              :       INTEGER                                            :: count, handle, i, order, unit_nr
     690              :       INTEGER(KIND=int_8)                                :: flops
     691              :       REAL(KIND=dp)                                      :: a0, a1, a2, a3, a4, a5, floptot, &
     692              :                                                             frob_matrix, frob_matrix_base, &
     693              :                                                             gersh_matrix, occ_matrix, prefactor, &
     694              :                                                             t1, t2
     695              :       TYPE(cp_logger_type), POINTER                      :: logger
     696              :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3, tmp4
     697              : 
     698         1058 :       CALL timeset(routineN, handle)
     699              : 
     700         1058 :       IF (PRESENT(iounit)) THEN
     701         1058 :          unit_nr = iounit
     702              :       ELSE
     703            0 :          logger => cp_get_default_logger()
     704            0 :          IF (logger%para_env%is_source()) THEN
     705            0 :             unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     706              :          ELSE
     707            0 :             unit_nr = -1
     708              :          END IF
     709              :       END IF
     710              : 
     711         1058 :       IF (PRESENT(sign_order)) THEN
     712         1058 :          order = sign_order
     713              :       ELSE
     714              :          order = 2
     715              :       END IF
     716              : 
     717         1058 :       CALL dbcsr_create(tmp1, template=matrix_sign)
     718              : 
     719         1058 :       CALL dbcsr_create(tmp2, template=matrix_sign)
     720         1058 :       IF (ABS(order) >= 4) THEN
     721            8 :          CALL dbcsr_create(tmp3, template=matrix_sign)
     722              :       END IF
     723            8 :       IF (ABS(order) > 4) THEN
     724            6 :          CALL dbcsr_create(tmp4, template=matrix_sign)
     725              :       END IF
     726              : 
     727         1058 :       CALL dbcsr_copy(matrix_sign, matrix)
     728         1058 :       CALL dbcsr_filter(matrix_sign, threshold)
     729              : 
     730              :       ! scale the matrix to get into the convergence range
     731         1058 :       frob_matrix = dbcsr_frobenius_norm(matrix_sign)
     732         1058 :       gersh_matrix = dbcsr_gershgorin_norm(matrix_sign)
     733         1058 :       CALL dbcsr_scale(matrix_sign, 1/MIN(frob_matrix, gersh_matrix))
     734              : 
     735         1058 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     736              : 
     737         1058 :       count = 0
     738        13202 :       DO i = 1, 100
     739        13202 :          floptot = 0_dp
     740        13202 :          t1 = m_walltime()
     741              :          ! tmp1 = X * X
     742              :          CALL dbcsr_multiply("N", "N", -1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     743        13202 :                              filter_eps=threshold, flop=flops)
     744        13202 :          floptot = floptot + flops
     745              : 
     746              :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
     747        13202 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     748        13202 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     749        13202 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
     750              : 
     751              :          ! f(y) approx 1/sqrt(1-y)
     752              :          ! f(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     753              :          ! f2(y)=1+y/2=1/2*(2+y)
     754              :          ! f3(y)=1+y/2+3/8*y^2=3/8*(8/3+4/3*y+y^2)
     755              :          ! f4(y)=1+y/2+3/8*y^2+5/16*y^3=5/16*(16/5+8/5*y+6/5*y^2+y^3)
     756              :          ! f5(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4=35/128*(128/35+128/70*y+48/35*y^2+8/7*y^3+y^4)
     757              :          !      z(y)=(y+a_0)*y+a_1
     758              :          ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     759              :          !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     760              :          !    a_0=1/14
     761              :          !    a_1=23819/13720
     762              :          !    a_2=1269/980-2a_1=-3734/1715
     763              :          !    a_3=832591127/188238400
     764              :          ! f6(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5
     765              :          !      =63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     766              :          ! f7(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     767              :          !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     768              :          ! z(y)=(y+a_0)*y+a_1
     769              :          ! w(y)=(y+a_2)*z(y)+a_3
     770              :          ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     771              :          ! a_0= 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422
     772              :          ! a_1= 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568
     773              :          ! a_2=-1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968
     774              :          ! a_3= 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453
     775              :          ! a_4=-3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667
     776              :          ! a_5= 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578
     777              :          !
     778              :          ! y=1-X*X
     779              : 
     780              :          ! tmp1 = I-x*x
     781              :          IF (order == 2) THEN
     782        13156 :             prefactor = 0.5_dp
     783              : 
     784              :             ! update the above to 3*I-X*X
     785        13156 :             CALL dbcsr_add_on_diag(tmp1, +2.0_dp)
     786        13156 :             occ_matrix = dbcsr_get_occupation(matrix_sign)
     787              :          ELSE IF (order == 3) THEN
     788              :             ! with one multiplication
     789              :             ! tmp1=y
     790           12 :             CALL dbcsr_copy(tmp2, tmp1)
     791           12 :             CALL dbcsr_scale(tmp1, 4.0_dp/3.0_dp)
     792           12 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp/3.0_dp)
     793              : 
     794              :             ! tmp2=y^2
     795              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp2, 1.0_dp, tmp1, &
     796           12 :                                 filter_eps=threshold, flop=flops)
     797           12 :             floptot = floptot + flops
     798           12 :             prefactor = 3.0_dp/8.0_dp
     799              : 
     800              :          ELSE IF (order == 4) THEN
     801              :             ! with two multiplications
     802              :             ! tmp1=y
     803           10 :             CALL dbcsr_copy(tmp3, tmp1)
     804           10 :             CALL dbcsr_scale(tmp1, 8.0_dp/5.0_dp)
     805           10 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp/5.0_dp)
     806              : 
     807              :             !
     808              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     809           10 :                                 filter_eps=threshold, flop=flops)
     810           10 :             floptot = floptot + flops
     811              : 
     812           10 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp/5.0_dp)
     813              : 
     814              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     815           10 :                                 filter_eps=threshold, flop=flops)
     816           10 :             floptot = floptot + flops
     817              : 
     818           10 :             prefactor = 5.0_dp/16.0_dp
     819              :          ELSE IF (order == -5) THEN
     820              :             ! with three multiplications
     821              :             ! tmp1=y
     822            0 :             CALL dbcsr_copy(tmp3, tmp1)
     823            0 :             CALL dbcsr_scale(tmp1, 128.0_dp/70.0_dp)
     824            0 :             CALL dbcsr_add_on_diag(tmp1, 128.0_dp/35.0_dp)
     825              : 
     826              :             !
     827              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     828            0 :                                 filter_eps=threshold, flop=flops)
     829            0 :             floptot = floptot + flops
     830              : 
     831            0 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 48.0_dp/35.0_dp)
     832              : 
     833              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     834            0 :                                 filter_eps=threshold, flop=flops)
     835            0 :             floptot = floptot + flops
     836              : 
     837            0 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 8.0_dp/7.0_dp)
     838              : 
     839              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 1.0_dp, tmp1, &
     840            0 :                                 filter_eps=threshold, flop=flops)
     841            0 :             floptot = floptot + flops
     842              : 
     843            0 :             prefactor = 35.0_dp/128.0_dp
     844              :          ELSE IF (order == 5) THEN
     845              :             ! with two multiplications
     846              :             !      z(y)=(y+a_0)*y+a_1
     847              :             ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     848              :             !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     849              :             !    a_0=1/14
     850              :             !    a_1=23819/13720
     851              :             !    a_2=1269/980-2a_1=-3734/1715
     852              :             !    a_3=832591127/188238400
     853            8 :             a0 = 1.0_dp/14.0_dp
     854            8 :             a1 = 23819.0_dp/13720.0_dp
     855            8 :             a2 = -3734_dp/1715.0_dp
     856            8 :             a3 = 832591127_dp/188238400.0_dp
     857              : 
     858              :             ! tmp1=y
     859              :             ! tmp3=z
     860            8 :             CALL dbcsr_copy(tmp3, tmp1)
     861            8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     862              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     863            8 :                                 filter_eps=threshold, flop=flops)
     864            8 :             floptot = floptot + flops
     865            8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     866              : 
     867            8 :             CALL dbcsr_add_on_diag(tmp1, a2)
     868            8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
     869              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, tmp3, &
     870            8 :                                 filter_eps=threshold, flop=flops)
     871            8 :             floptot = floptot + flops
     872            8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     873            8 :             CALL dbcsr_copy(tmp1, tmp3)
     874              : 
     875            8 :             prefactor = 35.0_dp/128.0_dp
     876              :          ELSE IF (order == 6) THEN
     877              :             ! with four multiplications
     878              :             ! f6(y)=63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     879              :             ! tmp1=y
     880            8 :             CALL dbcsr_copy(tmp3, tmp1)
     881            8 :             CALL dbcsr_scale(tmp1, 128.0_dp/63.0_dp)
     882            8 :             CALL dbcsr_add_on_diag(tmp1, 256.0_dp/63.0_dp)
     883              : 
     884              :             !
     885              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     886            8 :                                 filter_eps=threshold, flop=flops)
     887            8 :             floptot = floptot + flops
     888              : 
     889            8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 32.0_dp/21.0_dp)
     890              : 
     891              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     892            8 :                                 filter_eps=threshold, flop=flops)
     893            8 :             floptot = floptot + flops
     894              : 
     895            8 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 80.0_dp/63.0_dp)
     896              : 
     897              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 0.0_dp, tmp2, &
     898            8 :                                 filter_eps=threshold, flop=flops)
     899            8 :             floptot = floptot + flops
     900              : 
     901            8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 10.0_dp/9.0_dp)
     902              : 
     903              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     904            8 :                                 filter_eps=threshold, flop=flops)
     905            8 :             floptot = floptot + flops
     906              : 
     907            8 :             prefactor = 63.0_dp/256.0_dp
     908              :          ELSE IF (order == 7) THEN
     909              :             ! with three multiplications
     910              : 
     911            8 :             a0 = 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422_dp
     912            8 :             a1 = 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568_dp
     913            8 :             a2 = -1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968_dp
     914            8 :             a3 = 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453_dp
     915            8 :             a4 = -3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667_dp
     916            8 :             a5 = 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578_dp
     917              :             !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     918              :             ! z(y)=(y+a_0)*y+a_1
     919              :             ! w(y)=(y+a_2)*z(y)+a_3
     920              :             ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     921              : 
     922              :             ! tmp1=y
     923              :             ! tmp3=z
     924            8 :             CALL dbcsr_copy(tmp3, tmp1)
     925            8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     926              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     927            8 :                                 filter_eps=threshold, flop=flops)
     928            8 :             floptot = floptot + flops
     929            8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     930              : 
     931              :             ! tmp4=w
     932            8 :             CALL dbcsr_copy(tmp4, tmp1)
     933            8 :             CALL dbcsr_add_on_diag(tmp4, a2)
     934              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp2, 0.0_dp, tmp3, &
     935            8 :                                 filter_eps=threshold, flop=flops)
     936            8 :             floptot = floptot + flops
     937            8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     938              : 
     939            8 :             CALL dbcsr_add(tmp2, tmp3, 1.0_dp, 1.0_dp)
     940            8 :             CALL dbcsr_add_on_diag(tmp2, a4)
     941              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
     942            8 :                                 filter_eps=threshold, flop=flops)
     943            8 :             floptot = floptot + flops
     944            8 :             CALL dbcsr_add_on_diag(tmp1, a5)
     945              : 
     946            8 :             prefactor = 231.0_dp/1024.0_dp
     947              :          ELSE
     948            0 :             CPABORT("requested order is not implemented.")
     949              :          END IF
     950              : 
     951              :          ! tmp2 = X * prefactor *
     952              :          CALL dbcsr_multiply("N", "N", prefactor, matrix_sign, tmp1, 0.0_dp, tmp2, &
     953        13202 :                              filter_eps=threshold, flop=flops)
     954        13202 :          floptot = floptot + flops
     955              : 
     956              :          ! done iterating
     957              :          ! CALL dbcsr_filter(tmp2,threshold)
     958        13202 :          CALL dbcsr_copy(matrix_sign, tmp2)
     959        13202 :          t2 = m_walltime()
     960              : 
     961        13202 :          occ_matrix = dbcsr_get_occupation(matrix_sign)
     962              : 
     963        13202 :          IF (unit_nr > 0) THEN
     964            0 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sign iter ", i, occ_matrix, &
     965            0 :                frob_matrix/frob_matrix_base, t2 - t1, &
     966            0 :                floptot/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     967            0 :             CALL m_flush(unit_nr)
     968              :          END IF
     969              : 
     970              :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
     971        13202 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) EXIT
     972              : 
     973              :       END DO
     974              : 
     975              :       ! this check is not really needed
     976              :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     977         1058 :                           filter_eps=threshold)
     978         1058 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     979         1058 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     980         1058 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
     981         1058 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
     982         1058 :       IF (unit_nr > 0) THEN
     983            0 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sign iter", i, occ_matrix, &
     984            0 :             frob_matrix/frob_matrix_base
     985            0 :          WRITE (unit_nr, '()')
     986            0 :          CALL m_flush(unit_nr)
     987              :       END IF
     988              : 
     989         1058 :       CALL dbcsr_release(tmp1)
     990         1058 :       CALL dbcsr_release(tmp2)
     991         1058 :       IF (ABS(order) >= 4) THEN
     992            8 :          CALL dbcsr_release(tmp3)
     993              :       END IF
     994            8 :       IF (ABS(order) > 4) THEN
     995            6 :          CALL dbcsr_release(tmp4)
     996              :       END IF
     997              : 
     998         1058 :       CALL timestop(handle)
     999              : 
    1000         1058 :    END SUBROUTINE matrix_sign_Newton_Schulz
    1001              : 
    1002              :    ! **************************************************************************************************
    1003              : !> \brief compute the sign a matrix using the general algorithm for the p-th root of Richters et al.
    1004              : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1005              : !> \param matrix_sign ...
    1006              : !> \param matrix ...
    1007              : !> \param threshold ...
    1008              : !> \param sign_order ...
    1009              : !> \par History
    1010              : !>       2019.03 created [Robert Schade]
    1011              : !> \author Robert Schade
    1012              : ! **************************************************************************************************
    1013           16 :    SUBROUTINE matrix_sign_proot(matrix_sign, matrix, threshold, sign_order)
    1014              : 
    1015              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1016              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1017              :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1018              : 
    1019              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sign_proot'
    1020              : 
    1021              :       INTEGER                                            :: handle, order, unit_nr
    1022              :       INTEGER(KIND=int_8)                                :: flop0, flop1, flop2
    1023              :       LOGICAL                                            :: converged, symmetrize
    1024              :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, occ_matrix
    1025              :       TYPE(cp_logger_type), POINTER                      :: logger
    1026              :       TYPE(dbcsr_type)                                   :: matrix2, matrix_sqrt, matrix_sqrt_inv, &
    1027              :                                                             tmp1, tmp2
    1028              : 
    1029            8 :       CALL cite_reference(Richters2018)
    1030              : 
    1031            8 :       CALL timeset(routineN, handle)
    1032              : 
    1033            8 :       logger => cp_get_default_logger()
    1034            8 :       IF (logger%para_env%is_source()) THEN
    1035            4 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1036              :       ELSE
    1037            4 :          unit_nr = -1
    1038              :       END IF
    1039              : 
    1040            8 :       IF (PRESENT(sign_order)) THEN
    1041            8 :          order = sign_order
    1042              :       ELSE
    1043            0 :          order = 2
    1044              :       END IF
    1045              : 
    1046            8 :       CALL dbcsr_create(tmp1, template=matrix_sign)
    1047              : 
    1048            8 :       CALL dbcsr_create(tmp2, template=matrix_sign)
    1049              : 
    1050            8 :       CALL dbcsr_create(matrix2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1051              :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix, 0.0_dp, matrix2, &
    1052            8 :                           filter_eps=threshold, flop=flop0)
    1053              :       !CALL dbcsr_filter(matrix2, threshold)
    1054              : 
    1055              :       !CALL dbcsr_copy(matrix_sign, matrix)
    1056              :       !CALL dbcsr_filter(matrix_sign, threshold)
    1057              : 
    1058            8 :       IF (unit_nr > 0) WRITE (unit_nr, *)
    1059              : 
    1060            8 :       CALL dbcsr_create(matrix_sqrt, template=matrix2)
    1061            8 :       CALL dbcsr_create(matrix_sqrt_inv, template=matrix2)
    1062            8 :       IF (unit_nr > 0) WRITE (unit_nr, *) "Threshold=", threshold
    1063              : 
    1064            8 :       symmetrize = .FALSE.
    1065              :       CALL matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix2, threshold, order, &
    1066            8 :                              0.01_dp, 100, symmetrize, converged)
    1067              : 
    1068              :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_sqrt_inv, 0.0_dp, matrix_sign, &
    1069            8 :                           filter_eps=threshold, flop=flop1)
    1070              : 
    1071              :       ! this check is not really needed
    1072              :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
    1073            8 :                           filter_eps=threshold, flop=flop2)
    1074            8 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1075            8 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1076            8 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1077            8 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
    1078            8 :       IF (unit_nr > 0) THEN
    1079            4 :          WRITE (unit_nr, '(T6,A,F10.8,E12.3)') "Final proot sign iter", occ_matrix, &
    1080            8 :             frob_matrix/frob_matrix_base
    1081            4 :          WRITE (unit_nr, '()')
    1082            4 :          CALL m_flush(unit_nr)
    1083              :       END IF
    1084              : 
    1085            8 :       CALL dbcsr_release(tmp1)
    1086            8 :       CALL dbcsr_release(tmp2)
    1087            8 :       CALL dbcsr_release(matrix2)
    1088            8 :       CALL dbcsr_release(matrix_sqrt)
    1089            8 :       CALL dbcsr_release(matrix_sqrt_inv)
    1090              : 
    1091            8 :       CALL timestop(handle)
    1092              : 
    1093            8 :    END SUBROUTINE matrix_sign_proot
    1094              : 
    1095              : ! **************************************************************************************************
    1096              : !> \brief compute the sign of a dense matrix using Newton-Schulz iterations
    1097              : !> \param matrix_sign ...
    1098              : !> \param matrix ...
    1099              : !> \param matrix_id ...
    1100              : !> \param threshold ...
    1101              : !> \param sign_order ...
    1102              : !> \author Michael Lass, Robert Schade
    1103              : ! **************************************************************************************************
    1104            2 :    SUBROUTINE dense_matrix_sign_Newton_Schulz(matrix_sign, matrix, matrix_id, threshold, sign_order)
    1105              : 
    1106              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: matrix_sign
    1107              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: matrix
    1108              :       INTEGER, INTENT(IN), OPTIONAL                      :: matrix_id
    1109              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1110              :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1111              : 
    1112              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dense_matrix_sign_Newton_Schulz'
    1113              : 
    1114              :       INTEGER                                            :: handle, i, j, sz, unit_nr
    1115              :       LOGICAL                                            :: converged
    1116              :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, &
    1117              :                                                             gersh_matrix, prefactor, scaling_factor
    1118            2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp1, tmp2
    1119              :       REAL(KIND=dp), DIMENSION(1)                        :: work
    1120              :       REAL(KIND=dp), EXTERNAL                            :: dlange
    1121              :       TYPE(cp_logger_type), POINTER                      :: logger
    1122              : 
    1123            2 :       CALL timeset(routineN, handle)
    1124              : 
    1125              :       ! print output on all ranks
    1126            2 :       logger => cp_get_default_logger()
    1127            2 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1128              : 
    1129              :       ! scale the matrix to get into the convergence range
    1130            2 :       sz = SIZE(matrix, 1)
    1131            2 :       frob_matrix = dlange('F', sz, sz, matrix, sz, work) !dbcsr_frobenius_norm(matrix_sign)
    1132            2 :       gersh_matrix = dlange('1', sz, sz, matrix, sz, work) !dbcsr_gershgorin_norm(matrix_sign)
    1133            2 :       scaling_factor = 1/MIN(frob_matrix, gersh_matrix)
    1134           86 :       matrix_sign = matrix*scaling_factor
    1135            8 :       ALLOCATE (tmp1(sz, sz))
    1136            6 :       ALLOCATE (tmp2(sz, sz))
    1137              : 
    1138            2 :       converged = .FALSE.
    1139           14 :       DO i = 1, 100
    1140           14 :          CALL dgemm('N', 'N', sz, sz, sz, -1.0_dp, matrix_sign, sz, matrix_sign, sz, 0.0_dp, tmp1, sz)
    1141              : 
    1142              :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1143           14 :          frob_matrix_base = dlange('F', sz, sz, tmp1, sz, work)
    1144           98 :          DO j = 1, sz
    1145           98 :             tmp1(j, j) = tmp1(j, j) + 1.0_dp
    1146              :          END DO
    1147           14 :          frob_matrix = dlange('F', sz, sz, tmp1, sz, work)
    1148              : 
    1149           14 :          IF (sign_order == 2) THEN
    1150            8 :             prefactor = 0.5_dp
    1151              :             ! update the above to 3*I-X*X
    1152           56 :             DO j = 1, sz
    1153           56 :                tmp1(j, j) = tmp1(j, j) + 2.0_dp
    1154              :             END DO
    1155            6 :          ELSE IF (sign_order == 3) THEN
    1156          258 :             tmp2(:, :) = tmp1
    1157          258 :             tmp1 = tmp1*4.0_dp/3.0_dp
    1158           42 :             DO j = 1, sz
    1159           42 :                tmp1(j, j) = tmp1(j, j) + 8.0_dp/3.0_dp
    1160              :             END DO
    1161            6 :             CALL dgemm('N', 'N', sz, sz, sz, 1.0_dp, tmp2, sz, tmp2, sz, 1.0_dp, tmp1, sz)
    1162            6 :             prefactor = 3.0_dp/8.0_dp
    1163              :          ELSE
    1164            0 :             CPABORT("requested order is not implemented.")
    1165              :          END IF
    1166              : 
    1167           14 :          CALL dgemm('N', 'N', sz, sz, sz, prefactor, matrix_sign, sz, tmp1, sz, 0.0_dp, tmp2, sz)
    1168          602 :          matrix_sign = tmp2
    1169              : 
    1170              :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
    1171           14 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) THEN
    1172              :             WRITE (unit_nr, '(T6,A,1X,I6,1X,A,1X,I3,E12.3)') &
    1173            2 :                "Submatrix", matrix_id, "final NS sign iter", i, frob_matrix/frob_matrix_base
    1174            2 :             CALL m_flush(unit_nr)
    1175              :             converged = .TRUE.
    1176              :             EXIT
    1177              :          END IF
    1178              :       END DO
    1179              : 
    1180              :       IF (.NOT. converged) &
    1181            0 :          CPABORT("dense_matrix_sign_Newton_Schulz did not converge within 100 iterations")
    1182              : 
    1183            2 :       DEALLOCATE (tmp1)
    1184            2 :       DEALLOCATE (tmp2)
    1185              : 
    1186            2 :       CALL timestop(handle)
    1187              : 
    1188            2 :    END SUBROUTINE dense_matrix_sign_Newton_Schulz
    1189              : 
    1190              : ! **************************************************************************************************
    1191              : !> \brief Perform eigendecomposition of a dense matrix
    1192              : !> \param sm ...
    1193              : !> \param N ...
    1194              : !> \param eigvals ...
    1195              : !> \param eigvecs ...
    1196              : !> \par History
    1197              : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1198              : !> \author Michael Lass, Robert Schade
    1199              : ! **************************************************************************************************
    1200            4 :    SUBROUTINE eigdecomp(sm, N, eigvals, eigvecs)
    1201              :       INTEGER, INTENT(IN)                                :: N
    1202              :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1203              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1204              :          INTENT(OUT)                                     :: eigvals
    1205              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1206              :          INTENT(OUT)                                     :: eigvecs
    1207              : 
    1208              :       INTEGER                                            :: info, liwork, lwork
    1209            4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iwork
    1210            4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work
    1211            4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp
    1212              : 
    1213           24 :       ALLOCATE (eigvecs(N, N), tmp(N, N))
    1214           12 :       ALLOCATE (eigvals(N))
    1215              : 
    1216              :       ! symmetrize sm
    1217          172 :       eigvecs(:, :) = 0.5*(sm + TRANSPOSE(sm))
    1218              : 
    1219              :       ! probe optimal sizes for WORK and IWORK
    1220            4 :       LWORK = -1
    1221            4 :       LIWORK = -1
    1222            4 :       ALLOCATE (WORK(1))
    1223            4 :       ALLOCATE (IWORK(1))
    1224            4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1225            4 :       LWORK = INT(WORK(1))
    1226            4 :       LIWORK = INT(IWORK(1))
    1227            4 :       DEALLOCATE (IWORK)
    1228            4 :       DEALLOCATE (WORK)
    1229              : 
    1230              :       ! calculate eigenvalues and eigenvectors
    1231           12 :       ALLOCATE (WORK(LWORK))
    1232           12 :       ALLOCATE (IWORK(LIWORK))
    1233            4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1234            4 :       DEALLOCATE (IWORK)
    1235            4 :       DEALLOCATE (WORK)
    1236            4 :       IF (INFO /= 0) CPABORT("dsyevd did not succeed")
    1237              : 
    1238            4 :       DEALLOCATE (tmp)
    1239            4 :    END SUBROUTINE eigdecomp
    1240              : 
    1241              : ! **************************************************************************************************
    1242              : !> \brief Calculate the sign matrix from eigenvalues and eigenvectors of a matrix
    1243              : !> \param sm_sign ...
    1244              : !> \param eigvals ...
    1245              : !> \param eigvecs ...
    1246              : !> \param N ...
    1247              : !> \param mu_correction ...
    1248              : !> \par History
    1249              : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1250              : !> \author Michael Lass, Robert Schade
    1251              : ! **************************************************************************************************
    1252            4 :    SUBROUTINE sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, mu_correction)
    1253              :       INTEGER                                            :: N
    1254              :       REAL(KIND=dp), INTENT(IN)                          :: eigvecs(N, N), eigvals(N)
    1255              :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1256              :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1257              : 
    1258              :       INTEGER                                            :: i
    1259            4 :       REAL(KIND=dp)                                      :: modified_eigval, tmp(N, N)
    1260              : 
    1261          172 :       sm_sign = 0
    1262           28 :       DO i = 1, N
    1263           24 :          modified_eigval = eigvals(i) - mu_correction
    1264           28 :          IF (modified_eigval > 0) THEN
    1265            6 :             sm_sign(i, i) = 1.0
    1266           18 :          ELSE IF (modified_eigval < 0) THEN
    1267           18 :             sm_sign(i, i) = -1.0
    1268              :          ELSE
    1269            0 :             sm_sign(i, i) = 0.0
    1270              :          END IF
    1271              :       END DO
    1272              : 
    1273              :       ! Create matrix with eigenvalues in {-1,0,1} and eigenvectors of sm:
    1274              :       ! sm_sign = eigvecs * sm_sign * eigvecs.T
    1275            4 :       CALL dgemm('N', 'N', N, N, N, 1.0_dp, eigvecs, N, sm_sign, N, 0.0_dp, tmp, N)
    1276            4 :       CALL dgemm('N', 'T', N, N, N, 1.0_dp, tmp, N, eigvecs, N, 0.0_dp, sm_sign, N)
    1277            4 :    END SUBROUTINE sign_from_eigdecomp
    1278              : 
    1279              : ! **************************************************************************************************
    1280              : !> \brief Compute partial trace of a matrix from its eigenvalues and eigenvectors
    1281              : !> \param eigvals ...
    1282              : !> \param eigvecs ...
    1283              : !> \param firstcol ...
    1284              : !> \param lastcol ...
    1285              : !> \param mu_correction ...
    1286              : !> \return ...
    1287              : !> \par History
    1288              : !>       2020.05 Created [Michael Lass]
    1289              : !> \author Michael Lass
    1290              : ! **************************************************************************************************
    1291           36 :    FUNCTION trace_from_eigdecomp(eigvals, eigvecs, firstcol, lastcol, mu_correction) RESULT(trace)
    1292              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1293              :          INTENT(IN)                                      :: eigvals
    1294              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1295              :          INTENT(IN)                                      :: eigvecs
    1296              :       INTEGER, INTENT(IN)                                :: firstcol, lastcol
    1297              :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1298              :       REAL(KIND=dp)                                      :: trace
    1299              : 
    1300              :       INTEGER                                            :: i, j, sm_size
    1301              :       REAL(KIND=dp)                                      :: modified_eigval, tmpsum
    1302           36 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: mapped_eigvals
    1303              : 
    1304           36 :       sm_size = SIZE(eigvals)
    1305          108 :       ALLOCATE (mapped_eigvals(sm_size))
    1306              : 
    1307          252 :       DO i = 1, sm_size
    1308          216 :          modified_eigval = eigvals(i) - mu_correction
    1309          252 :          IF (modified_eigval > 0) THEN
    1310           26 :             mapped_eigvals(i) = 1.0
    1311          190 :          ELSE IF (modified_eigval < 0) THEN
    1312          190 :             mapped_eigvals(i) = -1.0
    1313              :          ELSE
    1314            0 :             mapped_eigvals(i) = 0.0
    1315              :          END IF
    1316              :       END DO
    1317              : 
    1318           36 :       trace = 0.0_dp
    1319          252 :       DO i = firstcol, lastcol
    1320              :          tmpsum = 0.0_dp
    1321         1512 :          DO j = 1, sm_size
    1322         1512 :             tmpsum = tmpsum + (eigvecs(i, j)*mapped_eigvals(j)*eigvecs(i, j))
    1323              :          END DO
    1324          252 :          trace = trace - 0.5_dp*tmpsum + 0.5_dp
    1325              :       END DO
    1326           36 :    END FUNCTION trace_from_eigdecomp
    1327              : 
    1328              : ! **************************************************************************************************
    1329              : !> \brief Calculate the sign matrix by direct calculation of all eigenvalues and eigenvectors
    1330              : !> \param sm_sign ...
    1331              : !> \param sm ...
    1332              : !> \param N ...
    1333              : !> \par History
    1334              : !>       2020.02 Created [Michael Lass, Robert Schade]
    1335              : !>       2020.05 Extracted eigdecomp and sign_from_eigdecomp [Michael Lass]
    1336              : !> \author Michael Lass, Robert Schade
    1337              : ! **************************************************************************************************
    1338            2 :    SUBROUTINE dense_matrix_sign_direct(sm_sign, sm, N)
    1339              :       INTEGER, INTENT(IN)                                :: N
    1340              :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1341              :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1342              : 
    1343            2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigvals
    1344            2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: eigvecs
    1345              : 
    1346              :       CALL eigdecomp(sm, N, eigvals, eigvecs)
    1347            2 :       CALL sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, 0.0_dp)
    1348              : 
    1349            2 :       DEALLOCATE (eigvals, eigvecs)
    1350            2 :    END SUBROUTINE dense_matrix_sign_direct
    1351              : 
    1352              : ! **************************************************************************************************
    1353              : !> \brief Submatrix method
    1354              : !> \param matrix_sign ...
    1355              : !> \param matrix ...
    1356              : !> \param threshold ...
    1357              : !> \param sign_order ...
    1358              : !> \param submatrix_sign_method ...
    1359              : !> \par History
    1360              : !>       2019.03 created [Robert Schade]
    1361              : !>       2019.06 impl. submatrix method [Michael Lass]
    1362              : !> \author Robert Schade, Michael Lass
    1363              : ! **************************************************************************************************
    1364            6 :    SUBROUTINE matrix_sign_submatrix(matrix_sign, matrix, threshold, sign_order, submatrix_sign_method)
    1365              : 
    1366              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1367              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1368              :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1369              :       INTEGER, INTENT(IN)                                :: submatrix_sign_method
    1370              : 
    1371              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix'
    1372              : 
    1373              :       INTEGER                                            :: handle, i, myrank, nblkcols, order, &
    1374              :                                                             sm_size, unit_nr
    1375            6 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1376            6 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign
    1377              :       TYPE(cp_logger_type), POINTER                      :: logger
    1378              :       TYPE(dbcsr_distribution_type)                      :: dist
    1379            6 :       TYPE(submatrix_dissection_type)                    :: dissection
    1380              : 
    1381            6 :       CALL timeset(routineN, handle)
    1382              : 
    1383              :       ! print output on all ranks
    1384            6 :       logger => cp_get_default_logger()
    1385            6 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1386              : 
    1387            6 :       IF (PRESENT(sign_order)) THEN
    1388            6 :          order = sign_order
    1389              :       ELSE
    1390            0 :          order = 2
    1391              :       END IF
    1392              : 
    1393            6 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist)
    1394            6 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1395              : 
    1396            6 :       CALL dissection%init(matrix)
    1397            6 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1398              : 
    1399              :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1400              :       !$OMP          PRIVATE(sm, sm_sign, sm_size) &
    1401            6 :       !$OMP          SHARED(dissection, myrank, my_sms, order, submatrix_sign_method, threshold, unit_nr)
    1402              :       !$OMP DO SCHEDULE(GUIDED)
    1403              :       DO i = 1, SIZE(my_sms)
    1404              :          WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "processing submatrix", my_sms(i)
    1405              :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1406              :          sm_size = SIZE(sm, 1)
    1407              :          ALLOCATE (sm_sign(sm_size, sm_size))
    1408              :          SELECT CASE (submatrix_sign_method)
    1409              :          CASE (ls_scf_submatrix_sign_ns)
    1410              :             CALL dense_matrix_sign_Newton_Schulz(sm_sign, sm, my_sms(i), threshold, order)
    1411              :          CASE (ls_scf_submatrix_sign_direct, ls_scf_submatrix_sign_direct_muadj, ls_scf_submatrix_sign_direct_muadj_lowmem)
    1412              :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1413              :          CASE DEFAULT
    1414              :             CPABORT("Unkown submatrix sign method.")
    1415              :          END SELECT
    1416              :          CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1417              :          DEALLOCATE (sm, sm_sign)
    1418              :       END DO
    1419              :       !$OMP END DO
    1420              :       !$OMP END PARALLEL
    1421              : 
    1422            6 :       CALL dissection%communicate_results(matrix_sign)
    1423            6 :       CALL dissection%final
    1424              : 
    1425            6 :       CALL timestop(handle)
    1426              : 
    1427           12 :    END SUBROUTINE matrix_sign_submatrix
    1428              : 
    1429              : ! **************************************************************************************************
    1430              : !> \brief Submatrix method with internal adjustment of chemical potential
    1431              : !> \param matrix_sign ...
    1432              : !> \param matrix ...
    1433              : !> \param mu ...
    1434              : !> \param nelectron ...
    1435              : !> \param threshold ...
    1436              : !> \param variant ...
    1437              : !> \par History
    1438              : !>       2020.05 Created [Michael Lass]
    1439              : !> \author Robert Schade, Michael Lass
    1440              : ! **************************************************************************************************
    1441            4 :    SUBROUTINE matrix_sign_submatrix_mu_adjust(matrix_sign, matrix, mu, nelectron, threshold, variant)
    1442              : 
    1443              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1444              :       REAL(KIND=dp), INTENT(INOUT)                       :: mu
    1445              :       INTEGER, INTENT(IN)                                :: nelectron
    1446              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1447              :       INTEGER, INTENT(IN)                                :: variant
    1448              : 
    1449              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix_mu_adjust'
    1450              :       REAL(KIND=dp), PARAMETER                           :: initial_increment = 0.01_dp
    1451              : 
    1452              :       INTEGER                                            :: handle, i, j, myrank, nblkcols, &
    1453              :                                                             sm_firstcol, sm_lastcol, sm_size, &
    1454              :                                                             unit_nr
    1455            4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1456              :       LOGICAL                                            :: has_mu_high, has_mu_low
    1457              :       REAL(KIND=dp)                                      :: increment, mu_high, mu_low, new_mu, trace
    1458            4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign, tmp
    1459              :       TYPE(cp_logger_type), POINTER                      :: logger
    1460              :       TYPE(dbcsr_distribution_type)                      :: dist
    1461            4 :       TYPE(eigbuf), ALLOCATABLE, DIMENSION(:)            :: eigbufs
    1462              :       TYPE(mp_comm_type)                                 :: group
    1463            4 :       TYPE(submatrix_dissection_type)                    :: dissection
    1464              : 
    1465            4 :       CALL timeset(routineN, handle)
    1466              : 
    1467              :       ! print output on all ranks
    1468            4 :       logger => cp_get_default_logger()
    1469            4 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1470              : 
    1471            4 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist, group=group)
    1472            4 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1473              : 
    1474            4 :       CALL dissection%init(matrix)
    1475            4 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1476              : 
    1477           12 :       ALLOCATE (eigbufs(SIZE(my_sms)))
    1478              : 
    1479              :       trace = 0.0_dp
    1480              : 
    1481              :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1482              :       !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j, tmp) &
    1483              :       !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, threshold, variant) &
    1484            4 :       !$OMP          REDUCTION(+:trace)
    1485              :       !$OMP DO SCHEDULE(GUIDED)
    1486              :       DO i = 1, SIZE(my_sms)
    1487              :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1488              :          sm_size = SIZE(sm, 1)
    1489              :          WRITE (unit_nr, *) "Rank", myrank, "processing submatrix", my_sms(i), "size", sm_size
    1490              : 
    1491              :          CALL dissection%get_relevant_sm_columns(my_sms(i), sm_firstcol, sm_lastcol)
    1492              : 
    1493              :          IF (variant == ls_scf_submatrix_sign_direct_muadj) THEN
    1494              :             ! Store all eigenvectors in buffer. We will use it to compute sm_sign at the end.
    1495              :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=eigbufs(i)%eigvecs)
    1496              :          ELSE
    1497              :             ! Only store eigenvectors that are required for mu adjustment.
    1498              :             ! Calculate sm_sign right away in the hope that mu is already correct.
    1499              :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=tmp)
    1500              :             ALLOCATE (eigbufs(i)%eigvecs(sm_firstcol:sm_lastcol, 1:sm_size))
    1501              :             eigbufs(i)%eigvecs(:, :) = tmp(sm_firstcol:sm_lastcol, 1:sm_size)
    1502              : 
    1503              :             ALLOCATE (sm_sign(sm_size, sm_size))
    1504              :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, tmp, sm_size, 0.0_dp)
    1505              :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1506              :             DEALLOCATE (sm_sign, tmp)
    1507              :          END IF
    1508              : 
    1509              :          DEALLOCATE (sm)
    1510              :          trace = trace + trace_from_eigdecomp(eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_firstcol, sm_lastcol, 0.0_dp)
    1511              :       END DO
    1512              :       !$OMP END DO
    1513              :       !$OMP END PARALLEL
    1514              : 
    1515            4 :       has_mu_low = .FALSE.
    1516            4 :       has_mu_high = .FALSE.
    1517            4 :       increment = initial_increment
    1518            4 :       new_mu = mu
    1519           72 :       DO i = 1, 30
    1520           72 :          CALL group%sum(trace)
    1521           72 :          IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,1X,F13.9,1X,F15.9)') &
    1522           72 :             "Density matrix:  mu, trace error: ", new_mu, trace - nelectron
    1523           72 :          IF (ABS(trace - nelectron) < 0.5_dp) EXIT
    1524           68 :          IF (trace < nelectron) THEN
    1525            8 :             mu_low = new_mu
    1526            8 :             new_mu = new_mu + increment
    1527            8 :             has_mu_low = .TRUE.
    1528            8 :             increment = increment*2
    1529              :          ELSE
    1530           60 :             mu_high = new_mu
    1531           60 :             new_mu = new_mu - increment
    1532           60 :             has_mu_high = .TRUE.
    1533           60 :             increment = increment*2
    1534              :          END IF
    1535              : 
    1536           68 :          IF (has_mu_low .AND. has_mu_high) THEN
    1537           20 :             new_mu = (mu_low + mu_high)/2
    1538           20 :             IF (ABS(mu_high - mu_low) < threshold) EXIT
    1539              :          END IF
    1540              : 
    1541              :          trace = 0
    1542              :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1543              :          !$OMP          PRIVATE(i, sm_sign, tmp, sm_size, sm_firstcol, sm_lastcol) &
    1544              :          !$OMP          SHARED(dissection, my_sms, unit_nr, eigbufs, mu, new_mu, nelectron) &
    1545           72 :          !$OMP          REDUCTION(+:trace)
    1546              :          !$OMP DO SCHEDULE(GUIDED)
    1547              :          DO j = 1, SIZE(my_sms)
    1548              :             sm_size = SIZE(eigbufs(j)%eigvals)
    1549              :             CALL dissection%get_relevant_sm_columns(my_sms(j), sm_firstcol, sm_lastcol)
    1550              :             trace = trace + trace_from_eigdecomp(eigbufs(j)%eigvals, eigbufs(j)%eigvecs, sm_firstcol, sm_lastcol, new_mu - mu)
    1551              :          END DO
    1552              :          !$OMP END DO
    1553              :          !$OMP END PARALLEL
    1554              :       END DO
    1555              : 
    1556              :       ! Finalize sign matrix from eigendecompositions if we kept all eigenvectors
    1557            4 :       IF (variant == ls_scf_submatrix_sign_direct_muadj) THEN
    1558              :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1559              :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1560            2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1561              :          !$OMP DO SCHEDULE(GUIDED)
    1562              :          DO i = 1, SIZE(my_sms)
    1563              :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "finalizing submatrix", my_sms(i)
    1564              :             sm_size = SIZE(eigbufs(i)%eigvals)
    1565              :             ALLOCATE (sm_sign(sm_size, sm_size))
    1566              :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_size, new_mu - mu)
    1567              :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1568              :             DEALLOCATE (sm_sign)
    1569              :          END DO
    1570              :          !$OMP END DO
    1571              :          !$OMP END PARALLEL
    1572              :       END IF
    1573              : 
    1574            6 :       DEALLOCATE (eigbufs)
    1575              : 
    1576              :       ! If we only stored parts of the eigenvectors and mu has changed, we need to recompute sm_sign
    1577            4 :       IF ((variant == ls_scf_submatrix_sign_direct_muadj_lowmem) .AND. (mu /= new_mu)) THEN
    1578              :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1579              :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1580            2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1581              :          !$OMP DO SCHEDULE(GUIDED)
    1582              :          DO i = 1, SIZE(my_sms)
    1583              :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "reprocessing submatrix", my_sms(i)
    1584              :             CALL dissection%generate_submatrix(my_sms(i), sm)
    1585              :             sm_size = SIZE(sm, 1)
    1586              :             DO j = 1, sm_size
    1587              :                sm(j, j) = sm(j, j) + mu - new_mu
    1588              :             END DO
    1589              :             ALLOCATE (sm_sign(sm_size, sm_size))
    1590              :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1591              :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1592              :             DEALLOCATE (sm, sm_sign)
    1593              :          END DO
    1594              :          !$OMP END DO
    1595              :          !$OMP END PARALLEL
    1596              :       END IF
    1597              : 
    1598            4 :       mu = new_mu
    1599              : 
    1600            4 :       CALL dissection%communicate_results(matrix_sign)
    1601            4 :       CALL dissection%final
    1602              : 
    1603            4 :       CALL timestop(handle)
    1604              : 
    1605           12 :    END SUBROUTINE matrix_sign_submatrix_mu_adjust
    1606              : 
    1607              : ! **************************************************************************************************
    1608              : !> \brief compute the sqrt of a matrix via the sign function and the corresponding Newton-Schulz iterations
    1609              : !>        the order of the algorithm should be 2..5, 3 or 5 is recommended
    1610              : !> \param matrix_sqrt ...
    1611              : !> \param matrix_sqrt_inv ...
    1612              : !> \param matrix ...
    1613              : !> \param threshold ...
    1614              : !> \param order ...
    1615              : !> \param eps_lanczos ...
    1616              : !> \param max_iter_lanczos ...
    1617              : !> \param symmetrize ...
    1618              : !> \param converged ...
    1619              : !> \param iounit ...
    1620              : !> \par History
    1621              : !>       2010.10 created [Joost VandeVondele]
    1622              : !> \author Joost VandeVondele
    1623              : ! **************************************************************************************************
    1624        14410 :    SUBROUTINE matrix_sqrt_Newton_Schulz(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1625              :                                         eps_lanczos, max_iter_lanczos, symmetrize, converged, iounit)
    1626              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1627              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1628              :       INTEGER, INTENT(IN)                                :: order
    1629              :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1630              :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1631              :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1632              :       INTEGER, INTENT(IN), OPTIONAL                      :: iounit
    1633              : 
    1634              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sqrt_Newton_Schulz'
    1635              : 
    1636              :       INTEGER                                            :: handle, i, unit_nr
    1637              :       INTEGER(KIND=int_8)                                :: flop1, flop2, flop3, flop4, flop5
    1638              :       LOGICAL                                            :: arnoldi_converged, tsym
    1639              :       REAL(KIND=dp)                                      :: a, b, c, conv, d, frob_matrix, &
    1640              :                                                             frob_matrix_base, gershgorin_norm, &
    1641              :                                                             max_ev, min_ev, oa, ob, oc, &
    1642              :                                                             occ_matrix, od, scaling, t1, t2
    1643              :       TYPE(cp_logger_type), POINTER                      :: logger
    1644              :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
    1645              : 
    1646        14410 :       CALL timeset(routineN, handle)
    1647              : 
    1648        14410 :       IF (PRESENT(iounit)) THEN
    1649        13058 :          unit_nr = iounit
    1650              :       ELSE
    1651         1352 :          logger => cp_get_default_logger()
    1652         1352 :          IF (logger%para_env%is_source()) THEN
    1653          676 :             unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1654              :          ELSE
    1655          676 :             unit_nr = -1
    1656              :          END IF
    1657              :       END IF
    1658              : 
    1659        14410 :       IF (PRESENT(converged)) converged = .FALSE.
    1660        14410 :       IF (PRESENT(symmetrize)) THEN
    1661            0 :          tsym = symmetrize
    1662              :       ELSE
    1663              :          tsym = .TRUE.
    1664              :       END IF
    1665              : 
    1666              :       ! for stability symmetry can not be assumed
    1667        14410 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1668        14410 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1669        14410 :       IF (order >= 4) THEN
    1670           20 :          CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1671              :       END IF
    1672              : 
    1673        14410 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1674        14410 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1675        14410 :       CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1676        14410 :       CALL dbcsr_copy(matrix_sqrt, matrix)
    1677              : 
    1678              :       ! scale the matrix to get into the convergence range
    1679        14410 :       IF (order == 0) THEN
    1680              : 
    1681            0 :          gershgorin_norm = dbcsr_gershgorin_norm(matrix_sqrt)
    1682            0 :          frob_matrix = dbcsr_frobenius_norm(matrix_sqrt)
    1683            0 :          scaling = 1.0_dp/MIN(frob_matrix, gershgorin_norm)
    1684              : 
    1685              :       ELSE
    1686              : 
    1687              :          ! scale the matrix to get into the convergence range
    1688              :          CALL arnoldi_extremal(matrix_sqrt, max_ev, min_ev, threshold=eps_lanczos, &
    1689        14410 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1690        14410 :          IF (unit_nr > 0) THEN
    1691          676 :             WRITE (unit_nr, *)
    1692          676 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1693          676 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1694          676 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1695              :          END IF
    1696              :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1697              :          ! and adjust the scaling to be on the safe side
    1698        14410 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1699              : 
    1700              :       END IF
    1701              : 
    1702        14410 :       CALL dbcsr_scale(matrix_sqrt, scaling)
    1703        14410 :       CALL dbcsr_filter(matrix_sqrt, threshold)
    1704        14410 :       IF (unit_nr > 0) THEN
    1705          676 :          WRITE (unit_nr, *)
    1706          676 :          WRITE (unit_nr, *) "Order=", order
    1707              :       END IF
    1708              : 
    1709        67914 :       DO i = 1, 100
    1710              : 
    1711        67914 :          t1 = m_walltime()
    1712              : 
    1713              :          ! tmp1 = Zk * Yk - I
    1714              :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1715        67914 :                              filter_eps=threshold, flop=flop1)
    1716        67914 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1717        67914 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1718              : 
    1719              :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1720        67914 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    1721              : 
    1722        67914 :          flop4 = 0; flop5 = 0
    1723           36 :          SELECT CASE (order)
    1724              :          CASE (0, 2)
    1725              :             ! update the above to 0.5*(3*I-Zk*Yk)
    1726           36 :             CALL dbcsr_add_on_diag(tmp1, -2.0_dp)
    1727           36 :             CALL dbcsr_scale(tmp1, -0.5_dp)
    1728              :          CASE (3)
    1729              :             ! tmp2 = tmp1 ** 2
    1730              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1731        67818 :                                 filter_eps=threshold, flop=flop4)
    1732              :             ! tmp1 = 1/16 * (16*I-8*tmp1+6*tmp1**2-5*tmp1**3)
    1733        67818 :             CALL dbcsr_add(tmp1, tmp2, -4.0_dp, 3.0_dp)
    1734        67818 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp)
    1735        67818 :             CALL dbcsr_scale(tmp1, 0.125_dp)
    1736              :          CASE (4) ! as expensive as case(5), so little need to use it
    1737              :             ! tmp2 = tmp1 ** 2
    1738              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1739           32 :                                 filter_eps=threshold, flop=flop4)
    1740              :             ! tmp3 = tmp2 * tmp1
    1741              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp1, 0.0_dp, tmp3, &
    1742           32 :                                 filter_eps=threshold, flop=flop5)
    1743           32 :             CALL dbcsr_scale(tmp1, -8.0_dp)
    1744           32 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp)
    1745           32 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp)
    1746           32 :             CALL dbcsr_add(tmp1, tmp3, 1.0_dp, -5.0_dp)
    1747           32 :             CALL dbcsr_scale(tmp1, 1/16.0_dp)
    1748              :          CASE (5)
    1749              :             ! Knuth's reformulation to evaluate the polynomial of 4th degree in 2 multiplications
    1750              :             ! p = y4+A*y3+B*y2+C*y+D
    1751              :             ! z := y * (y+a); P := (z+y+b) * (z+c) + d.
    1752              :             ! a=(A-1)/2 ; b=B*(a+1)-C-a*(a+1)*(a+1)
    1753              :             ! c=B-b-a*(a+1)
    1754              :             ! d=D-bc
    1755           28 :             oa = -40.0_dp/35.0_dp
    1756           28 :             ob = 48.0_dp/35.0_dp
    1757           28 :             oc = -64.0_dp/35.0_dp
    1758           28 :             od = 128.0_dp/35.0_dp
    1759           28 :             a = (oa - 1)/2
    1760           28 :             b = ob*(a + 1) - oc - a*(a + 1)**2
    1761           28 :             c = ob - b - a*(a + 1)
    1762           28 :             d = od - b*c
    1763              :             ! tmp2 = tmp1 ** 2 + a * tmp1
    1764              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1765           28 :                                 filter_eps=threshold, flop=flop4)
    1766           28 :             CALL dbcsr_add(tmp2, tmp1, 1.0_dp, a)
    1767              :             ! tmp3 = tmp2 + tmp1 + b
    1768           28 :             CALL dbcsr_copy(tmp3, tmp2)
    1769           28 :             CALL dbcsr_add(tmp3, tmp1, 1.0_dp, 1.0_dp)
    1770           28 :             CALL dbcsr_add_on_diag(tmp3, b)
    1771              :             ! tmp2 = tmp2 + c
    1772           28 :             CALL dbcsr_add_on_diag(tmp2, c)
    1773              :             ! tmp1 = tmp2 * tmp3
    1774              :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
    1775           28 :                                 filter_eps=threshold, flop=flop5)
    1776              :             ! tmp1 = tmp1 + d
    1777           28 :             CALL dbcsr_add_on_diag(tmp1, d)
    1778              :             ! final scale
    1779           28 :             CALL dbcsr_scale(tmp1, 35.0_dp/128.0_dp)
    1780              :          CASE DEFAULT
    1781        67914 :             CPABORT("Illegal order value")
    1782              :          END SELECT
    1783              : 
    1784              :          ! tmp2 = Yk * tmp1 = Y(k+1)
    1785              :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt, tmp1, 0.0_dp, tmp2, &
    1786        67914 :                              filter_eps=threshold, flop=flop2)
    1787              :          ! CALL dbcsr_filter(tmp2,threshold)
    1788        67914 :          CALL dbcsr_copy(matrix_sqrt, tmp2)
    1789              : 
    1790              :          ! tmp2 = tmp1 * Zk = Z(k+1)
    1791              :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_sqrt_inv, 0.0_dp, tmp2, &
    1792        67914 :                              filter_eps=threshold, flop=flop3)
    1793              :          ! CALL dbcsr_filter(tmp2,threshold)
    1794        67914 :          CALL dbcsr_copy(matrix_sqrt_inv, tmp2)
    1795              : 
    1796        67914 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1797              : 
    1798              :          ! done iterating
    1799        67914 :          t2 = m_walltime()
    1800              : 
    1801        67914 :          conv = frob_matrix/frob_matrix_base
    1802              : 
    1803        67914 :          IF (unit_nr > 0) THEN
    1804         3818 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sqrt iter ", i, occ_matrix, &
    1805         3818 :                conv, t2 - t1, &
    1806         7636 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    1807         3818 :             CALL m_flush(unit_nr)
    1808              :          END IF
    1809              : 
    1810        67914 :          IF (abnormal_value(conv)) &
    1811            0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    1812              : 
    1813              :          ! conv < SQRT(threshold)
    1814        67914 :          IF ((conv*conv) < threshold) THEN
    1815        14410 :             IF (PRESENT(converged)) converged = .TRUE.
    1816              :             EXIT
    1817              :          END IF
    1818              : 
    1819              :       END DO
    1820              : 
    1821              :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    1822        14410 :       IF (tsym) THEN
    1823        14410 :          IF (unit_nr > 0) THEN
    1824          676 :             WRITE (unit_nr, '(T6,A20)') "Symmetrizing Results"
    1825              :          END IF
    1826        14410 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    1827        14410 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    1828        14410 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    1829        14410 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    1830              :       END IF
    1831              : 
    1832              :       ! this check is not really needed
    1833              :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1834        14410 :                           filter_eps=threshold)
    1835        14410 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1836        14410 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1837        14410 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1838        14410 :       occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1839        14410 :       IF (unit_nr > 0) THEN
    1840          676 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sqrt iter ", i, occ_matrix, &
    1841         1352 :             frob_matrix/frob_matrix_base
    1842          676 :          WRITE (unit_nr, '()')
    1843          676 :          CALL m_flush(unit_nr)
    1844              :       END IF
    1845              : 
    1846              :       ! scale to proper end results
    1847        14410 :       CALL dbcsr_scale(matrix_sqrt, 1/SQRT(scaling))
    1848        14410 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    1849              : 
    1850        14410 :       CALL dbcsr_release(tmp1)
    1851        14410 :       CALL dbcsr_release(tmp2)
    1852        14410 :       IF (order >= 4) THEN
    1853           20 :          CALL dbcsr_release(tmp3)
    1854              :       END IF
    1855              : 
    1856        14410 :       CALL timestop(handle)
    1857              : 
    1858        14410 :    END SUBROUTINE matrix_sqrt_Newton_Schulz
    1859              : 
    1860              : ! **************************************************************************************************
    1861              : !> \brief compute the sqrt of a matrix via the general algorithm for the p-th root of Richters et al.
    1862              : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1863              : !> \param matrix_sqrt ...
    1864              : !> \param matrix_sqrt_inv ...
    1865              : !> \param matrix ...
    1866              : !> \param threshold ...
    1867              : !> \param order ...
    1868              : !> \param eps_lanczos ...
    1869              : !> \param max_iter_lanczos ...
    1870              : !> \param symmetrize ...
    1871              : !> \param converged ...
    1872              : !> \par History
    1873              : !>       2019.04 created [Robert Schade]
    1874              : !> \author Robert Schade
    1875              : ! **************************************************************************************************
    1876           48 :    SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1877              :                                 eps_lanczos, max_iter_lanczos, symmetrize, converged)
    1878              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1879              :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1880              :       INTEGER, INTENT(IN)                                :: order
    1881              :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1882              :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1883              :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1884              : 
    1885              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sqrt_proot'
    1886              : 
    1887              :       INTEGER                                            :: choose, handle, i, ii, j, unit_nr
    1888              :       INTEGER(KIND=int_8)                                :: f, flop1, flop2, flop3, flop4, flop5
    1889              :       LOGICAL                                            :: arnoldi_converged, test, tsym
    1890              :       REAL(KIND=dp)                                      :: conv, frob_matrix, frob_matrix_base, &
    1891              :                                                             max_ev, min_ev, occ_matrix, scaling, &
    1892              :                                                             t1, t2
    1893              :       TYPE(cp_logger_type), POINTER                      :: logger
    1894              :       TYPE(dbcsr_type)                                   :: BK2A, matrixS, Rmat, tmp1, tmp2, tmp3
    1895              : 
    1896           16 :       CALL cite_reference(Richters2018)
    1897              : 
    1898           16 :       test = .FALSE.
    1899              : 
    1900           16 :       CALL timeset(routineN, handle)
    1901              : 
    1902           16 :       logger => cp_get_default_logger()
    1903           16 :       IF (logger%para_env%is_source()) THEN
    1904            8 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1905              :       ELSE
    1906            8 :          unit_nr = -1
    1907              :       END IF
    1908              : 
    1909           16 :       IF (PRESENT(converged)) converged = .FALSE.
    1910           16 :       IF (PRESENT(symmetrize)) THEN
    1911           16 :          tsym = symmetrize
    1912              :       ELSE
    1913              :          tsym = .TRUE.
    1914              :       END IF
    1915              : 
    1916              :       ! for stability symmetry can not be assumed
    1917           16 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1918           16 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1919           16 :       CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1920           16 :       CALL dbcsr_create(Rmat, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1921           16 :       CALL dbcsr_create(matrixS, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1922              : 
    1923           16 :       CALL dbcsr_copy(matrixS, matrix)
    1924              :       IF (1 == 1) THEN
    1925              :          ! scale the matrix to get into the convergence range
    1926              :          CALL arnoldi_extremal(matrixS, max_ev, min_ev, threshold=eps_lanczos, &
    1927           16 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1928           16 :          IF (unit_nr > 0) THEN
    1929            8 :             WRITE (unit_nr, *)
    1930            8 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1931            8 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1932            8 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1933              :          END IF
    1934              :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1935              :          ! and adjust the scaling to be on the safe side
    1936           16 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1937           16 :          CALL dbcsr_scale(matrixS, scaling)
    1938           16 :          CALL dbcsr_filter(matrixS, threshold)
    1939              :       ELSE
    1940              :          scaling = 1.0_dp
    1941              :       END IF
    1942              : 
    1943           16 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1944           16 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1945              :       !CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1946              : 
    1947           16 :       IF (unit_nr > 0) THEN
    1948            8 :          WRITE (unit_nr, *)
    1949            8 :          WRITE (unit_nr, *) "Order=", order
    1950              :       END IF
    1951              : 
    1952           86 :       DO i = 1, 100
    1953              : 
    1954           86 :          t1 = m_walltime()
    1955              :          IF (1 == 1) THEN
    1956              :             !build R=1-A B_K^2
    1957              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp1, &
    1958           86 :                                 filter_eps=threshold, flop=flop1)
    1959              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrixS, tmp1, 0.0_dp, Rmat, &
    1960           86 :                                 filter_eps=threshold, flop=flop2)
    1961           86 :             CALL dbcsr_scale(Rmat, -1.0_dp)
    1962           86 :             CALL dbcsr_add_on_diag(Rmat, 1.0_dp)
    1963              : 
    1964           86 :             flop4 = 0; flop5 = 0
    1965           86 :             CALL dbcsr_set(tmp1, 0.0_dp)
    1966           86 :             CALL dbcsr_add_on_diag(tmp1, 2.0_dp)
    1967              : 
    1968           86 :             flop3 = 0
    1969              : 
    1970          274 :             DO j = 2, order
    1971          188 :                IF (j == 2) THEN
    1972           86 :                   CALL dbcsr_copy(tmp2, Rmat)
    1973              :                ELSE
    1974              :                   f = 0
    1975          102 :                   CALL dbcsr_copy(tmp3, tmp2)
    1976              :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, Rmat, 0.0_dp, tmp2, &
    1977          102 :                                       filter_eps=threshold, flop=f)
    1978          102 :                   flop3 = flop3 + f
    1979              :                END IF
    1980          274 :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
    1981              :             END DO
    1982              :          ELSE
    1983              :             CALL dbcsr_create(BK2A, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1984              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrixS, 0.0_dp, tmp3, &
    1985              :                                 filter_eps=threshold, flop=flop1)
    1986              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, tmp3, 0.0_dp, BK2A, &
    1987              :                                 filter_eps=threshold, flop=flop2)
    1988              :             CALL dbcsr_copy(Rmat, BK2A)
    1989              :             CALL dbcsr_add_on_diag(Rmat, -1.0_dp)
    1990              : 
    1991              :             CALL dbcsr_set(tmp1, 0.0_dp)
    1992              :             CALL dbcsr_add_on_diag(tmp1, 1.0_dp)
    1993              : 
    1994              :             CALL dbcsr_set(tmp2, 0.0_dp)
    1995              :             CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
    1996              : 
    1997              :             flop3 = 0
    1998              :             DO j = 1, order
    1999              :                !choose=factorial(order)/(factorial(j)*factorial(order-j))
    2000              :                choose = PRODUCT([(ii, ii=1, order)])/(PRODUCT([(ii, ii=1, j)])*PRODUCT([(ii, ii=1, order - j)]))
    2001              :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, -1.0_dp*(-1)**j*choose)
    2002              :                IF (j < order) THEN
    2003              :                   f = 0
    2004              :                   CALL dbcsr_copy(tmp3, tmp2)
    2005              :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, BK2A, 0.0_dp, tmp2, &
    2006              :                                       filter_eps=threshold, flop=f)
    2007              :                   flop3 = flop3 + f
    2008              :                END IF
    2009              :             END DO
    2010              :             CALL dbcsr_release(BK2A)
    2011              :          END IF
    2012              : 
    2013           86 :          CALL dbcsr_copy(tmp3, matrix_sqrt_inv)
    2014              :          CALL dbcsr_multiply("N", "N", 0.5_dp, tmp3, tmp1, 0.0_dp, matrix_sqrt_inv, &
    2015           86 :                              filter_eps=threshold, flop=flop4)
    2016              : 
    2017           86 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2018              : 
    2019              :          ! done iterating
    2020           86 :          t2 = m_walltime()
    2021              : 
    2022           86 :          conv = dbcsr_frobenius_norm(Rmat)
    2023              : 
    2024           86 :          IF (unit_nr > 0) THEN
    2025           43 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "PROOT sqrt iter ", i, occ_matrix, &
    2026           43 :                conv, t2 - t1, &
    2027           86 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    2028           43 :             CALL m_flush(unit_nr)
    2029              :          END IF
    2030              : 
    2031           86 :          IF (abnormal_value(conv)) &
    2032            0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    2033              : 
    2034              :          ! conv < SQRT(threshold)
    2035           86 :          IF ((conv*conv) < threshold) THEN
    2036           16 :             IF (PRESENT(converged)) converged = .TRUE.
    2037              :             EXIT
    2038              :          END IF
    2039              : 
    2040              :       END DO
    2041              : 
    2042              :       ! scale to proper end results
    2043           16 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    2044              :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix, 0.0_dp, matrix_sqrt, &
    2045           16 :                           filter_eps=threshold, flop=flop5)
    2046              : 
    2047              :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    2048           16 :       IF (tsym) THEN
    2049            8 :          IF (unit_nr > 0) THEN
    2050            4 :             WRITE (unit_nr, '(A20)') "SYMMETRIZING RESULTS"
    2051              :          END IF
    2052            8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    2053            8 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    2054            8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    2055            8 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    2056              :       END IF
    2057              : 
    2058              :       ! this check is not really needed
    2059              :       IF (test) THEN
    2060              :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    2061              :                              filter_eps=threshold)
    2062              :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2063              :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2064              :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2065              :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2066              :          IF (unit_nr > 0) THEN
    2067              :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{1/2}-Eins error ", i, occ_matrix, &
    2068              :                frob_matrix/frob_matrix_base
    2069              :             WRITE (unit_nr, '()')
    2070              :             CALL m_flush(unit_nr)
    2071              :          END IF
    2072              : 
    2073              :          ! this check is not really needed
    2074              :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp2, &
    2075              :                              filter_eps=threshold)
    2076              :          CALL dbcsr_multiply("N", "N", +1.0_dp, tmp2, matrix, 0.0_dp, tmp1, &
    2077              :                              filter_eps=threshold)
    2078              :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2079              :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2080              :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2081              :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2082              :          IF (unit_nr > 0) THEN
    2083              :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{-1/2} S-Eins error ", i, occ_matrix, &
    2084              :                frob_matrix/frob_matrix_base
    2085              :             WRITE (unit_nr, '()')
    2086              :             CALL m_flush(unit_nr)
    2087              :          END IF
    2088              :       END IF
    2089              : 
    2090           16 :       CALL dbcsr_release(tmp1)
    2091           16 :       CALL dbcsr_release(tmp2)
    2092           16 :       CALL dbcsr_release(tmp3)
    2093           16 :       CALL dbcsr_release(Rmat)
    2094           16 :       CALL dbcsr_release(matrixS)
    2095              : 
    2096           16 :       CALL timestop(handle)
    2097           16 :    END SUBROUTINE matrix_sqrt_proot
    2098              : 
    2099              : ! **************************************************************************************************
    2100              : !> \brief ...
    2101              : !> \param matrix_exp ...
    2102              : !> \param matrix ...
    2103              : !> \param omega ...
    2104              : !> \param alpha ...
    2105              : !> \param threshold ...
    2106              : ! **************************************************************************************************
    2107         1146 :    SUBROUTINE matrix_exponential(matrix_exp, matrix, omega, alpha, threshold)
    2108              :       ! compute matrix_exp=omega*exp(alpha*matrix)
    2109              :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_exp, matrix
    2110              :       REAL(KIND=dp), INTENT(IN)                          :: omega, alpha, threshold
    2111              : 
    2112              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_exponential'
    2113              :       REAL(dp), PARAMETER                                :: one = 1.0_dp, toll = 1.E-17_dp, &
    2114              :                                                             zero = 0.0_dp
    2115              : 
    2116              :       INTEGER                                            :: handle, i, k, unit_nr
    2117              :       REAL(dp)                                           :: factorial, norm_C, norm_D, norm_scalar
    2118              :       TYPE(cp_logger_type), POINTER                      :: logger
    2119              :       TYPE(dbcsr_type)                                   :: B, B_square, C, D, D_product
    2120              : 
    2121         1146 :       CALL timeset(routineN, handle)
    2122              : 
    2123         1146 :       logger => cp_get_default_logger()
    2124         1146 :       IF (logger%para_env%is_source()) THEN
    2125         1058 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2126              :       ELSE
    2127              :          unit_nr = -1
    2128              :       END IF
    2129              : 
    2130              :       ! Calculate the norm of the matrix alpha*matrix, and scale it until it is less than 1.0
    2131         1146 :       norm_scalar = ABS(alpha)*dbcsr_frobenius_norm(matrix)
    2132              : 
    2133              :       ! k=scaling parameter
    2134         1146 :       k = 1
    2135         1008 :       DO
    2136         2154 :          IF ((norm_scalar/2.0_dp**k) <= one) EXIT
    2137         1008 :          k = k + 1
    2138              :       END DO
    2139              : 
    2140              :       ! copy and scale the input matrix in matrix C and in matrix D
    2141         1146 :       CALL dbcsr_create(C, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2142         1146 :       CALL dbcsr_copy(C, matrix)
    2143         1146 :       CALL dbcsr_scale(C, alpha_scalar=alpha/2.0_dp**k)
    2144              : 
    2145         1146 :       CALL dbcsr_create(D, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2146         1146 :       CALL dbcsr_copy(D, C)
    2147              : 
    2148              :       !   write(*,*)
    2149              :       !   write(*,*)
    2150              :       !   CALL dbcsr_print(D, variable_name="D")
    2151              : 
    2152              :       ! set the B matrix as B=Identity+D
    2153         1146 :       CALL dbcsr_create(B, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2154         1146 :       CALL dbcsr_copy(B, D)
    2155         1146 :       CALL dbcsr_add_on_diag(B, alpha=one)
    2156              : 
    2157              :       !   CALL dbcsr_print(B, variable_name="B")
    2158              : 
    2159              :       ! Calculate the norm of C and moltiply by toll to be used as a threshold
    2160         1146 :       norm_C = toll*dbcsr_frobenius_norm(matrix)
    2161              : 
    2162              :       ! iteration for the truncated taylor series expansion
    2163         1146 :       CALL dbcsr_create(D_product, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2164         1146 :       i = 1
    2165              :       DO
    2166        12676 :          i = i + 1
    2167              :          ! compute D_product=D*C
    2168              :          CALL dbcsr_multiply("N", "N", one, D, C, &
    2169        12676 :                              zero, D_product, filter_eps=threshold)
    2170              : 
    2171              :          ! copy D_product in D
    2172        12676 :          CALL dbcsr_copy(D, D_product)
    2173              : 
    2174              :          ! calculate B=B+D_product/fat(i)
    2175        12676 :          factorial = ifac(i)
    2176        12676 :          CALL dbcsr_add(B, D_product, one, factorial)
    2177              : 
    2178              :          ! check for convergence using the norm of D (copy of the matrix D_product) and C
    2179        12676 :          norm_D = factorial*dbcsr_frobenius_norm(D)
    2180        12676 :          IF (norm_D < norm_C) EXIT
    2181              :       END DO
    2182              : 
    2183              :       ! start the k iteration for the squaring of the matrix
    2184         1146 :       CALL dbcsr_create(B_square, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2185         3300 :       DO i = 1, k
    2186              :          !compute B_square=B*B
    2187              :          CALL dbcsr_multiply("N", "N", one, B, B, &
    2188         2154 :                              zero, B_square, filter_eps=threshold)
    2189              :          ! copy Bsquare in B to iterate
    2190         3300 :          CALL dbcsr_copy(B, B_square)
    2191              :       END DO
    2192              : 
    2193              :       ! copy B_square in matrix_exp and
    2194         1146 :       CALL dbcsr_copy(matrix_exp, B_square)
    2195              : 
    2196              :       ! scale matrix_exp by omega, matrix_exp=omega*B_square
    2197         1146 :       CALL dbcsr_scale(matrix_exp, alpha_scalar=omega)
    2198              :       ! write(*,*) alpha,omega
    2199              : 
    2200         1146 :       CALL dbcsr_release(B)
    2201         1146 :       CALL dbcsr_release(C)
    2202         1146 :       CALL dbcsr_release(D)
    2203         1146 :       CALL dbcsr_release(D_product)
    2204         1146 :       CALL dbcsr_release(B_square)
    2205              : 
    2206         1146 :       CALL timestop(handle)
    2207              : 
    2208         1146 :    END SUBROUTINE matrix_exponential
    2209              : 
    2210              : ! **************************************************************************************************
    2211              : !> \brief McWeeny purification of a matrix in the orthonormal basis
    2212              : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2213              : !> \param threshold Threshold used as filter_eps and convergence criteria
    2214              : !> \param max_steps Max number of iterations
    2215              : !> \par History
    2216              : !>       2013.01 created [Florian Schiffmann]
    2217              : !>       2014.07 slightly refactored [Ole Schuett]
    2218              : !> \author Florian Schiffmann
    2219              : ! **************************************************************************************************
    2220          234 :    SUBROUTINE purify_mcweeny_orth(matrix_p, threshold, max_steps)
    2221              :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2222              :       REAL(KIND=dp)                                      :: threshold
    2223              :       INTEGER                                            :: max_steps
    2224              : 
    2225              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_orth'
    2226              : 
    2227              :       INTEGER                                            :: handle, i, ispin, unit_nr
    2228              :       REAL(KIND=dp)                                      :: frob_norm, trace
    2229              :       TYPE(cp_logger_type), POINTER                      :: logger
    2230              :       TYPE(dbcsr_type)                                   :: matrix_pp, matrix_tmp
    2231              : 
    2232          234 :       CALL timeset(routineN, handle)
    2233          234 :       logger => cp_get_default_logger()
    2234          234 :       IF (logger%para_env%is_source()) THEN
    2235          117 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2236              :       ELSE
    2237          117 :          unit_nr = -1
    2238              :       END IF
    2239              : 
    2240          234 :       CALL dbcsr_create(matrix_pp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2241          234 :       CALL dbcsr_create(matrix_tmp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2242          234 :       CALL dbcsr_trace(matrix_p(1), trace)
    2243              : 
    2244          476 :       DO ispin = 1, SIZE(matrix_p)
    2245          476 :          DO i = 1, max_steps
    2246              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_p(ispin), &
    2247          242 :                                 0.0_dp, matrix_pp, filter_eps=threshold)
    2248              : 
    2249              :             ! test convergence
    2250          242 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2251          242 :             CALL dbcsr_add(matrix_tmp, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2252          242 :             frob_norm = dbcsr_frobenius_norm(matrix_tmp) ! tmp = PP - P
    2253          242 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2254          242 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2255              : 
    2256              :             ! construct new P
    2257          242 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2258              :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_pp, matrix_p(ispin), &
    2259          242 :                                 3.0_dp, matrix_tmp, filter_eps=threshold)
    2260          242 :             CALL dbcsr_copy(matrix_p(ispin), matrix_tmp) ! tmp = 3PP - 2PPP
    2261              : 
    2262              :             ! frob_norm < SQRT(trace*threshold)
    2263          242 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2264              :          END DO
    2265              :       END DO
    2266              : 
    2267          234 :       CALL dbcsr_release(matrix_pp)
    2268          234 :       CALL dbcsr_release(matrix_tmp)
    2269          234 :       CALL timestop(handle)
    2270          234 :    END SUBROUTINE purify_mcweeny_orth
    2271              : 
    2272              : ! **************************************************************************************************
    2273              : !> \brief McWeeny purification of a matrix in the non-orthonormal basis
    2274              : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2275              : !> \param matrix_s Overlap-Matrix
    2276              : !> \param threshold Threshold used as filter_eps and convergence criteria
    2277              : !> \param max_steps Max number of iterations
    2278              : !> \par History
    2279              : !>       2013.01 created [Florian Schiffmann]
    2280              : !>       2014.07 slightly refactored [Ole Schuett]
    2281              : !> \author Florian Schiffmann
    2282              : ! **************************************************************************************************
    2283          196 :    SUBROUTINE purify_mcweeny_nonorth(matrix_p, matrix_s, threshold, max_steps)
    2284              :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2285              :       TYPE(dbcsr_type)                                   :: matrix_s
    2286              :       REAL(KIND=dp)                                      :: threshold
    2287              :       INTEGER                                            :: max_steps
    2288              : 
    2289              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_nonorth'
    2290              : 
    2291              :       INTEGER                                            :: handle, i, ispin, unit_nr
    2292              :       REAL(KIND=dp)                                      :: frob_norm, trace
    2293              :       TYPE(cp_logger_type), POINTER                      :: logger
    2294              :       TYPE(dbcsr_type)                                   :: matrix_ps, matrix_psp, matrix_test
    2295              : 
    2296          196 :       CALL timeset(routineN, handle)
    2297              : 
    2298          196 :       logger => cp_get_default_logger()
    2299          196 :       IF (logger%para_env%is_source()) THEN
    2300           98 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2301              :       ELSE
    2302           98 :          unit_nr = -1
    2303              :       END IF
    2304              : 
    2305          196 :       CALL dbcsr_create(matrix_ps, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2306          196 :       CALL dbcsr_create(matrix_psp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2307          196 :       CALL dbcsr_create(matrix_test, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2308              : 
    2309          392 :       DO ispin = 1, SIZE(matrix_p)
    2310          404 :          DO i = 1, max_steps
    2311              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_s, &
    2312          208 :                                 0.0_dp, matrix_ps, filter_eps=threshold)
    2313              :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_ps, matrix_p(ispin), &
    2314          208 :                                 0.0_dp, matrix_psp, filter_eps=threshold)
    2315          208 :             IF (i == 1) CALL dbcsr_trace(matrix_ps, trace)
    2316              : 
    2317              :             ! test convergence
    2318          208 :             CALL dbcsr_copy(matrix_test, matrix_psp)
    2319          208 :             CALL dbcsr_add(matrix_test, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2320          208 :             frob_norm = dbcsr_frobenius_norm(matrix_test) ! test = PSP - P
    2321          208 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,2f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2322          208 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2323              : 
    2324              :             ! construct new P
    2325          208 :             CALL dbcsr_copy(matrix_p(ispin), matrix_psp)
    2326              :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_ps, matrix_psp, &
    2327          208 :                                 3.0_dp, matrix_p(ispin), filter_eps=threshold)
    2328              : 
    2329              :             ! frob_norm < SQRT(trace*threshold)
    2330          208 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2331              :          END DO
    2332              :       END DO
    2333              : 
    2334          196 :       CALL dbcsr_release(matrix_ps)
    2335          196 :       CALL dbcsr_release(matrix_psp)
    2336          196 :       CALL dbcsr_release(matrix_test)
    2337          196 :       CALL timestop(handle)
    2338          196 :    END SUBROUTINE purify_mcweeny_nonorth
    2339              : 
    2340            0 : END MODULE iterate_matrix
        

Generated by: LCOV version 2.0-1