LCOV - code coverage report
Current view: top level - src/fm - cp_fm_basic_linalg.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:ccc2433) Lines: 591 825 71.6 %
Date: 2024-04-25 07:09:54 Functions: 29 42 69.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief basic linear algebra operations for full matrices
      10             : !> \par History
      11             : !>      08.2002 split out of qs_blacs [fawzi]
      12             : !> \author Fawzi Mohamed
      13             : ! **************************************************************************************************
      14             : MODULE cp_fm_basic_linalg
      15             :    USE cp_blacs_env, ONLY: cp_blacs_env_type
      16             :    USE cp_fm_struct, ONLY: cp_fm_struct_equivalent
      17             :    USE cp_fm_types, ONLY: &
      18             :       cp_fm_create, cp_fm_get_diag, cp_fm_get_info, cp_fm_get_submatrix, cp_fm_p_type, &
      19             :       cp_fm_release, cp_fm_set_all, cp_fm_set_element, cp_fm_set_submatrix, cp_fm_to_fm, &
      20             :       cp_fm_type
      21             :    USE cp_log_handling, ONLY: cp_logger_get_default_unit_nr, &
      22             :                               cp_to_string
      23             :    USE kahan_sum, ONLY: accurate_dot_product, &
      24             :                         accurate_sum
      25             :    USE kinds, ONLY: dp, &
      26             :                     int_8, &
      27             :                     sp
      28             :    USE machine, ONLY: m_memory
      29             :    USE mathlib, ONLY: get_pseudo_inverse_svd, &
      30             :                       invert_matrix
      31             :    USE message_passing, ONLY: mp_comm_type
      32             : #include "../base/base_uses.f90"
      33             : 
      34             :    IMPLICIT NONE
      35             :    PRIVATE
      36             : 
      37             :    LOGICAL, PRIVATE, PARAMETER :: debug_this_module = .TRUE.
      38             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'cp_fm_basic_linalg'
      39             : 
      40             :    PUBLIC :: cp_fm_scale, & ! scale a matrix
      41             :              cp_fm_scale_and_add, & ! scale and add two matrices
      42             :              cp_fm_geadd, & ! general addition
      43             :              cp_fm_column_scale, & ! scale columns of a matrix
      44             :              cp_fm_row_scale, & ! scale rows of a matrix
      45             :              cp_fm_trace, & ! trace of the transpose(A)*B
      46             :              cp_fm_contracted_trace, & ! sum_{i,...,k} Tr [A(i,...,k)^T * B(i,...,k)]
      47             :              cp_fm_norm, & ! different norms of A
      48             :              cp_fm_schur_product, & ! schur product
      49             :              cp_fm_transpose, & ! transpose a matrix
      50             :              cp_fm_upper_to_full, & ! symmetrise an upper symmetric matrix
      51             :              cp_fm_syrk, & ! rank k update
      52             :              cp_fm_triangular_multiply, & ! triangular matrix multiply / solve
      53             :              cp_fm_symm, & ! multiply a symmetric with a non-symmetric matrix
      54             :              cp_fm_gemm, & ! multiply two matrices
      55             :              cp_complex_fm_gemm, & ! multiply two complex matrices, represented by non_complex fm matrices
      56             :              cp_fm_invert, & ! computes the inverse and determinant
      57             :              cp_fm_frobenius_norm, & ! frobenius norm
      58             :              cp_fm_triangular_invert, & ! compute the reciprocal of a triangular matrix
      59             :              cp_fm_qr_factorization, & ! compute the QR factorization of a rectangular matrix
      60             :              cp_fm_solve, & ! solves the equation  A*B=C A and C are input
      61             :              cp_fm_pdgeqpf, & ! compute a QR factorization with column pivoting of a M-by-N distributed matrix
      62             :              cp_fm_pdorgqr, & ! generates an M-by-N as first N columns of a product of K elementary reflectors
      63             :              cp_fm_potrf, & ! Cholesky decomposition
      64             :              cp_fm_potri, & ! Invert triangular matrix
      65             :              cp_fm_rot_rows, & ! rotates two rows
      66             :              cp_fm_rot_cols, & ! rotates two columns
      67             :              cp_fm_cholesky_restore, & ! apply Cholesky decomposition
      68             :              cp_fm_Gram_Schmidt_orthonorm, & ! Gram-Schmidt orthonormalization of columns of a full matrix, &
      69             :              cp_fm_det ! determinant of a real matrix with correct sign
      70             : 
      71             :    REAL(KIND=dp), EXTERNAL :: dlange, pdlange, pdlatra
      72             :    REAL(KIND=sp), EXTERNAL :: slange, pslange, pslatra
      73             : 
      74             :    INTERFACE cp_fm_trace
      75             :       MODULE PROCEDURE cp_fm_trace_a0b0t0
      76             :       MODULE PROCEDURE cp_fm_trace_a1b0t1_a
      77             :       MODULE PROCEDURE cp_fm_trace_a1b0t1_p
      78             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_aa
      79             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_ap
      80             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_pa
      81             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_pp
      82             :    END INTERFACE cp_fm_trace
      83             : 
      84             :    INTERFACE cp_fm_contracted_trace
      85             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_aa
      86             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_ap
      87             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_pa
      88             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_pp
      89             :    END INTERFACE cp_fm_contracted_trace
      90             : CONTAINS
      91             : 
      92             : ! **************************************************************************************************
      93             : !> \brief Computes the determinant (with a correct sign even in parallel environment!) of a real square matrix
      94             : !> \author A. Sinyavskiy (andrey.sinyavskiy@chem.uzh.ch)
      95             : ! **************************************************************************************************
      96           0 :    SUBROUTINE cp_fm_det(matrix_a, det_a)
      97             : 
      98             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a
      99             :       REAL(KIND=dp), INTENT(OUT)               :: det_a
     100             :       REAL(KIND=dp)                            :: determinant
     101             :       TYPE(cp_fm_type)                         :: matrix_lu
     102             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
     103             :       INTEGER                                  :: n, i, info, P
     104           0 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
     105             :       REAL(KIND=dp), DIMENSION(:), POINTER     :: diag
     106             : 
     107             : #if defined(__SCALAPACK)
     108             :       INTEGER, EXTERNAL                        :: indxl2g
     109             :       INTEGER                                  :: myprow, nprow, npcol, nrow_local, nrow_block, irow_local
     110             :       INTEGER, DIMENSION(9)                    :: desca
     111             : #endif
     112             : 
     113             :       CALL cp_fm_create(matrix=matrix_lu, &
     114             :                         matrix_struct=matrix_a%matrix_struct, &
     115           0 :                         name="A_lu"//TRIM(ADJUSTL(cp_to_string(1)))//"MATRIX")
     116           0 :       CALL cp_fm_to_fm(matrix_a, matrix_lu)
     117             : 
     118           0 :       a => matrix_lu%local_data
     119           0 :       n = matrix_lu%matrix_struct%nrow_global
     120           0 :       ALLOCATE (ipivot(n))
     121           0 :       ipivot(:) = 0
     122           0 :       P = 0
     123           0 :       ALLOCATE (diag(n))
     124           0 :       diag(:) = 0.0_dp
     125             : #if defined(__SCALAPACK)
     126             :       ! Use LU decomposition
     127           0 :       desca(:) = matrix_lu%matrix_struct%descriptor(:)
     128           0 :       CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
     129           0 :       CALL cp_fm_get_diag(matrix_lu, diag)
     130           0 :       determinant = PRODUCT(diag)
     131           0 :       myprow = matrix_lu%matrix_struct%context%mepos(1)
     132           0 :       nprow = matrix_lu%matrix_struct%context%num_pe(1)
     133           0 :       npcol = matrix_lu%matrix_struct%context%num_pe(2)
     134           0 :       nrow_local = matrix_lu%matrix_struct%nrow_locals(myprow)
     135           0 :       nrow_block = matrix_lu%matrix_struct%nrow_block
     136           0 :       DO irow_local = 1, nrow_local
     137           0 :          i = indxl2g(irow_local, nrow_block, myprow, matrix_lu%matrix_struct%first_p_pos(1), nprow)
     138           0 :          IF (ipivot(irow_local) /= i) P = P + 1
     139             :       END DO
     140           0 :       CALL matrix_lu%matrix_struct%para_env%sum(P)
     141             :       ! very important fix
     142           0 :       P = P/npcol
     143             : #else
     144             :       CALL dgetrf(n, n, a, n, ipivot, info)
     145             :       CALL cp_fm_get_diag(matrix_lu, diag)
     146             :       determinant = PRODUCT(diag)
     147             :       DO i = 1, n
     148             :          IF (ipivot(i) /= i) P = P + 1
     149             :       END DO
     150             : #endif
     151           0 :       DEALLOCATE (ipivot)
     152           0 :       DEALLOCATE (diag)
     153           0 :       CALL cp_fm_release(matrix_lu)
     154           0 :       det_a = determinant*(-2*MOD(P, 2) + 1.0_dp)
     155           0 :    END SUBROUTINE cp_fm_det
     156             : 
     157             : ! **************************************************************************************************
     158             : !> \brief calc A <- alpha*A + beta*B
     159             : !>      optimized for alpha == 1.0 (just add beta*B) and beta == 0.0 (just
     160             : !>      scale A)
     161             : !> \param alpha ...
     162             : !> \param matrix_a ...
     163             : !> \param beta ...
     164             : !> \param matrix_b ...
     165             : ! **************************************************************************************************
     166     1068494 :    SUBROUTINE cp_fm_scale_and_add(alpha, matrix_a, beta, matrix_b)
     167             : 
     168             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
     169             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a
     170             :       REAL(KIND=dp), INTENT(in), OPTIONAL                :: beta
     171             :       TYPE(cp_fm_type), INTENT(IN), OPTIONAL             :: matrix_b
     172             : 
     173             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_scale_and_add'
     174             : 
     175             :       INTEGER                                            :: handle, size_a, size_b
     176             :       REAL(KIND=dp)                                      :: my_beta
     177     1068494 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a, b
     178     1068494 :       REAL(KIND=sp), DIMENSION(:, :), POINTER            :: a_sp, b_sp
     179             : 
     180     1068494 :       CALL timeset(routineN, handle)
     181             : 
     182     1068494 :       IF (PRESENT(matrix_b)) THEN
     183     1068494 :          my_beta = 1.0_dp
     184             :       ELSE
     185           0 :          my_beta = 0.0_dp
     186             :       END IF
     187     1068494 :       IF (PRESENT(beta)) my_beta = beta
     188     1068494 :       NULLIFY (a, b)
     189             : 
     190     1068494 :       IF (PRESENT(beta)) THEN
     191     1068494 :          CPASSERT(PRESENT(matrix_b))
     192     1068494 :          IF (ASSOCIATED(matrix_a%local_data, matrix_b%local_data)) THEN
     193           0 :             CPWARN("Bad use of routine. Call cp_fm_scale instead")
     194           0 :             CALL cp_fm_scale(alpha + beta, matrix_a)
     195           0 :             CALL timestop(handle)
     196           0 :             RETURN
     197             :          END IF
     198             :       END IF
     199             : 
     200     1068494 :       a => matrix_a%local_data
     201     1068494 :       a_sp => matrix_a%local_data_sp
     202             : 
     203     1068494 :       IF (matrix_a%use_sp) THEN
     204           0 :          size_a = SIZE(a_sp, 1)*SIZE(a_sp, 2)
     205             :       ELSE
     206     1068494 :          size_a = SIZE(a, 1)*SIZE(a, 2)
     207             :       END IF
     208             : 
     209     1068494 :       IF (alpha /= 1.0_dp) THEN
     210       78502 :          IF (matrix_a%use_sp) THEN
     211           0 :             CALL sscal(size_a, REAL(alpha, sp), a_sp, 1)
     212             :          ELSE
     213       78502 :             CALL dscal(size_a, alpha, a, 1)
     214             :          END IF
     215             :       END IF
     216     1068494 :       IF (my_beta .NE. 0.0_dp) THEN
     217     1059500 :          IF (matrix_a%matrix_struct%context /= matrix_b%matrix_struct%context) &
     218           0 :             CPABORT("matrixes must be in the same blacs context")
     219             : 
     220     1059500 :          IF (cp_fm_struct_equivalent(matrix_a%matrix_struct, &
     221             :                                      matrix_b%matrix_struct)) THEN
     222             : 
     223     1059500 :             b => matrix_b%local_data
     224     1059500 :             b_sp => matrix_b%local_data_sp
     225     1059500 :             IF (matrix_b%use_sp) THEN
     226           0 :                size_b = SIZE(b_sp, 1)*SIZE(b_sp, 2)
     227             :             ELSE
     228     1059500 :                size_b = SIZE(b, 1)*SIZE(b, 2)
     229             :             END IF
     230     1059500 :             IF (size_a /= size_b) &
     231           0 :                CPABORT("Matrixes must have same locale sizes")
     232             : 
     233     1059500 :             IF (matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     234           0 :                CALL saxpy(size_a, REAL(my_beta, sp), b_sp, 1, a_sp, 1)
     235     1059500 :             ELSEIF (matrix_a%use_sp .AND. .NOT. matrix_b%use_sp) THEN
     236           0 :                CALL saxpy(size_a, REAL(my_beta, sp), REAL(b, sp), 1, a_sp, 1)
     237     1059500 :             ELSEIF (.NOT. matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     238           0 :                CALL daxpy(size_a, my_beta, REAL(b_sp, dp), 1, a, 1)
     239             :             ELSE
     240     1059500 :                CALL daxpy(size_a, my_beta, b, 1, a, 1)
     241             :             END IF
     242             : 
     243             :          ELSE
     244             : #ifdef __SCALAPACK
     245           0 :             CPABORT("to do (pdscal,pdcopy,pdaxpy)")
     246             : #else
     247             :             CPABORT("")
     248             : #endif
     249             :          END IF
     250             : 
     251             :       END IF
     252             : 
     253     1068494 :       CALL timestop(handle)
     254             : 
     255     1068494 :    END SUBROUTINE cp_fm_scale_and_add
     256             : 
     257             : ! **************************************************************************************************
     258             : !> \brief interface to BLACS geadd:
     259             : !>                matrix_b = beta*matrix_b + alpha*opt(matrix_a)
     260             : !>        where opt(matrix_a) can be either:
     261             : !>              'N':  matrix_a
     262             : !>              'T':  matrix_a^T
     263             : !>              'C':  matrix_a^H (Hermitian conjugate)
     264             : !>        note that this is a level three routine, use cp_fm_scale_and_add if that
     265             : !>        is sufficient for your needs
     266             : !> \param alpha  : complex scalar
     267             : !> \param trans  : 'N' normal, 'T' transposed
     268             : !> \param matrix_a : input matrix_a
     269             : !> \param beta   : complex scalar
     270             : !> \param matrix_b : input matrix_b, upon out put the updated matrix_b
     271             : !> \author  Lianheng Tong
     272             : ! **************************************************************************************************
     273          96 :    SUBROUTINE cp_fm_geadd(alpha, trans, matrix_a, beta, matrix_b)
     274             :       REAL(KIND=dp), INTENT(IN) :: alpha, beta
     275             :       CHARACTER, INTENT(IN) :: trans
     276             :       TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
     277             : 
     278             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_geadd'
     279             : 
     280             :       INTEGER :: nrow_global, ncol_global, handle
     281             :       REAL(KIND=dp), DIMENSION(:, :), POINTER :: aa, bb
     282             : #if defined(__SCALAPACK)
     283             :       INTEGER, DIMENSION(9) :: desca, descb
     284             : #else
     285             :       INTEGER :: ii, jj
     286             : #endif
     287             : 
     288          96 :       CALL timeset(routineN, handle)
     289             : 
     290          96 :       nrow_global = matrix_a%matrix_struct%nrow_global
     291          96 :       ncol_global = matrix_a%matrix_struct%ncol_global
     292          96 :       CPASSERT(nrow_global .EQ. matrix_b%matrix_struct%nrow_global)
     293          96 :       CPASSERT(ncol_global .EQ. matrix_b%matrix_struct%ncol_global)
     294             : 
     295          96 :       aa => matrix_a%local_data
     296          96 :       bb => matrix_b%local_data
     297             : 
     298             : #if defined(__SCALAPACK)
     299         960 :       desca = matrix_a%matrix_struct%descriptor
     300         960 :       descb = matrix_b%matrix_struct%descriptor
     301             :       CALL pdgeadd(trans, &
     302             :                    nrow_global, &
     303             :                    ncol_global, &
     304             :                    alpha, &
     305             :                    aa, &
     306             :                    1, 1, &
     307             :                    desca, &
     308             :                    beta, &
     309             :                    bb, &
     310             :                    1, 1, &
     311          96 :                    descb)
     312             : #else
     313             :       ! dgeadd is not a standard BLAS function, although is implemented
     314             :       ! in some libraries like OpenBLAS, so not going to use it here
     315             :       SELECT CASE (trans)
     316             :       CASE ('T')
     317             :          DO jj = 1, ncol_global
     318             :             DO ii = 1, nrow_global
     319             :                bb(ii, jj) = beta*bb(ii, jj) + alpha*aa(jj, ii)
     320             :             END DO
     321             :          END DO
     322             :       CASE DEFAULT
     323             :          DO jj = 1, ncol_global
     324             :             DO ii = 1, nrow_global
     325             :                bb(ii, jj) = beta*bb(ii, jj) + alpha*aa(ii, jj)
     326             :             END DO
     327             :          END DO
     328             :       END SELECT
     329             : #endif
     330             : 
     331          96 :       CALL timestop(handle)
     332             : 
     333          96 :    END SUBROUTINE cp_fm_geadd
     334             : 
     335             : ! **************************************************************************************************
     336             : !> \brief Computes the LU-decomposition of the matrix, and the determinant of the matrix
     337             : !>      IMPORTANT : the sign of the determinant is not defined correctly yet ....
     338             : !> \param matrix_a ...
     339             : !> \param almost_determinant ...
     340             : !> \param correct_sign ...
     341             : !> \par History
     342             : !>      added correct_sign 02.07 (fschiff)
     343             : !> \author Joost VandeVondele
     344             : !> \note
     345             : !>      - matrix_a is overwritten
     346             : !>      - the sign of the determinant might be wrong
     347             : !>      - SERIOUS WARNING (KNOWN BUG) : the sign of the determinant depends on ipivot
     348             : !>      - one should be able to find out if ipivot is an even or an odd permutation...
     349             : !>        if you need the correct sign, just add correct_sign==.TRUE. (fschiff)
     350             : !>      - Use cp_fm_get_diag instead of n times cp_fm_get_element (A. Bussy)
     351             : ! **************************************************************************************************
     352           0 :    SUBROUTINE cp_fm_lu_decompose(matrix_a, almost_determinant, correct_sign)
     353             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a
     354             :       REAL(KIND=dp), INTENT(OUT)               :: almost_determinant
     355             :       LOGICAL, INTENT(IN), OPTIONAL            :: correct_sign
     356             : 
     357             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_lu_decompose'
     358             : 
     359             :       INTEGER                                  :: handle, i, info, n
     360           0 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
     361             :       REAL(KIND=dp)                            :: determinant
     362             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
     363             : #if defined(__SCALAPACK)
     364             :       INTEGER, DIMENSION(9)                    :: desca
     365           0 :       REAL(KIND=dp), DIMENSION(:), POINTER     :: diag
     366             : #else
     367             :       INTEGER                                  :: lda
     368             : #endif
     369             : 
     370           0 :       CALL timeset(routineN, handle)
     371             : 
     372           0 :       a => matrix_a%local_data
     373           0 :       n = matrix_a%matrix_struct%nrow_global
     374           0 :       ALLOCATE (ipivot(n + matrix_a%matrix_struct%nrow_block))
     375             : 
     376             : #if defined(__SCALAPACK)
     377             :       MARK_USED(correct_sign)
     378           0 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     379           0 :       CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
     380             : 
     381           0 :       ALLOCATE (diag(n))
     382           0 :       diag(:) = 0.0_dp
     383           0 :       CALL cp_fm_get_diag(matrix_a, diag)
     384           0 :       determinant = 1.0_dp
     385           0 :       DO i = 1, n
     386           0 :          determinant = determinant*diag(i)
     387             :       END DO
     388           0 :       DEALLOCATE (diag)
     389             : #else
     390             :       lda = SIZE(a, 1)
     391             :       CALL dgetrf(n, n, a, lda, ipivot, info)
     392             :       determinant = 1.0_dp
     393             :       IF (correct_sign) THEN
     394             :          DO i = 1, n
     395             :             IF (ipivot(i) .NE. i) THEN
     396             :                determinant = -determinant*a(i, i)
     397             :             ELSE
     398             :                determinant = determinant*a(i, i)
     399             :             END IF
     400             :          END DO
     401             :       ELSE
     402             :          DO i = 1, n
     403             :             determinant = determinant*a(i, i)
     404             :          END DO
     405             :       END IF
     406             : #endif
     407             :       ! info is allowed to be zero
     408             :       ! this does just signal a zero diagonal element
     409           0 :       DEALLOCATE (ipivot)
     410           0 :       almost_determinant = determinant ! notice that the sign is random
     411           0 :       CALL timestop(handle)
     412           0 :    END SUBROUTINE
     413             : 
     414             : ! **************************************************************************************************
     415             : !> \brief computes matrix_c = beta * matrix_c + alpha * ( matrix_a  ** transa ) * ( matrix_b ** transb )
     416             : !> \param transa : 'N' -> normal   'T' -> transpose
     417             : !>      alpha,beta :: can be 0.0_dp and 1.0_dp
     418             : !> \param transb ...
     419             : !> \param m ...
     420             : !> \param n ...
     421             : !> \param k ...
     422             : !> \param alpha ...
     423             : !> \param matrix_a : m x k matrix ( ! for transa = 'N')
     424             : !> \param matrix_b : k x n matrix ( ! for transb = 'N')
     425             : !> \param beta ...
     426             : !> \param matrix_c : m x n matrix
     427             : !> \param a_first_col ...
     428             : !> \param a_first_row ...
     429             : !> \param b_first_col : the k x n matrix starts at col b_first_col of matrix_b (avoid usage)
     430             : !> \param b_first_row ...
     431             : !> \param c_first_col ...
     432             : !> \param c_first_row ...
     433             : !> \author Matthias Krack
     434             : !> \note
     435             : !>      matrix_c should have no overlap with matrix_a, matrix_b
     436             : ! **************************************************************************************************
     437         514 :    SUBROUTINE cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
     438             :                          matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
     439             :                          c_first_col, c_first_row)
     440             : 
     441             :       CHARACTER(LEN=1), INTENT(IN)             :: transa, transb
     442             :       INTEGER, INTENT(IN)                      :: m, n, k
     443             :       REAL(KIND=dp), INTENT(IN)                :: alpha
     444             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a, matrix_b
     445             :       REAL(KIND=dp), INTENT(IN)                :: beta
     446             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_c
     447             :       INTEGER, INTENT(IN), OPTIONAL            :: a_first_col, a_first_row, &
     448             :                                                   b_first_col, b_first_row, &
     449             :                                                   c_first_col, c_first_row
     450             : 
     451             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_gemm'
     452             : 
     453             :       INTEGER                                  :: handle, i_a, i_b, i_c, j_a, &
     454             :                                                   j_b, j_c
     455             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, b, c
     456         514 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp, b_sp, c_sp
     457             : #if defined(__SCALAPACK)
     458             :       INTEGER, DIMENSION(9)                    :: desca, descb, descc
     459             : #else
     460             :       INTEGER                                  :: lda, ldb, ldc
     461             : #endif
     462             : 
     463         514 :       CALL timeset(routineN, handle)
     464             : 
     465             :       !sample peak memory
     466         514 :       CALL m_memory()
     467             : 
     468         514 :       a => matrix_a%local_data
     469         514 :       b => matrix_b%local_data
     470         514 :       c => matrix_c%local_data
     471             : 
     472         514 :       a_sp => matrix_a%local_data_sp
     473         514 :       b_sp => matrix_b%local_data_sp
     474         514 :       c_sp => matrix_c%local_data_sp
     475             : 
     476         514 :       IF (PRESENT(a_first_row)) THEN
     477           0 :          i_a = a_first_row
     478             :       ELSE
     479         514 :          i_a = 1
     480             :       END IF
     481         514 :       IF (PRESENT(a_first_col)) THEN
     482           0 :          j_a = a_first_col
     483             :       ELSE
     484         514 :          j_a = 1
     485             :       END IF
     486         514 :       IF (PRESENT(b_first_row)) THEN
     487           0 :          i_b = b_first_row
     488             :       ELSE
     489         514 :          i_b = 1
     490             :       END IF
     491         514 :       IF (PRESENT(b_first_col)) THEN
     492           0 :          j_b = b_first_col
     493             :       ELSE
     494         514 :          j_b = 1
     495             :       END IF
     496         514 :       IF (PRESENT(c_first_row)) THEN
     497           0 :          i_c = c_first_row
     498             :       ELSE
     499         514 :          i_c = 1
     500             :       END IF
     501         514 :       IF (PRESENT(c_first_col)) THEN
     502           0 :          j_c = c_first_col
     503             :       ELSE
     504         514 :          j_c = 1
     505             :       END IF
     506             : 
     507             : #if defined(__SCALAPACK)
     508             : 
     509        5140 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     510        5140 :       descb(:) = matrix_b%matrix_struct%descriptor(:)
     511        5140 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     512             : 
     513         514 :       IF (matrix_a%use_sp .AND. matrix_b%use_sp .AND. matrix_c%use_sp) THEN
     514             : 
     515             :          CALL psgemm(transa, transb, m, n, k, REAL(alpha, sp), a_sp(1, 1), i_a, j_a, desca, b_sp(1, 1), i_b, j_b, &
     516           0 :                      descb, REAL(beta, sp), c_sp(1, 1), i_c, j_c, descc)
     517             : 
     518         514 :       ELSEIF ((.NOT. matrix_a%use_sp) .AND. (.NOT. matrix_b%use_sp) .AND. (.NOT. matrix_c%use_sp)) THEN
     519             : 
     520             :          CALL pdgemm(transa, transb, m, n, k, alpha, a, i_a, j_a, desca, b, i_b, j_b, &
     521         514 :                      descb, beta, c, i_c, j_c, descc)
     522             : 
     523             :       ELSE
     524           0 :          CPABORT("Mixed precision gemm NYI")
     525             :       END IF
     526             : #else
     527             : 
     528             :       IF (matrix_a%use_sp .AND. matrix_b%use_sp .AND. matrix_c%use_sp) THEN
     529             : 
     530             :          lda = SIZE(a_sp, 1)
     531             :          ldb = SIZE(b_sp, 1)
     532             :          ldc = SIZE(c_sp, 1)
     533             : 
     534             :          CALL sgemm(transa, transb, m, n, k, REAL(alpha, sp), a_sp(i_a, j_a), lda, b_sp(i_b, j_b), ldb, &
     535             :                     REAL(beta, sp), c_sp(i_c, j_c), ldc)
     536             : 
     537             :       ELSEIF ((.NOT. matrix_a%use_sp) .AND. (.NOT. matrix_b%use_sp) .AND. (.NOT. matrix_c%use_sp)) THEN
     538             : 
     539             :          lda = SIZE(a, 1)
     540             :          ldb = SIZE(b, 1)
     541             :          ldc = SIZE(c, 1)
     542             : 
     543             :          CALL dgemm(transa, transb, m, n, k, alpha, a(i_a, j_a), lda, b(i_b, j_b), ldb, beta, c(i_c, j_c), ldc)
     544             : 
     545             :       ELSE
     546             :          CPABORT("Mixed precision gemm NYI")
     547             :       END IF
     548             : 
     549             : #endif
     550         514 :       CALL timestop(handle)
     551             : 
     552         514 :    END SUBROUTINE cp_fm_gemm
     553             : 
     554             : ! **************************************************************************************************
     555             : !> \brief computes matrix_c = beta * matrix_c + alpha *  matrix_a  *  matrix_b
     556             : !>      computes matrix_c = beta * matrix_c + alpha *  matrix_b  *  matrix_a
     557             : !>      where matrix_a is symmetric
     558             : !> \param side : 'L' -> matrix_a is on the left 'R' -> matrix_a is on the right
     559             : !>      alpha,beta :: can be 0.0_dp and 1.0_dp
     560             : !> \param uplo ...
     561             : !> \param m ...
     562             : !> \param n ...
     563             : !> \param alpha ...
     564             : !> \param matrix_a : m x m matrix
     565             : !> \param matrix_b : m x n matrix
     566             : !> \param beta ...
     567             : !> \param matrix_c : m x n matrix
     568             : !> \author Matthias Krack
     569             : !> \note
     570             : !>      matrix_c should have no overlap with matrix_a, matrix_b
     571             : !>      all matrices in QS are upper triangular, so uplo should be 'U' always
     572             : !>      matrix_a is always an m x m matrix
     573             : !>      it is typically slower to do cp_fm_symm than cp_fm_gemm (especially in parallel easily 50 percent !)
     574             : ! **************************************************************************************************
     575      143848 :    SUBROUTINE cp_fm_symm(side, uplo, m, n, alpha, matrix_a, matrix_b, beta, matrix_c)
     576             : 
     577             :       CHARACTER(LEN=1), INTENT(IN)             :: side, uplo
     578             :       INTEGER, INTENT(IN)                      :: m, n
     579             :       REAL(KIND=dp), INTENT(IN)                :: alpha
     580             :       TYPE(cp_fm_type), INTENT(IN)                :: matrix_a, matrix_b
     581             :       REAL(KIND=dp), INTENT(IN)                :: beta
     582             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_c
     583             : 
     584             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_symm'
     585             : 
     586             :       INTEGER                                  :: handle
     587      143848 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, b, c
     588             : #if defined(__SCALAPACK)
     589             :       INTEGER, DIMENSION(9)                    :: desca, descb, descc
     590             : #else
     591             :       INTEGER                                  :: lda, ldb, ldc
     592             : #endif
     593             : 
     594      143848 :       CALL timeset(routineN, handle)
     595             : 
     596      143848 :       a => matrix_a%local_data
     597      143848 :       b => matrix_b%local_data
     598      143848 :       c => matrix_c%local_data
     599             : 
     600             : #if defined(__SCALAPACK)
     601             : 
     602     1438480 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     603     1438480 :       descb(:) = matrix_b%matrix_struct%descriptor(:)
     604     1438480 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     605             : 
     606      143848 :       CALL pdsymm(side, uplo, m, n, alpha, a(1, 1), 1, 1, desca, b(1, 1), 1, 1, descb, beta, c(1, 1), 1, 1, descc)
     607             : 
     608             : #else
     609             : 
     610             :       lda = matrix_a%matrix_struct%local_leading_dimension
     611             :       ldb = matrix_b%matrix_struct%local_leading_dimension
     612             :       ldc = matrix_c%matrix_struct%local_leading_dimension
     613             : 
     614             :       CALL dsymm(side, uplo, m, n, alpha, a(1, 1), lda, b(1, 1), ldb, beta, c(1, 1), ldc)
     615             : 
     616             : #endif
     617      143848 :       CALL timestop(handle)
     618             : 
     619      143848 :    END SUBROUTINE cp_fm_symm
     620             : 
     621             : ! **************************************************************************************************
     622             : !> \brief computes the Frobenius norm of matrix_a
     623             : !> \brief computes the Frobenius norm of matrix_a
     624             : !> \param matrix_a : m x n matrix
     625             : !> \return ...
     626             : !> \author VW
     627             : ! **************************************************************************************************
     628        8030 :    FUNCTION cp_fm_frobenius_norm(matrix_a) RESULT(norm)
     629             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a
     630             :       REAL(KIND=dp)                            :: norm
     631             : 
     632             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_frobenius_norm'
     633             : 
     634             :       INTEGER                                  :: handle, size_a
     635        8030 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
     636             :       REAL(KIND=dp), EXTERNAL                  :: DDOT
     637             : #if defined(__SCALAPACK)
     638             :       TYPE(mp_comm_type)                       :: group
     639             : #endif
     640             : 
     641        8030 :       CALL timeset(routineN, handle)
     642             : 
     643             :       norm = 0.0_dp
     644        8030 :       a => matrix_a%local_data
     645        8030 :       size_a = SIZE(a, 1)*SIZE(a, 2)
     646        8030 :       norm = DDOT(size_a, a(1, 1), 1, a(1, 1), 1)
     647             : #if defined(__SCALAPACK)
     648        8030 :       group = matrix_a%matrix_struct%para_env
     649        8030 :       CALL group%sum(norm)
     650             : #endif
     651        8030 :       norm = SQRT(norm)
     652             : 
     653        8030 :       CALL timestop(handle)
     654             : 
     655        8030 :    END FUNCTION cp_fm_frobenius_norm
     656             : 
     657             : ! **************************************************************************************************
     658             : !> \brief performs a rank-k update of a symmetric matrix_c
     659             : !>         matrix_c = beta * matrix_c + alpha * matrix_a * transpose ( matrix_a )
     660             : !> \param uplo : 'U'   ('L')
     661             : !> \param trans : 'N'  ('T')
     662             : !> \param k : number of cols to use in matrix_a
     663             : !>      ia,ja ::  1,1 (could be used for selecting subblock of a)
     664             : !> \param alpha ...
     665             : !> \param matrix_a ...
     666             : !> \param ia ...
     667             : !> \param ja ...
     668             : !> \param beta ...
     669             : !> \param matrix_c ...
     670             : !> \author Matthias Krack
     671             : !> \note
     672             : !>      In QS uplo should 'U' (upper part updated)
     673             : ! **************************************************************************************************
     674        6294 :    SUBROUTINE cp_fm_syrk(uplo, trans, k, alpha, matrix_a, ia, ja, beta, matrix_c)
     675             :       CHARACTER(LEN=1), INTENT(IN)             :: uplo, trans
     676             :       INTEGER, INTENT(IN)                      :: k
     677             :       REAL(KIND=dp), INTENT(IN)                :: alpha
     678             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a
     679             :       INTEGER, INTENT(IN)                      :: ia, ja
     680             :       REAL(KIND=dp), INTENT(IN)                :: beta
     681             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_c
     682             : 
     683             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_syrk'
     684             : 
     685             :       INTEGER                                  :: handle, n
     686        6294 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, c
     687             : #if defined(__SCALAPACK)
     688             :       INTEGER, DIMENSION(9)                    :: desca, descc
     689             : #else
     690             :       INTEGER                                  :: lda, ldc
     691             : #endif
     692             : 
     693        6294 :       CALL timeset(routineN, handle)
     694             : 
     695        6294 :       n = matrix_c%matrix_struct%nrow_global
     696             : 
     697        6294 :       a => matrix_a%local_data
     698        6294 :       c => matrix_c%local_data
     699             : 
     700             : #if defined(__SCALAPACK)
     701             : 
     702       62940 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     703       62940 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     704             : 
     705        6294 :       CALL pdsyrk(uplo, trans, n, k, alpha, a(1, 1), ia, ja, desca, beta, c(1, 1), 1, 1, descc)
     706             : 
     707             : #else
     708             : 
     709             :       lda = SIZE(a, 1)
     710             :       ldc = SIZE(c, 1)
     711             : 
     712             :       CALL dsyrk(uplo, trans, n, k, alpha, a(ia, ja), lda, beta, c(1, 1), ldc)
     713             : 
     714             : #endif
     715        6294 :       CALL timestop(handle)
     716             : 
     717        6294 :    END SUBROUTINE cp_fm_syrk
     718             : 
     719             : ! **************************************************************************************************
     720             : !> \brief computes the schur product of two matrices
     721             : !>       c_ij = a_ij * b_ij
     722             : !> \param matrix_a ...
     723             : !> \param matrix_b ...
     724             : !> \param matrix_c ...
     725             : !> \author Joost VandeVondele
     726             : ! **************************************************************************************************
     727        9190 :    SUBROUTINE cp_fm_schur_product(matrix_a, matrix_b, matrix_c)
     728             : 
     729             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b, matrix_c
     730             : 
     731             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_schur_product'
     732             : 
     733             :       INTEGER                                            :: handle, icol_local, irow_local, mypcol, &
     734             :                                                             myprow, ncol_local, npcol, nprow, &
     735             :                                                             nrow_local
     736        9190 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a, b, c
     737             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     738             : 
     739        9190 :       CALL timeset(routineN, handle)
     740             : 
     741        9190 :       context => matrix_a%matrix_struct%context
     742        9190 :       myprow = context%mepos(1)
     743        9190 :       mypcol = context%mepos(2)
     744        9190 :       nprow = context%num_pe(1)
     745        9190 :       npcol = context%num_pe(2)
     746             : 
     747        9190 :       a => matrix_a%local_data
     748        9190 :       b => matrix_b%local_data
     749        9190 :       c => matrix_c%local_data
     750             : 
     751        9190 :       nrow_local = matrix_a%matrix_struct%nrow_locals(myprow)
     752        9190 :       ncol_local = matrix_a%matrix_struct%ncol_locals(mypcol)
     753             : 
     754       99952 :       DO icol_local = 1, ncol_local
     755     6860227 :          DO irow_local = 1, nrow_local
     756     6851037 :             c(irow_local, icol_local) = a(irow_local, icol_local)*b(irow_local, icol_local)
     757             :          END DO
     758             :       END DO
     759             : 
     760        9190 :       CALL timestop(handle)
     761             : 
     762        9190 :    END SUBROUTINE cp_fm_schur_product
     763             : 
     764             : ! **************************************************************************************************
     765             : !> \brief returns the trace of matrix_a^T matrix_b, i.e
     766             : !>      sum_{i,j}(matrix_a(i,j)*matrix_b(i,j))
     767             : !> \param matrix_a a matrix
     768             : !> \param matrix_b another matrix
     769             : !> \param trace ...
     770             : !> \par History
     771             : !>      11.06.2001 Creation (Matthias Krack)
     772             : !>      12.2002 added doc [fawzi]
     773             : !> \author Matthias Krack
     774             : !> \note
     775             : !>      note the transposition of matrix_a!
     776             : ! **************************************************************************************************
     777      685048 :    SUBROUTINE cp_fm_trace_a0b0t0(matrix_a, matrix_b, trace)
     778             : 
     779             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b
     780             :       REAL(KIND=dp), INTENT(OUT)                         :: trace
     781             : 
     782             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_trace_a0b0t0'
     783             : 
     784             :       INTEGER                                            :: handle, mypcol, myprow, ncol_local, &
     785             :                                                             npcol, nprow, nrow_local
     786      685048 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a, b
     787      685048 :       REAL(KIND=sp), DIMENSION(:, :), POINTER            :: a_sp, b_sp
     788             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     789             :       TYPE(mp_comm_type)                                 :: group
     790             : 
     791      685048 :       CALL timeset(routineN, handle)
     792             : 
     793      685048 :       context => matrix_a%matrix_struct%context
     794      685048 :       myprow = context%mepos(1)
     795      685048 :       mypcol = context%mepos(2)
     796      685048 :       nprow = context%num_pe(1)
     797      685048 :       npcol = context%num_pe(2)
     798             : 
     799      685048 :       group = matrix_a%matrix_struct%para_env
     800             : 
     801      685048 :       a => matrix_a%local_data
     802      685048 :       b => matrix_b%local_data
     803             : 
     804      685048 :       a_sp => matrix_a%local_data_sp
     805      685048 :       b_sp => matrix_b%local_data_sp
     806             : 
     807      685048 :       nrow_local = MIN(matrix_a%matrix_struct%nrow_locals(myprow), matrix_b%matrix_struct%nrow_locals(myprow))
     808      685048 :       ncol_local = MIN(matrix_a%matrix_struct%ncol_locals(mypcol), matrix_b%matrix_struct%ncol_locals(mypcol))
     809             : 
     810             :       ! cries for an accurate_dot_product
     811      685048 :       IF (matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     812             :          trace = accurate_sum(REAL(a_sp(1:nrow_local, 1:ncol_local)* &
     813           0 :                                    b_sp(1:nrow_local, 1:ncol_local), dp))
     814      685048 :       ELSEIF (matrix_a%use_sp .AND. .NOT. matrix_b%use_sp) THEN
     815             :          trace = accurate_sum(REAL(a_sp(1:nrow_local, 1:ncol_local), dp)* &
     816           0 :                               b(1:nrow_local, 1:ncol_local))
     817      685048 :       ELSEIF (.NOT. matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     818             :          trace = accurate_sum(a(1:nrow_local, 1:ncol_local)* &
     819           0 :                               REAL(b_sp(1:nrow_local, 1:ncol_local), dp))
     820             :       ELSE
     821             :          trace = accurate_dot_product(a(1:nrow_local, 1:ncol_local), &
     822      685048 :                                       b(1:nrow_local, 1:ncol_local))
     823             :       END IF
     824             : 
     825      685048 :       CALL group%sum(trace)
     826             : 
     827      685048 :       CALL timestop(handle)
     828             : 
     829      685048 :    END SUBROUTINE cp_fm_trace_a0b0t0
     830             : 
     831             :    #:mute
     832             :       #:set types = [("cp_fm_type", "a", ""), ("cp_fm_p_type", "p","%matrix")]
     833             :    #:endmute
     834             : 
     835             : ! **************************************************************************************************
     836             : !> \brief Compute trace(k) = Tr (matrix_a(k)^T matrix_b) for each pair of matrices A_k and B.
     837             : !> \param matrix_a list of A matrices
     838             : !> \param matrix_b B matrix
     839             : !> \param trace    computed traces
     840             : !> \par History
     841             : !>    * 08.2018 forked from cp_fm_trace() [Sergey Chulkov]
     842             : !> \note \parblock
     843             : !>      Computing the trace requires collective communication between involved MPI processes
     844             : !>      that implies a synchronisation point between them. The aim of this subroutine is to reduce
     845             : !>      the amount of time wasted in such synchronisation by performing one large collective
     846             : !>      operation which involves all the matrices in question.
     847             : !>
     848             : !>      The subroutine's suffix reflects dimensionality of dummy arrays; 'a1b0t1' means that
     849             : !>      the dummy variables 'matrix_a' and 'trace' are 1-dimensional arrays, while the variable
     850             : !>      'matrix_b' is a single matrix.
     851             : !>      \endparblock
     852             : ! **************************************************************************************************
     853             :    #:for longname, shortname, appendix in types
     854        3030 :       SUBROUTINE cp_fm_trace_a1b0t1_${shortname}$ (matrix_a, matrix_b, trace)
     855             :          TYPE(${longname}$), DIMENSION(:), INTENT(in)       :: matrix_a
     856             :          TYPE(cp_fm_type), INTENT(IN)                       :: matrix_b
     857             :          REAL(kind=dp), DIMENSION(:), INTENT(out)           :: trace
     858             : 
     859             :          CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_trace_a1b0t1_${shortname}$'
     860             : 
     861             :          INTEGER                                            :: handle, imatrix, n_matrices, &
     862             :                                                                ncols_local, nrows_local
     863             :          LOGICAL                                            :: use_sp_a, use_sp_b
     864        3030 :          REAL(kind=dp), DIMENSION(:, :), POINTER            :: ldata_a, ldata_b
     865        3030 :          REAL(kind=sp), DIMENSION(:, :), POINTER            :: ldata_a_sp, ldata_b_sp
     866             :          TYPE(mp_comm_type)                                 :: group
     867             : 
     868        3030 :          CALL timeset(routineN, handle)
     869             : 
     870        3030 :          n_matrices = SIZE(trace)
     871        3030 :          CPASSERT(SIZE(matrix_a) == n_matrices)
     872             : 
     873        3030 :          CALL cp_fm_get_info(matrix_b, nrow_local=nrows_local, ncol_local=ncols_local)
     874        3030 :          use_sp_b = matrix_b%use_sp
     875             : 
     876        3030 :          IF (use_sp_b) THEN
     877           0 :             ldata_b_sp => matrix_b%local_data_sp(1:nrows_local, 1:ncols_local)
     878             :          ELSE
     879        3030 :             ldata_b => matrix_b%local_data(1:nrows_local, 1:ncols_local)
     880             :          END IF
     881             : 
     882             : !$OMP PARALLEL DO DEFAULT(NONE), &
     883             : !$OMP             PRIVATE(imatrix, ldata_a, ldata_a_sp, use_sp_a), &
     884             : !$OMP             SHARED(ldata_b, ldata_b_sp, matrix_a, matrix_b), &
     885        3030 : !$OMP             SHARED(ncols_local, nrows_local, n_matrices, trace, use_sp_b)
     886             : 
     887             :          DO imatrix = 1, n_matrices
     888             : 
     889             :             use_sp_a = matrix_a(imatrix) ${appendix}$%use_sp
     890             : 
     891             :             ! assume that the matrices A(i) and B have identical shapes and distribution schemes
     892             :             IF (use_sp_a .AND. use_sp_b) THEN
     893             :                ldata_a_sp => matrix_a(imatrix) ${appendix}$%local_data_sp(1:nrows_local, 1:ncols_local)
     894             :                trace(imatrix) = accurate_dot_product(ldata_a_sp, ldata_b_sp)
     895             :             ELSE IF (.NOT. use_sp_a .AND. .NOT. use_sp_b) THEN
     896             :                ldata_a => matrix_a(imatrix) ${appendix}$%local_data(1:nrows_local, 1:ncols_local)
     897             :                trace(imatrix) = accurate_dot_product(ldata_a, ldata_b)
     898             :             ELSE
     899             :                CPABORT("Matrices A and B are of different types")
     900             :             END IF
     901             :          END DO
     902             : !$OMP END PARALLEL DO
     903             : 
     904        3030 :          group = matrix_b%matrix_struct%para_env
     905       18882 :          CALL group%sum(trace)
     906             : 
     907        3030 :          CALL timestop(handle)
     908        3030 :       END SUBROUTINE cp_fm_trace_a1b0t1_${shortname}$
     909             :    #:endfor
     910             : 
     911             : ! **************************************************************************************************
     912             : !> \brief Compute trace(k) = Tr (matrix_a(k)^T matrix_b(k)) for each pair of matrices A_k and B_k.
     913             : !> \param matrix_a list of A matrices
     914             : !> \param matrix_b list of B matrices
     915             : !> \param trace    computed traces
     916             : !> \param accurate ...
     917             : !> \par History
     918             : !>    * 11.2016 forked from cp_fm_trace() [Sergey Chulkov]
     919             : !> \note \parblock
     920             : !>      Computing the trace requires collective communication between involved MPI processes
     921             : !>      that implies a synchronisation point between them. The aim of this subroutine is to reduce
     922             : !>      the amount of time wasted in such synchronisation by performing one large collective
     923             : !>      operation which involves all the matrices in question.
     924             : !>
     925             : !>      The subroutine's suffix reflects dimensionality of dummy arrays; 'a1b1t1' means that
     926             : !>      all dummy variables (matrix_a, matrix_b, and trace) are 1-dimensional arrays.
     927             : !>      \endparblock
     928             : ! **************************************************************************************************
     929             :    #:for longname1, shortname1, appendix1 in types
     930             :       #:for longname2, shortname2, appendix2 in types
     931      138548 :          SUBROUTINE cp_fm_trace_a1b1t1_${shortname1}$${shortname2}$ (matrix_a, matrix_b, trace, accurate)
     932             :             TYPE(${longname1}$), DIMENSION(:), INTENT(in)       :: matrix_a
     933             :             TYPE(${longname2}$), DIMENSION(:), INTENT(in)       :: matrix_b
     934             :             REAL(kind=dp), DIMENSION(:), INTENT(out)           :: trace
     935             :             LOGICAL, INTENT(IN), OPTIONAL                      :: accurate
     936             : 
     937             :             CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_trace_a1b1t1_${shortname1}$${shortname2}$'
     938             : 
     939             :             INTEGER                                            :: handle, imatrix, n_matrices, &
     940             :                                                                   ncols_local, nrows_local
     941             :             LOGICAL                                            :: use_accurate_sum, use_sp_a, use_sp_b
     942      138548 :             REAL(kind=dp), DIMENSION(:, :), POINTER            :: ldata_a, ldata_b
     943      138548 :             REAL(kind=sp), DIMENSION(:, :), POINTER            :: ldata_a_sp, ldata_b_sp
     944             :             TYPE(mp_comm_type)                                 :: group
     945             : 
     946      138548 :             CALL timeset(routineN, handle)
     947             : 
     948      138548 :             n_matrices = SIZE(trace)
     949      138548 :             CPASSERT(SIZE(matrix_a) == n_matrices)
     950      138548 :             CPASSERT(SIZE(matrix_b) == n_matrices)
     951             : 
     952      138548 :             use_accurate_sum = .TRUE.
     953      138548 :             IF (PRESENT(accurate)) use_accurate_sum = accurate
     954             : 
     955             : !$OMP PARALLEL DO DEFAULT(NONE), &
     956             : !$OMP             PRIVATE(imatrix, ldata_a, ldata_a_sp, ldata_b, ldata_b_sp, ncols_local), &
     957             : !$OMP             PRIVATE(nrows_local, use_sp_a, use_sp_b), &
     958      138548 : !$OMP             SHARED(matrix_a, matrix_b, n_matrices, trace, use_accurate_sum)
     959             :             DO imatrix = 1, n_matrices
     960             :                CALL cp_fm_get_info(matrix_a(imatrix) ${appendix1}$, nrow_local=nrows_local, ncol_local=ncols_local)
     961             : 
     962             :                use_sp_a = matrix_a(imatrix) ${appendix1}$%use_sp
     963             :                use_sp_b = matrix_b(imatrix) ${appendix2}$%use_sp
     964             : 
     965             :                ! assume that the matrices A(i) and B(i) have identical shapes and distribution schemes
     966             :                IF (use_sp_a .AND. use_sp_b) THEN
     967             :                   ldata_a_sp => matrix_a(imatrix) ${appendix1}$%local_data_sp(1:nrows_local, 1:ncols_local)
     968             :                   ldata_b_sp => matrix_b(imatrix) ${appendix2}$%local_data_sp(1:nrows_local, 1:ncols_local)
     969             :                   IF (use_accurate_sum) THEN
     970             :                      trace(imatrix) = accurate_dot_product(ldata_a_sp, ldata_b_sp)
     971             :                   ELSE
     972             :                      trace(imatrix) = SUM(ldata_a_sp*ldata_b_sp)
     973             :                   END IF
     974             :                ELSE IF (.NOT. use_sp_a .AND. .NOT. use_sp_b) THEN
     975             :                   ldata_a => matrix_a(imatrix) ${appendix1}$%local_data(1:nrows_local, 1:ncols_local)
     976             :                   ldata_b => matrix_b(imatrix) ${appendix2}$%local_data(1:nrows_local, 1:ncols_local)
     977             :                   IF (use_accurate_sum) THEN
     978             :                      trace(imatrix) = accurate_dot_product(ldata_a, ldata_b)
     979             :                   ELSE
     980             :                      trace(imatrix) = SUM(ldata_a*ldata_b)
     981             :                   END IF
     982             :                ELSE
     983             :                   CPABORT("Matrices A and B are of different types")
     984             :                END IF
     985             :             END DO
     986             : !$OMP END PARALLEL DO
     987             : 
     988      138548 :             group = matrix_a(1) ${appendix1}$%matrix_struct%para_env
     989      460124 :             CALL group%sum(trace)
     990             : 
     991      138548 :             CALL timestop(handle)
     992      138548 :          END SUBROUTINE cp_fm_trace_a1b1t1_${shortname1}$${shortname2}$
     993             :       #:endfor
     994             :    #:endfor
     995             : 
     996             : ! **************************************************************************************************
     997             : !> \brief Compute trace(i,j) = \sum_k Tr (matrix_a(k,i)^T matrix_b(k,j)).
     998             : !> \param matrix_a list of A matrices
     999             : !> \param matrix_b list of B matrices
    1000             : !> \param trace    computed traces
    1001             : !> \param accurate ...
    1002             : ! **************************************************************************************************
    1003             :    #:for longname1, shortname1, appendix1 in types
    1004             :       #:for longname2, shortname2, appendix2 in types
    1005       13804 :          SUBROUTINE cp_fm_contracted_trace_a2b2t2_${shortname1}$${shortname2}$ (matrix_a, matrix_b, trace, accurate)
    1006             :             TYPE(${longname1}$), DIMENSION(:, :), INTENT(in)       :: matrix_a
    1007             :             TYPE(${longname2}$), DIMENSION(:, :), INTENT(in)       :: matrix_b
    1008             :             REAL(kind=dp), DIMENSION(:, :), INTENT(out)        :: trace
    1009             :             LOGICAL, INTENT(IN), OPTIONAL                      :: accurate
    1010             : 
    1011             :             CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_contracted_trace_a2b2t2_${shortname1}$${shortname2}$'
    1012             : 
    1013             :             INTEGER                                            :: handle, ia, ib, iz, na, nb, ncols_local, &
    1014             :                                                                   nrows_local, nz
    1015             :             INTEGER(kind=int_8)                                :: ib8, itrace, na8, ntraces
    1016             :             LOGICAL                                            :: use_accurate_sum, use_sp_a, use_sp_b
    1017             :             REAL(kind=dp)                                      :: t
    1018       13804 :             REAL(kind=dp), DIMENSION(:, :), POINTER            :: ldata_a, ldata_b
    1019       13804 :             REAL(kind=sp), DIMENSION(:, :), POINTER            :: ldata_a_sp, ldata_b_sp
    1020             :             TYPE(mp_comm_type)                                 :: group
    1021             : 
    1022       13804 :             CALL timeset(routineN, handle)
    1023             : 
    1024       13804 :             nz = SIZE(matrix_a, 1)
    1025       13804 :             CPASSERT(SIZE(matrix_b, 1) == nz)
    1026             : 
    1027       13804 :             na = SIZE(matrix_a, 2)
    1028       13804 :             nb = SIZE(matrix_b, 2)
    1029       13804 :             CPASSERT(SIZE(trace, 1) == na)
    1030       13804 :             CPASSERT(SIZE(trace, 2) == nb)
    1031             : 
    1032       13804 :             use_accurate_sum = .TRUE.
    1033       13804 :             IF (PRESENT(accurate)) use_accurate_sum = accurate
    1034             : 
    1035             :             ! here we use one running index (itrace) instead of two (ia, ib) in order to
    1036             :             ! improve load balance between shared-memory threads
    1037       13804 :             ntraces = na*nb
    1038       13804 :             na8 = INT(na, kind=int_8)
    1039             : 
    1040             : !$OMP PARALLEL DO DEFAULT(NONE), &
    1041             : !$OMP             PRIVATE(ia, ib, ib8, itrace, iz, ldata_a, ldata_a_sp, ldata_b, ldata_b_sp, ncols_local), &
    1042             : !$OMP             PRIVATE(nrows_local, t, use_sp_a, use_sp_b), &
    1043       13804 : !$OMP             SHARED(matrix_a, matrix_b, na, na8, nb, ntraces, nz, trace, use_accurate_sum)
    1044             :             DO itrace = 1, ntraces
    1045             :                ib8 = (itrace - 1)/na8
    1046             :                ia = INT(itrace - ib8*na8)
    1047             :                ib = INT(ib8) + 1
    1048             : 
    1049             :                t = 0.0_dp
    1050             :                DO iz = 1, nz
    1051             :                   CALL cp_fm_get_info(matrix_a(iz, ia) ${appendix1}$, nrow_local=nrows_local, ncol_local=ncols_local)
    1052             :                   use_sp_a = matrix_a(iz, ia) ${appendix1}$%use_sp
    1053             :                   use_sp_b = matrix_b(iz, ib) ${appendix2}$%use_sp
    1054             : 
    1055             :                   ! assume that the matrices A(iz, ia) and B(iz, ib) have identical shapes and distribution schemes
    1056             :                   IF (.NOT. use_sp_a .AND. .NOT. use_sp_b) THEN
    1057             :                      ldata_a => matrix_a(iz, ia) ${appendix1}$%local_data(1:nrows_local, 1:ncols_local)
    1058             :                      ldata_b => matrix_b(iz, ib) ${appendix2}$%local_data(1:nrows_local, 1:ncols_local)
    1059             :                      IF (use_accurate_sum) THEN
    1060             :                         t = t + accurate_dot_product(ldata_a, ldata_b)
    1061             :                      ELSE
    1062             :                         t = t + SUM(ldata_a*ldata_b)
    1063             :                      END IF
    1064             :                   ELSE IF (use_sp_a .AND. use_sp_b) THEN
    1065             :                      ldata_a_sp => matrix_a(iz, ia) ${appendix1}$%local_data_sp(1:nrows_local, 1:ncols_local)
    1066             :                      ldata_b_sp => matrix_b(iz, ib) ${appendix2}$%local_data_sp(1:nrows_local, 1:ncols_local)
    1067             :                      IF (use_accurate_sum) THEN
    1068             :                         t = t + accurate_dot_product(ldata_a_sp, ldata_b_sp)
    1069             :                      ELSE
    1070             :                         t = t + SUM(ldata_a_sp*ldata_b_sp)
    1071             :                      END IF
    1072             :                   ELSE
    1073             :                      CPABORT("Matrices A and B are of different types")
    1074             :                   END IF
    1075             :                END DO
    1076             :                trace(ia, ib) = t
    1077             :             END DO
    1078             : !$OMP END PARALLEL DO
    1079             : 
    1080       13804 :             group = matrix_a(1, 1) ${appendix1}$%matrix_struct%para_env
    1081      616984 :             CALL group%sum(trace)
    1082             : 
    1083       13804 :             CALL timestop(handle)
    1084       13804 :          END SUBROUTINE cp_fm_contracted_trace_a2b2t2_${shortname1}$${shortname2}$
    1085             :       #:endfor
    1086             :    #:endfor
    1087             : 
    1088             : ! **************************************************************************************************
    1089             : !> \brief multiplies in place by a triangular matrix:
    1090             : !>       matrix_b = alpha op(triangular_matrix) matrix_b
    1091             : !>      or (if side='R')
    1092             : !>       matrix_b = alpha matrix_b op(triangular_matrix)
    1093             : !>      op(triangular_matrix) is:
    1094             : !>       triangular_matrix (if transpose_tr=.false. and invert_tr=.false.)
    1095             : !>       triangular_matrix^T (if transpose_tr=.true. and invert_tr=.false.)
    1096             : !>       triangular_matrix^(-1) (if transpose_tr=.false. and invert_tr=.true.)
    1097             : !>       triangular_matrix^(-T) (if transpose_tr=.true. and invert_tr=.true.)
    1098             : !> \param triangular_matrix the triangular matrix that multiplies the other
    1099             : !> \param matrix_b the matrix that gets multiplied and stores the result
    1100             : !> \param side on which side of matrix_b stays op(triangular_matrix)
    1101             : !>        (defaults to 'L')
    1102             : !> \param transpose_tr if the triangular matrix should be transposed
    1103             : !>        (defaults to false)
    1104             : !> \param invert_tr if the triangular matrix should be inverted
    1105             : !>        (defaults to false)
    1106             : !> \param uplo_tr if triangular_matrix is stored in the upper ('U') or
    1107             : !>        lower ('L') triangle (defaults to 'U')
    1108             : !> \param unit_diag_tr if the diagonal elements of triangular_matrix should
    1109             : !>        be assumed to be 1 (defaults to false)
    1110             : !> \param n_rows the number of rows of the result (defaults to
    1111             : !>        size(matrix_b,1))
    1112             : !> \param n_cols the number of columns of the result (defaults to
    1113             : !>        size(matrix_b,2))
    1114             : !> \param alpha ...
    1115             : !> \par History
    1116             : !>      08.2002 created [fawzi]
    1117             : !> \author Fawzi Mohamed
    1118             : !> \note
    1119             : !>      needs an mpi env
    1120             : ! **************************************************************************************************
    1121      101426 :    SUBROUTINE cp_fm_triangular_multiply(triangular_matrix, matrix_b, side, &
    1122             :                                         transpose_tr, invert_tr, uplo_tr, unit_diag_tr, n_rows, n_cols, &
    1123             :                                         alpha)
    1124             :       TYPE(cp_fm_type), INTENT(IN)                       :: triangular_matrix, matrix_b
    1125             :       CHARACTER, INTENT(in), OPTIONAL                    :: side
    1126             :       LOGICAL, INTENT(in), OPTIONAL                      :: transpose_tr, invert_tr
    1127             :       CHARACTER, INTENT(in), OPTIONAL                    :: uplo_tr
    1128             :       LOGICAL, INTENT(in), OPTIONAL                      :: unit_diag_tr
    1129             :       INTEGER, INTENT(in), OPTIONAL                      :: n_rows, n_cols
    1130             :       REAL(KIND=dp), INTENT(in), OPTIONAL                :: alpha
    1131             : 
    1132             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_triangular_multiply'
    1133             : 
    1134             :       CHARACTER                                          :: side_char, transa, unit_diag, uplo
    1135             :       INTEGER                                            :: handle, m, n
    1136             :       LOGICAL                                            :: invert
    1137             :       REAL(KIND=dp)                                      :: al
    1138             : 
    1139       50713 :       CALL timeset(routineN, handle)
    1140       50713 :       side_char = 'L'
    1141       50713 :       unit_diag = 'N'
    1142       50713 :       uplo = 'U'
    1143       50713 :       transa = 'N'
    1144       50713 :       invert = .FALSE.
    1145       50713 :       al = 1.0_dp
    1146       50713 :       CALL cp_fm_get_info(matrix_b, nrow_global=m, ncol_global=n)
    1147       50713 :       IF (PRESENT(side)) side_char = side
    1148       50713 :       IF (PRESENT(invert_tr)) invert = invert_tr
    1149       50713 :       IF (PRESENT(uplo_tr)) uplo = uplo_tr
    1150       50713 :       IF (PRESENT(unit_diag_tr)) THEN
    1151           0 :          IF (unit_diag_tr) THEN
    1152           0 :             unit_diag = 'U'
    1153             :          ELSE
    1154             :             unit_diag = 'N'
    1155             :          END IF
    1156             :       END IF
    1157       50713 :       IF (PRESENT(transpose_tr)) THEN
    1158        3346 :          IF (transpose_tr) THEN
    1159        1218 :             transa = 'T'
    1160             :          ELSE
    1161             :             transa = 'N'
    1162             :          END IF
    1163             :       END IF
    1164       50713 :       IF (PRESENT(alpha)) al = alpha
    1165       50713 :       IF (PRESENT(n_rows)) m = n_rows
    1166       50713 :       IF (PRESENT(n_cols)) n = n_cols
    1167             : 
    1168       50713 :       IF (invert) THEN
    1169             : 
    1170             : #if defined(__SCALAPACK)
    1171             :          CALL pdtrsm(side_char, uplo, transa, unit_diag, m, n, al, &
    1172             :                      triangular_matrix%local_data(1, 1), 1, 1, &
    1173             :                      triangular_matrix%matrix_struct%descriptor, &
    1174             :                      matrix_b%local_data(1, 1), 1, 1, &
    1175       41659 :                      matrix_b%matrix_struct%descriptor(1))
    1176             : #else
    1177             :          CALL dtrsm(side_char, uplo, transa, unit_diag, m, n, al, &
    1178             :                     triangular_matrix%local_data(1, 1), &
    1179             :                     SIZE(triangular_matrix%local_data, 1), &
    1180             :                     matrix_b%local_data(1, 1), SIZE(matrix_b%local_data, 1))
    1181             : #endif
    1182             : 
    1183             :       ELSE
    1184             : 
    1185             : #if defined(__SCALAPACK)
    1186             :          CALL pdtrmm(side_char, uplo, transa, unit_diag, m, n, al, &
    1187             :                      triangular_matrix%local_data(1, 1), 1, 1, &
    1188             :                      triangular_matrix%matrix_struct%descriptor, &
    1189             :                      matrix_b%local_data(1, 1), 1, 1, &
    1190        9054 :                      matrix_b%matrix_struct%descriptor(1))
    1191             : #else
    1192             :          CALL dtrmm(side_char, uplo, transa, unit_diag, m, n, al, &
    1193             :                     triangular_matrix%local_data(1, 1), &
    1194             :                     SIZE(triangular_matrix%local_data, 1), &
    1195             :                     matrix_b%local_data(1, 1), SIZE(matrix_b%local_data, 1))
    1196             : #endif
    1197             : 
    1198             :       END IF
    1199             : 
    1200       50713 :       CALL timestop(handle)
    1201       50713 :    END SUBROUTINE cp_fm_triangular_multiply
    1202             : 
    1203             : ! **************************************************************************************************
    1204             : !> \brief scales a matrix
    1205             : !>      matrix_a = alpha * matrix_b
    1206             : !> \param alpha ...
    1207             : !> \param matrix_a ...
    1208             : !> \note
    1209             : !>      use cp_fm_set_all to zero (avoids problems with nan)
    1210             : ! **************************************************************************************************
    1211       82793 :    SUBROUTINE cp_fm_scale(alpha, matrix_a)
    1212             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
    1213             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a
    1214             : 
    1215             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'cp_fm_scale'
    1216             : 
    1217             :       INTEGER                                            :: handle, size_a
    1218             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a
    1219             : 
    1220       82793 :       CALL timeset(routineN, handle)
    1221             : 
    1222             :       NULLIFY (a)
    1223             : 
    1224       82793 :       a => matrix_a%local_data
    1225       82793 :       size_a = SIZE(a, 1)*SIZE(a, 2)
    1226             : 
    1227       82793 :       CALL DSCAL(size_a, alpha, a, 1)
    1228             : 
    1229       82793 :       CALL timestop(handle)
    1230             : 
    1231       82793 :    END SUBROUTINE cp_fm_scale
    1232             : 
    1233             : ! **************************************************************************************************
    1234             : !> \brief transposes a matrix
    1235             : !>      matrixt = matrix ^ T
    1236             : !> \param matrix ...
    1237             : !> \param matrixt ...
    1238             : !> \note
    1239             : !>      all matrix elements are transposed (see cp_fm_upper_to_half to symmetrise a matrix)
    1240             : ! **************************************************************************************************
    1241       19814 :    SUBROUTINE cp_fm_transpose(matrix, matrixt)
    1242             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix, matrixt
    1243             : 
    1244             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_transpose'
    1245             : 
    1246             :       INTEGER                                  :: handle, ncol_global, &
    1247             :                                                   nrow_global, ncol_globalt, nrow_globalt
    1248        9907 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, c
    1249             : #if defined(__SCALAPACK)
    1250             :       INTEGER, DIMENSION(9)                    :: desca, descc
    1251             : #else
    1252             :       INTEGER                                  :: i, j
    1253             : #endif
    1254             : 
    1255        9907 :       nrow_global = matrix%matrix_struct%nrow_global
    1256        9907 :       ncol_global = matrix%matrix_struct%ncol_global
    1257        9907 :       nrow_globalt = matrixt%matrix_struct%nrow_global
    1258        9907 :       ncol_globalt = matrixt%matrix_struct%ncol_global
    1259           0 :       CPASSERT(nrow_global == ncol_globalt)
    1260        9907 :       CPASSERT(nrow_globalt == ncol_global)
    1261             : 
    1262        9907 :       CALL timeset(routineN, handle)
    1263             : 
    1264        9907 :       a => matrix%local_data
    1265        9907 :       c => matrixt%local_data
    1266             : 
    1267             : #if defined(__SCALAPACK)
    1268       99070 :       desca(:) = matrix%matrix_struct%descriptor(:)
    1269       99070 :       descc(:) = matrixt%matrix_struct%descriptor(:)
    1270        9907 :       CALL pdtran(ncol_global, nrow_global, 1.0_dp, a(1, 1), 1, 1, desca, 0.0_dp, c(1, 1), 1, 1, descc)
    1271             : #else
    1272             :       DO j = 1, ncol_global
    1273             :          DO i = 1, nrow_global
    1274             :             c(j, i) = a(i, j)
    1275             :          END DO
    1276             :       END DO
    1277             : #endif
    1278        9907 :       CALL timestop(handle)
    1279             : 
    1280        9907 :    END SUBROUTINE cp_fm_transpose
    1281             : 
    1282             : ! **************************************************************************************************
    1283             : !> \brief given an upper triangular matrix computes the corresponding full matrix
    1284             : !> \param matrix the upper triangular matrix as input, the full matrix as output
    1285             : !> \param work a matrix of the same size as matrix
    1286             : !> \author Matthias Krack
    1287             : !> \note
    1288             : !>       the lower triangular part is irrelevant
    1289             : ! **************************************************************************************************
    1290      291230 :    SUBROUTINE cp_fm_upper_to_full(matrix, work)
    1291             : 
    1292             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix, work
    1293             : 
    1294             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_upper_to_full'
    1295             : 
    1296             :       INTEGER                                  :: handle, icol_global, irow_global, &
    1297             :                                                   mypcol, myprow, ncol_global, &
    1298             :                                                   npcol, nprow, nrow_global
    1299      145615 :       REAL(KIND=dp), DIMENSION(:, :), POINTER   :: a
    1300      145615 :       REAL(KIND=sp), DIMENSION(:, :), POINTER   :: a_sp
    1301             :       TYPE(cp_blacs_env_type), POINTER         :: context
    1302             : 
    1303             : #if defined(__SCALAPACK)
    1304             :       INTEGER                                  :: icol_local, irow_local, &
    1305             :                                                   ncol_block, ncol_local, &
    1306             :                                                   nrow_block, nrow_local
    1307             :       INTEGER, DIMENSION(9)                    :: desca, descc
    1308             :       INTEGER, EXTERNAL                        :: indxl2g
    1309      145615 :       REAL(KIND=dp), DIMENSION(:, :), POINTER   :: c
    1310      145615 :       REAL(KIND=sp), DIMENSION(:, :), POINTER   :: c_sp
    1311             : #endif
    1312             : 
    1313      145615 :       nrow_global = matrix%matrix_struct%nrow_global
    1314      145615 :       ncol_global = matrix%matrix_struct%ncol_global
    1315           0 :       CPASSERT(nrow_global == ncol_global)
    1316      145615 :       nrow_global = work%matrix_struct%nrow_global
    1317      145615 :       ncol_global = work%matrix_struct%ncol_global
    1318      145615 :       CPASSERT(nrow_global == ncol_global)
    1319      145615 :       CPASSERT(matrix%use_sp .EQV. work%use_sp)
    1320             : 
    1321      145615 :       CALL timeset(routineN, handle)
    1322             : 
    1323      145615 :       context => matrix%matrix_struct%context
    1324      145615 :       myprow = context%mepos(1)
    1325      145615 :       mypcol = context%mepos(2)
    1326      145615 :       nprow = context%num_pe(1)
    1327      145615 :       npcol = context%num_pe(2)
    1328             : 
    1329             : #if defined(__SCALAPACK)
    1330             : 
    1331      145615 :       nrow_block = matrix%matrix_struct%nrow_block
    1332      145615 :       ncol_block = matrix%matrix_struct%ncol_block
    1333             : 
    1334      145615 :       nrow_local = matrix%matrix_struct%nrow_locals(myprow)
    1335      145615 :       ncol_local = matrix%matrix_struct%ncol_locals(mypcol)
    1336             : 
    1337      145615 :       a => work%local_data
    1338      145615 :       a_sp => work%local_data_sp
    1339     1456150 :       desca(:) = work%matrix_struct%descriptor(:)
    1340      145615 :       c => matrix%local_data
    1341      145615 :       c_sp => matrix%local_data_sp
    1342     1456150 :       descc(:) = matrix%matrix_struct%descriptor(:)
    1343             : 
    1344     4418676 :       DO icol_local = 1, ncol_local
    1345             :          icol_global = indxl2g(icol_local, ncol_block, mypcol, &
    1346     4273061 :                                matrix%matrix_struct%first_p_pos(2), npcol)
    1347   205788304 :          DO irow_local = 1, nrow_local
    1348             :             irow_global = indxl2g(irow_local, nrow_block, myprow, &
    1349   201369628 :                                   matrix%matrix_struct%first_p_pos(1), nprow)
    1350   205642689 :             IF (irow_global > icol_global) THEN
    1351    99369413 :                IF (matrix%use_sp) THEN
    1352           0 :                   c_sp(irow_local, icol_local) = 0.0_sp
    1353             :                ELSE
    1354    99369413 :                   c(irow_local, icol_local) = 0.0_dp
    1355             :                END IF
    1356   102000215 :             ELSE IF (irow_global == icol_global) THEN
    1357     2630802 :                IF (matrix%use_sp) THEN
    1358           0 :                   c_sp(irow_local, icol_local) = 0.5_sp*c_sp(irow_local, icol_local)
    1359             :                ELSE
    1360     2630802 :                   c(irow_local, icol_local) = 0.5_dp*c(irow_local, icol_local)
    1361             :                END IF
    1362             :             END IF
    1363             :          END DO
    1364             :       END DO
    1365             : 
    1366     4418676 :       DO icol_local = 1, ncol_local
    1367   205788304 :       DO irow_local = 1, nrow_local
    1368   205642689 :          IF (matrix%use_sp) THEN
    1369           0 :             a_sp(irow_local, icol_local) = c_sp(irow_local, icol_local)
    1370             :          ELSE
    1371   201369628 :             a(irow_local, icol_local) = c(irow_local, icol_local)
    1372             :          END IF
    1373             :       END DO
    1374             :       END DO
    1375             : 
    1376      145615 :       IF (matrix%use_sp) THEN
    1377           0 :          CALL pstran(nrow_global, ncol_global, 1.0_sp, a_sp(1, 1), 1, 1, desca, 1.0_sp, c_sp(1, 1), 1, 1, descc)
    1378             :       ELSE
    1379      145615 :          CALL pdtran(nrow_global, ncol_global, 1.0_dp, a(1, 1), 1, 1, desca, 1.0_dp, c(1, 1), 1, 1, descc)
    1380             :       END IF
    1381             : 
    1382             : #else
    1383             : 
    1384             :       a => matrix%local_data
    1385             :       a_sp => matrix%local_data_sp
    1386             :       DO irow_global = 1, nrow_global
    1387             :          DO icol_global = irow_global + 1, ncol_global
    1388             :             IF (matrix%use_sp) THEN
    1389             :                a_sp(icol_global, irow_global) = a_sp(irow_global, icol_global)
    1390             :             ELSE
    1391             :                a(icol_global, irow_global) = a(irow_global, icol_global)
    1392             :             END IF
    1393             :          END DO
    1394             :       END DO
    1395             : 
    1396             : #endif
    1397      145615 :       CALL timestop(handle)
    1398             : 
    1399      145615 :    END SUBROUTINE cp_fm_upper_to_full
    1400             : 
    1401             : ! **************************************************************************************************
    1402             : !> \brief scales column i of matrix a with scaling(i)
    1403             : !> \param matrixa ...
    1404             : !> \param scaling : an array used for scaling the columns,
    1405             : !>                  SIZE(scaling) determines the number of columns to be scaled
    1406             : !> \author Joost VandeVondele
    1407             : !> \note
    1408             : !>      this is very useful as a first step in the computation of C = sum_i alpha_i A_i transpose (A_i)
    1409             : !>      that is a rank-k update (cp_fm_syrk , cp_sm_plus_fm_fm_t)
    1410             : !>      this procedure can be up to 20 times faster than calling cp_fm_syrk n times
    1411             : !>      where every vector has a different prefactor
    1412             : ! **************************************************************************************************
    1413      125904 :    SUBROUTINE cp_fm_column_scale(matrixa, scaling)
    1414             :       TYPE(cp_fm_type), INTENT(IN)          :: matrixa
    1415             :       REAL(KIND=dp), DIMENSION(:), INTENT(in)  :: scaling
    1416             : 
    1417             :       INTEGER                                  :: k, mypcol, myprow, n, ncol_global, &
    1418             :                                                   npcol, nprow
    1419      125904 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1420      125904 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
    1421             : #if defined(__SCALAPACK)
    1422             :       INTEGER                                  :: icol_global, icol_local, &
    1423             :                                                   ipcol, iprow, irow_local
    1424             : #else
    1425             :       INTEGER                                  :: i
    1426             : #endif
    1427             : 
    1428      125904 :       myprow = matrixa%matrix_struct%context%mepos(1)
    1429      125904 :       mypcol = matrixa%matrix_struct%context%mepos(2)
    1430      125904 :       nprow = matrixa%matrix_struct%context%num_pe(1)
    1431      125904 :       npcol = matrixa%matrix_struct%context%num_pe(2)
    1432             : 
    1433      125904 :       ncol_global = matrixa%matrix_struct%ncol_global
    1434             : 
    1435      125904 :       a => matrixa%local_data
    1436      125904 :       a_sp => matrixa%local_data_sp
    1437      125904 :       IF (matrixa%use_sp) THEN
    1438           0 :          n = SIZE(a_sp, 1)
    1439             :       ELSE
    1440      125904 :          n = SIZE(a, 1)
    1441             :       END IF
    1442      125904 :       k = MIN(SIZE(scaling), ncol_global)
    1443             : 
    1444             : #if defined(__SCALAPACK)
    1445             : 
    1446     1893213 :       DO icol_global = 1, k
    1447             :          CALL infog2l(1, icol_global, matrixa%matrix_struct%descriptor, &
    1448             :                       nprow, npcol, myprow, mypcol, &
    1449     1767309 :                       irow_local, icol_local, iprow, ipcol)
    1450     1893213 :          IF ((ipcol == mypcol)) THEN
    1451     1767309 :             IF (matrixa%use_sp) THEN
    1452           0 :                CALL SSCAL(n, REAL(scaling(icol_global), sp), a_sp(:, icol_local), 1)
    1453             :             ELSE
    1454     1767309 :                CALL DSCAL(n, scaling(icol_global), a(:, icol_local), 1)
    1455             :             END IF
    1456             :          END IF
    1457             :       END DO
    1458             : #else
    1459             :       DO i = 1, k
    1460             :          IF (matrixa%use_sp) THEN
    1461             :             CALL SSCAL(n, REAL(scaling(i), sp), a_sp(:, i), 1)
    1462             :          ELSE
    1463             :             CALL DSCAL(n, scaling(i), a(:, i), 1)
    1464             :          END IF
    1465             :       END DO
    1466             : #endif
    1467      125904 :    END SUBROUTINE cp_fm_column_scale
    1468             : 
    1469             : ! **************************************************************************************************
    1470             : !> \brief scales row i of matrix a with scaling(i)
    1471             : !> \param matrixa ...
    1472             : !> \param scaling : an array used for scaling the columns,
    1473             : !> \author JGH
    1474             : !> \note
    1475             : ! **************************************************************************************************
    1476        6564 :    SUBROUTINE cp_fm_row_scale(matrixa, scaling)
    1477             :       TYPE(cp_fm_type), INTENT(IN)          :: matrixa
    1478             :       REAL(KIND=dp), DIMENSION(:), INTENT(in)  :: scaling
    1479             : 
    1480             :       INTEGER                                  :: n, m, nrow_global, nrow_local, ncol_local
    1481        6564 :       INTEGER, DIMENSION(:), POINTER           :: row_indices
    1482        6564 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1483        6564 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
    1484             : #if defined(__SCALAPACK)
    1485             :       INTEGER                                  :: irow_global, icol, irow
    1486             : #else
    1487             :       INTEGER                                  :: j
    1488             : #endif
    1489             : 
    1490             :       CALL cp_fm_get_info(matrixa, row_indices=row_indices, nrow_global=nrow_global, &
    1491        6564 :                           nrow_local=nrow_local, ncol_local=ncol_local)
    1492        6564 :       CPASSERT(SIZE(scaling) == nrow_global)
    1493             : 
    1494        6564 :       a => matrixa%local_data
    1495        6564 :       a_sp => matrixa%local_data_sp
    1496        6564 :       IF (matrixa%use_sp) THEN
    1497        6564 :          n = SIZE(a_sp, 1)
    1498        6564 :          m = SIZE(a_sp, 2)
    1499             :       ELSE
    1500        6564 :          n = SIZE(a, 1)
    1501        6564 :          m = SIZE(a, 2)
    1502             :       END IF
    1503             : 
    1504             : #if defined(__SCALAPACK)
    1505       81426 :       DO icol = 1, ncol_local
    1506       81426 :          IF (matrixa%use_sp) THEN
    1507           0 :             DO irow = 1, nrow_local
    1508           0 :                irow_global = row_indices(irow)
    1509           0 :                a(irow, icol) = REAL(scaling(irow_global), dp)*a(irow, icol)
    1510             :             END DO
    1511             :          ELSE
    1512     6667421 :             DO irow = 1, nrow_local
    1513     6592559 :                irow_global = row_indices(irow)
    1514     6667421 :                a(irow, icol) = scaling(irow_global)*a(irow, icol)
    1515             :             END DO
    1516             :          END IF
    1517             :       END DO
    1518             : #else
    1519             :       IF (matrixa%use_sp) THEN
    1520             :          DO j = 1, m
    1521             :             a_sp(1:n, j) = REAL(scaling(1:n), sp)*a_sp(1:n, j)
    1522             :          END DO
    1523             :       ELSE
    1524             :          DO j = 1, m
    1525             :             a(1:n, j) = scaling(1:n)*a(1:n, j)
    1526             :          END DO
    1527             :       END IF
    1528             : #endif
    1529        6564 :    END SUBROUTINE cp_fm_row_scale
    1530             : ! **************************************************************************************************
    1531             : !> \brief Inverts a cp_fm_type matrix, optionally returning the determinant of the input matrix
    1532             : !> \param matrix_a the matrix to invert
    1533             : !> \param matrix_inverse the inverse of matrix_a
    1534             : !> \param det_a the determinant of matrix_a
    1535             : !> \param eps_svd optional parameter to active SVD based inversion, singular values below eps_svd
    1536             : !>                are screened
    1537             : !> \param eigval optionally return matrix eigenvalues/singular values
    1538             : !> \par History
    1539             : !>      note of Jan Wilhelm (12.2015)
    1540             : !>      - computation of determinant corrected
    1541             : !>      - determinant only computed if det_a is present
    1542             : !>      12.2016 added option to use SVD instead of LU [Nico Holmberg]
    1543             : !>      - Use cp_fm_get diag instead of n times cp_fm_get_element (A. Bussy)
    1544             : !> \author Florian Schiffmann(02.2007)
    1545             : ! **************************************************************************************************
    1546         702 :    SUBROUTINE cp_fm_invert(matrix_a, matrix_inverse, det_a, eps_svd, eigval)
    1547             : 
    1548             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a, matrix_inverse
    1549             :       REAL(KIND=dp), INTENT(OUT), OPTIONAL     :: det_a
    1550             :       REAL(KIND=dp), INTENT(IN), OPTIONAL      :: eps_svd
    1551             :       REAL(KIND=dp), DIMENSION(:), POINTER, &
    1552             :          INTENT(INOUT), OPTIONAL               :: eigval
    1553             : 
    1554             :       INTEGER                                  :: n
    1555         702 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
    1556             :       REAL(KIND=dp)                            :: determinant, my_eps_svd
    1557             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1558             :       TYPE(cp_fm_type)                :: matrix_lu
    1559             : 
    1560             : #if defined(__SCALAPACK)
    1561             :       TYPE(cp_fm_type)                :: u, vt, sigma, inv_sigma_ut
    1562             :       TYPE(mp_comm_type) :: group
    1563             :       INTEGER                                  :: i, info, liwork, lwork, exponent_of_minus_one
    1564             :       INTEGER, DIMENSION(9)                    :: desca
    1565             :       LOGICAL                                  :: quenched
    1566             :       REAL(KIND=dp)                            :: alpha, beta
    1567         702 :       REAL(KIND=dp), DIMENSION(:), POINTER     :: diag
    1568         702 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: work
    1569             : #else
    1570             :       LOGICAL                                  :: sign
    1571             :       REAL(KIND=dp)                            :: eps1
    1572             : #endif
    1573             : 
    1574         702 :       my_eps_svd = 0.0_dp
    1575         468 :       IF (PRESENT(eps_svd)) my_eps_svd = eps_svd
    1576             : 
    1577             :       CALL cp_fm_create(matrix=matrix_lu, &
    1578             :                         matrix_struct=matrix_a%matrix_struct, &
    1579         702 :                         name="A_lu"//TRIM(ADJUSTL(cp_to_string(1)))//"MATRIX")
    1580         702 :       CALL cp_fm_to_fm(matrix_a, matrix_lu)
    1581             : 
    1582         702 :       a => matrix_lu%local_data
    1583         702 :       n = matrix_lu%matrix_struct%nrow_global
    1584        2106 :       ALLOCATE (ipivot(n + matrix_a%matrix_struct%nrow_block))
    1585       10934 :       ipivot(:) = 0
    1586             : #if defined(__SCALAPACK)
    1587         702 :       IF (my_eps_svd .EQ. 0.0_dp) THEN
    1588             :          ! Use LU decomposition
    1589         674 :          lwork = 3*n
    1590         674 :          liwork = 3*n
    1591        6740 :          desca(:) = matrix_lu%matrix_struct%descriptor(:)
    1592         674 :          CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
    1593             : 
    1594         674 :          IF (PRESENT(det_a) .OR. PRESENT(eigval)) THEN
    1595             : 
    1596        1116 :             ALLOCATE (diag(n))
    1597         916 :             diag(:) = 0.0_dp
    1598         372 :             CALL cp_fm_get_diag(matrix_lu, diag)
    1599             : 
    1600         372 :             exponent_of_minus_one = 0
    1601         372 :             determinant = 1.0_dp
    1602         916 :             DO i = 1, n
    1603         544 :                determinant = determinant*diag(i)
    1604         916 :                IF (ipivot(i) .NE. i) THEN
    1605         224 :                   exponent_of_minus_one = exponent_of_minus_one + 1
    1606             :                END IF
    1607             :             END DO
    1608         372 :             IF (PRESENT(eigval)) THEN
    1609           0 :                CPASSERT(.NOT. ASSOCIATED(eigval))
    1610           0 :                ALLOCATE (eigval(n))
    1611           0 :                eigval(:) = diag
    1612             :             END IF
    1613         372 :             DEALLOCATE (diag)
    1614             : 
    1615         372 :             group = matrix_lu%matrix_struct%para_env
    1616         372 :             CALL group%sum(exponent_of_minus_one)
    1617             : 
    1618         372 :             determinant = determinant*(-1.0_dp)**exponent_of_minus_one
    1619             : 
    1620             :          END IF
    1621             : 
    1622         674 :          alpha = 0.0_dp
    1623         674 :          beta = 1.0_dp
    1624         674 :          CALL cp_fm_set_all(matrix_inverse, alpha, beta)
    1625         674 :          CALL pdgetrs('N', n, n, matrix_lu%local_data, 1, 1, desca, ipivot, matrix_inverse%local_data, 1, 1, desca, info)
    1626             :       ELSE
    1627             :          ! Use singular value decomposition
    1628             :          CALL cp_fm_create(matrix=u, &
    1629             :                            matrix_struct=matrix_a%matrix_struct, &
    1630          28 :                            name="LEFT_SINGULAR_MATRIX")
    1631          28 :          CALL cp_fm_set_all(u, alpha=0.0_dp)
    1632             :          CALL cp_fm_create(matrix=vt, &
    1633             :                            matrix_struct=matrix_a%matrix_struct, &
    1634          28 :                            name="RIGHT_SINGULAR_MATRIX")
    1635          28 :          CALL cp_fm_set_all(vt, alpha=0.0_dp)
    1636          84 :          ALLOCATE (diag(n))
    1637          92 :          diag(:) = 0.0_dp
    1638         280 :          desca(:) = matrix_lu%matrix_struct%descriptor(:)
    1639          28 :          ALLOCATE (work(1))
    1640             :          ! Workspace query
    1641          28 :          lwork = -1
    1642             :          CALL pdgesvd('V', 'V', n, n, matrix_lu%local_data, 1, 1, desca, diag, u%local_data, &
    1643          28 :                       1, 1, desca, vt%local_data, 1, 1, desca, work, lwork, info)
    1644          28 :          lwork = INT(work(1))
    1645          28 :          DEALLOCATE (work)
    1646          84 :          ALLOCATE (work(lwork))
    1647             :          ! SVD
    1648             :          CALL pdgesvd('V', 'V', n, n, matrix_lu%local_data, 1, 1, desca, diag, u%local_data, &
    1649          28 :                       1, 1, desca, vt%local_data, 1, 1, desca, work, lwork, info)
    1650             :          ! info == n+1 implies homogeneity error when the number of procs is large
    1651             :          ! this likely isnt a problem, but maybe we should handle it separately
    1652          28 :          IF (info /= 0 .AND. info /= n + 1) &
    1653           0 :             CPABORT("Singular value decomposition of matrix failed.")
    1654             :          ! (Pseudo)inverse and (pseudo)determinant
    1655             :          CALL cp_fm_create(matrix=sigma, &
    1656             :                            matrix_struct=matrix_a%matrix_struct, &
    1657          28 :                            name="SINGULAR_VALUE_MATRIX")
    1658          28 :          CALL cp_fm_set_all(sigma, alpha=0.0_dp)
    1659          28 :          determinant = 1.0_dp
    1660          28 :          quenched = .FALSE.
    1661          28 :          IF (PRESENT(eigval)) THEN
    1662          28 :             CPASSERT(.NOT. ASSOCIATED(eigval))
    1663          84 :             ALLOCATE (eigval(n))
    1664         156 :             eigval(:) = diag
    1665             :          END IF
    1666          92 :          DO i = 1, n
    1667          64 :             IF (diag(i) < my_eps_svd) THEN
    1668          18 :                diag(i) = 0.0_dp
    1669          18 :                quenched = .TRUE.
    1670             :             ELSE
    1671          46 :                determinant = determinant*diag(i)
    1672          46 :                diag(i) = 1.0_dp/diag(i)
    1673             :             END IF
    1674          92 :             CALL cp_fm_set_element(sigma, i, i, diag(i))
    1675             :          END DO
    1676          28 :          DEALLOCATE (diag)
    1677          28 :          IF (quenched) &
    1678             :             CALL cp_warn(__LOCATION__, &
    1679             :                          "Linear dependencies were detected in the SVD inversion of matrix "//TRIM(ADJUSTL(matrix_a%name))// &
    1680          12 :                          ". At least one singular value has been quenched.")
    1681             :          ! Sigma^-1 * U^T
    1682             :          CALL cp_fm_create(matrix=inv_sigma_ut, &
    1683             :                            matrix_struct=matrix_a%matrix_struct, &
    1684          28 :                            name="SINGULAR_VALUE_MATRIX")
    1685          28 :          CALL cp_fm_set_all(inv_sigma_ut, alpha=0.0_dp)
    1686             :          CALL pdgemm('N', 'T', n, n, n, 1.0_dp, sigma%local_data, 1, 1, desca, &
    1687          28 :                      u%local_data, 1, 1, desca, 0.0_dp, inv_sigma_ut%local_data, 1, 1, desca)
    1688             :          ! A^-1 = V * (Sigma^-1 * U^T)
    1689          28 :          CALL cp_fm_set_all(matrix_inverse, alpha=0.0_dp)
    1690             :          CALL pdgemm('T', 'N', n, n, n, 1.0_dp, vt%local_data, 1, 1, desca, &
    1691          28 :                      inv_sigma_ut%local_data, 1, 1, desca, 0.0_dp, matrix_inverse%local_data, 1, 1, desca)
    1692             :          ! Clean up
    1693          28 :          DEALLOCATE (work)
    1694          28 :          CALL cp_fm_release(u)
    1695          28 :          CALL cp_fm_release(vt)
    1696          28 :          CALL cp_fm_release(sigma)
    1697          28 :          CALL cp_fm_release(inv_sigma_ut)
    1698             :       END IF
    1699             : #else
    1700             :       IF (my_eps_svd .EQ. 0.0_dp) THEN
    1701             :          sign = .TRUE.
    1702             :          CALL invert_matrix(matrix_a%local_data, matrix_inverse%local_data, &
    1703             :                             eval_error=eps1)
    1704             :          CALL cp_fm_lu_decompose(matrix_lu, determinant, correct_sign=sign)
    1705             :          IF (PRESENT(eigval)) &
    1706             :             CALL cp_abort(__LOCATION__, &
    1707             :                           "NYI. Eigenvalues not available for return without SCALAPACK.")
    1708             :       ELSE
    1709             :          CALL get_pseudo_inverse_svd(matrix_a%local_data, matrix_inverse%local_data, eps_svd, &
    1710             :                                      determinant, eigval)
    1711             :       END IF
    1712             : #endif
    1713         702 :       CALL cp_fm_release(matrix_lu)
    1714         702 :       DEALLOCATE (ipivot)
    1715         702 :       IF (PRESENT(det_a)) det_a = determinant
    1716         702 :    END SUBROUTINE cp_fm_invert
    1717             : 
    1718             : ! **************************************************************************************************
    1719             : !> \brief inverts a triangular matrix
    1720             : !> \param matrix_a ...
    1721             : !> \param uplo_tr ...
    1722             : !> \author MI
    1723             : ! **************************************************************************************************
    1724        4926 :    SUBROUTINE cp_fm_triangular_invert(matrix_a, uplo_tr)
    1725             : 
    1726             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a
    1727             :       CHARACTER, INTENT(IN), OPTIONAL          :: uplo_tr
    1728             : 
    1729             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'cp_fm_triangular_invert'
    1730             : 
    1731             :       CHARACTER                                :: unit_diag, uplo
    1732             :       INTEGER                                  :: handle, info, ncol_global
    1733        4926 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1734             : #if defined(__SCALAPACK)
    1735             :       INTEGER, DIMENSION(9)                    :: desca
    1736             : #endif
    1737             : 
    1738        4926 :       CALL timeset(routineN, handle)
    1739             : 
    1740        4926 :       unit_diag = 'N'
    1741        4926 :       uplo = 'U'
    1742        4926 :       IF (PRESENT(uplo_tr)) uplo = uplo_tr
    1743             : 
    1744        4926 :       ncol_global = matrix_a%matrix_struct%ncol_global
    1745             : 
    1746        4926 :       a => matrix_a%local_data
    1747             : 
    1748             : #if defined(__SCALAPACK)
    1749             : 
    1750       49260 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1751             : 
    1752        4926 :       CALL pdtrtri(uplo, unit_diag, ncol_global, a(1, 1), 1, 1, desca, info)
    1753             : 
    1754             : #else
    1755             :       CALL dtrtri(uplo, unit_diag, ncol_global, a(1, 1), ncol_global, info)
    1756             : #endif
    1757             : 
    1758        4926 :       CALL timestop(handle)
    1759        4926 :    END SUBROUTINE cp_fm_triangular_invert
    1760             : 
    1761             : ! **************************************************************************************************
    1762             : !> \brief  performs a QR factorization of the input rectangular matrix A or of a submatrix of A
    1763             : !>         the computed upper triangular matrix R is in output in the submatrix sub(A) of size NxN
    1764             : !>         M and M give the dimension of the submatrix that has to be factorized (MxN) with M>N
    1765             : !> \param matrix_a ...
    1766             : !> \param matrix_r ...
    1767             : !> \param nrow_fact ...
    1768             : !> \param ncol_fact ...
    1769             : !> \param first_row ...
    1770             : !> \param first_col ...
    1771             : !> \author MI
    1772             : ! **************************************************************************************************
    1773       19320 :    SUBROUTINE cp_fm_qr_factorization(matrix_a, matrix_r, nrow_fact, ncol_fact, first_row, first_col)
    1774             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a, matrix_r
    1775             :       INTEGER, INTENT(IN), OPTIONAL            :: nrow_fact, ncol_fact, &
    1776             :                                                   first_row, first_col
    1777             : 
    1778             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'cp_fm_qr_factorization'
    1779             : 
    1780             :       INTEGER                                  :: handle, i, icol, info, irow, &
    1781             :                                                   j, lda, lwork, ncol, &
    1782             :                                                   ndim, nrow
    1783       19320 :       REAL(dp), ALLOCATABLE, DIMENSION(:)      :: tau, work
    1784       19320 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)   :: r_mat
    1785             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1786             : #if defined(__SCALAPACK)
    1787             :       INTEGER, DIMENSION(9)                    :: desca
    1788             : #endif
    1789             : 
    1790       19320 :       CALL timeset(routineN, handle)
    1791             : 
    1792       19320 :       ncol = matrix_a%matrix_struct%ncol_global
    1793       19320 :       nrow = matrix_a%matrix_struct%nrow_global
    1794       19320 :       lda = nrow
    1795             : 
    1796       19320 :       a => matrix_a%local_data
    1797             : 
    1798       19320 :       IF (PRESENT(nrow_fact)) nrow = nrow_fact
    1799       19320 :       IF (PRESENT(ncol_fact)) ncol = ncol_fact
    1800       19320 :       irow = 1
    1801       19320 :       IF (PRESENT(first_row)) irow = first_row
    1802       19320 :       icol = 1
    1803       19320 :       IF (PRESENT(first_col)) icol = first_col
    1804             : 
    1805       19320 :       CPASSERT(nrow >= ncol)
    1806       19320 :       ndim = SIZE(a, 2)
    1807             : !    ALLOCATE(ipiv(ndim))
    1808       57960 :       ALLOCATE (tau(ndim))
    1809             : 
    1810             : #if defined(__SCALAPACK)
    1811             : 
    1812      193200 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1813             : 
    1814       19320 :       lwork = -1
    1815       57960 :       ALLOCATE (work(2*ndim))
    1816       19320 :       CALL pdgeqrf(nrow, ncol, a, irow, icol, desca, tau, work, lwork, info)
    1817       19320 :       lwork = INT(work(1))
    1818       19320 :       DEALLOCATE (work)
    1819       57960 :       ALLOCATE (work(lwork))
    1820       19320 :       CALL pdgeqrf(nrow, ncol, a, irow, icol, desca, tau, work, lwork, info)
    1821             : 
    1822             : #else
    1823             :       lwork = -1
    1824             :       ALLOCATE (work(2*ndim))
    1825             :       CALL dgeqrf(nrow, ncol, a, lda, tau, work, lwork, info)
    1826             :       lwork = INT(work(1))
    1827             :       DEALLOCATE (work)
    1828             :       ALLOCATE (work(lwork))
    1829             :       CALL dgeqrf(nrow, ncol, a, lda, tau, work, lwork, info)
    1830             : 
    1831             : #endif
    1832             : 
    1833       77280 :       ALLOCATE (r_mat(ncol, ncol))
    1834       19320 :       CALL cp_fm_get_submatrix(matrix_a, r_mat, 1, 1, ncol, ncol)
    1835       38640 :       DO i = 1, ncol
    1836       38640 :          DO j = i + 1, ncol
    1837       19320 :             r_mat(j, i) = 0.0_dp
    1838             :          END DO
    1839             :       END DO
    1840       19320 :       CALL cp_fm_set_submatrix(matrix_r, r_mat, 1, 1, ncol, ncol)
    1841             : 
    1842       19320 :       DEALLOCATE (tau, work, r_mat)
    1843             : 
    1844       19320 :       CALL timestop(handle)
    1845             : 
    1846       19320 :    END SUBROUTINE cp_fm_qr_factorization
    1847             : 
    1848             : ! **************************************************************************************************
    1849             : !> \brief computes the the solution to A*b=A_general using lu decomposition
    1850             : !>        pay attention, both matrices are overwritten, a_general contais the result
    1851             : !> \param matrix_a ...
    1852             : !> \param general_a ...
    1853             : !> \author Florian Schiffmann
    1854             : ! **************************************************************************************************
    1855        4294 :    SUBROUTINE cp_fm_solve(matrix_a, general_a)
    1856             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a, general_a
    1857             : 
    1858             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_solve'
    1859             : 
    1860             :       INTEGER                                  :: handle, info, n
    1861        4294 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
    1862             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, a_general
    1863             : #if defined(__SCALAPACK)
    1864             :       INTEGER, DIMENSION(9)                    :: desca, descb
    1865             : #else
    1866             :       INTEGER                                  :: lda, ldb
    1867             : #endif
    1868             : 
    1869        4294 :       CALL timeset(routineN, handle)
    1870             : 
    1871        4294 :       a => matrix_a%local_data
    1872        4294 :       a_general => general_a%local_data
    1873        4294 :       n = matrix_a%matrix_struct%nrow_global
    1874       12882 :       ALLOCATE (ipivot(n + matrix_a%matrix_struct%nrow_block))
    1875             : 
    1876             : #if defined(__SCALAPACK)
    1877       42940 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1878       42940 :       descb(:) = general_a%matrix_struct%descriptor(:)
    1879        4294 :       CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
    1880             :       CALL pdgetrs("N", n, n, a, 1, 1, desca, ipivot, a_general, &
    1881        4294 :                    1, 1, descb, info)
    1882             : 
    1883             : #else
    1884             :       lda = SIZE(a, 1)
    1885             :       ldb = SIZE(a_general, 1)
    1886             :       CALL dgetrf(n, n, a, lda, ipivot, info)
    1887             :       CALL dgetrs("N", n, n, a, lda, ipivot, a_general, ldb, info)
    1888             : 
    1889             : #endif
    1890             :       ! info is allowed to be zero
    1891             :       ! this does just signal a zero diagonal element
    1892        4294 :       DEALLOCATE (ipivot)
    1893        4294 :       CALL timestop(handle)
    1894        4294 :    END SUBROUTINE
    1895             : 
    1896             : ! **************************************************************************************************
    1897             : !> \brief Convenience function. Computes the matrix multiplications needed
    1898             : !>        for the multiplication of complex matrices.
    1899             : !>        C = beta * C + alpha * ( A  ** transa ) * ( B ** transb )
    1900             : !> \param transa : 'N' -> normal   'T' -> transpose
    1901             : !>      alpha,beta :: can be 0.0_dp and 1.0_dp
    1902             : !> \param transb ...
    1903             : !> \param m ...
    1904             : !> \param n ...
    1905             : !> \param k ...
    1906             : !> \param alpha ...
    1907             : !> \param A_re m x k matrix ( ! for transa = 'N'), real part
    1908             : !> \param A_im m x k matrix ( ! for transa = 'N'), imaginary part
    1909             : !> \param B_re k x n matrix ( ! for transa = 'N'), real part
    1910             : !> \param B_im k x n matrix ( ! for transa = 'N'), imaginary part
    1911             : !> \param beta ...
    1912             : !> \param C_re m x n matrix, real part
    1913             : !> \param C_im m x n matrix, imaginary part
    1914             : !> \param a_first_col ...
    1915             : !> \param a_first_row ...
    1916             : !> \param b_first_col : the k x n matrix starts at col b_first_col of matrix_b (avoid usage)
    1917             : !> \param b_first_row ...
    1918             : !> \param c_first_col ...
    1919             : !> \param c_first_row ...
    1920             : !> \author Samuel Andermatt
    1921             : !> \note
    1922             : !>      C should have no overlap with A, B
    1923             : ! **************************************************************************************************
    1924           0 :    SUBROUTINE cp_complex_fm_gemm(transa, transb, m, n, k, alpha, A_re, A_im, B_re, B_im, beta, &
    1925             :                                  C_re, C_im, a_first_col, a_first_row, b_first_col, b_first_row, c_first_col, &
    1926             :                                  c_first_row)
    1927             :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
    1928             :       INTEGER, INTENT(IN)                                :: m, n, k
    1929             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
    1930             :       TYPE(cp_fm_type), INTENT(IN)                       :: A_re, A_im, B_re, B_im
    1931             :       REAL(KIND=dp), INTENT(IN)                          :: beta
    1932             :       TYPE(cp_fm_type), INTENT(IN)                       :: C_re, C_im
    1933             :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
    1934             :                                                             b_first_row, c_first_col, c_first_row
    1935             : 
    1936             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_complex_fm_gemm'
    1937             : 
    1938             :       INTEGER                                            :: handle
    1939             : 
    1940           0 :       CALL timeset(routineN, handle)
    1941             : 
    1942             :       CALL cp_fm_gemm(transa, transb, m, n, k, alpha, A_re, B_re, beta, C_re, &
    1943             :                       a_first_col=a_first_col, &
    1944             :                       a_first_row=a_first_row, &
    1945             :                       b_first_col=b_first_col, &
    1946             :                       b_first_row=b_first_row, &
    1947             :                       c_first_col=c_first_col, &
    1948           0 :                       c_first_row=c_first_row)
    1949             :       CALL cp_fm_gemm(transa, transb, m, n, k, -alpha, A_im, B_im, 1.0_dp, C_re, &
    1950             :                       a_first_col=a_first_col, &
    1951             :                       a_first_row=a_first_row, &
    1952             :                       b_first_col=b_first_col, &
    1953             :                       b_first_row=b_first_row, &
    1954             :                       c_first_col=c_first_col, &
    1955           0 :                       c_first_row=c_first_row)
    1956             :       CALL cp_fm_gemm(transa, transb, m, n, k, alpha, A_re, B_im, beta, C_im, &
    1957             :                       a_first_col=a_first_col, &
    1958             :                       a_first_row=a_first_row, &
    1959             :                       b_first_col=b_first_col, &
    1960             :                       b_first_row=b_first_row, &
    1961             :                       c_first_col=c_first_col, &
    1962           0 :                       c_first_row=c_first_row)
    1963             :       CALL cp_fm_gemm(transa, transb, m, n, k, alpha, A_im, B_re, 1.0_dp, C_im, &
    1964             :                       a_first_col=a_first_col, &
    1965             :                       a_first_row=a_first_row, &
    1966             :                       b_first_col=b_first_col, &
    1967             :                       b_first_row=b_first_row, &
    1968             :                       c_first_col=c_first_col, &
    1969           0 :                       c_first_row=c_first_row)
    1970             : 
    1971           0 :       CALL timestop(handle)
    1972             : 
    1973           0 :    END SUBROUTINE cp_complex_fm_gemm
    1974             : 
    1975             : ! **************************************************************************************************
    1976             : !> \brief inverts a matrix using LU decomposition
    1977             : !>        the input matrix will be overwritten
    1978             : !> \param matrix   : input a general square non-singular matrix, outputs its inverse
    1979             : !> \param info_out : optional, if present outputs the info from (p)zgetri
    1980             : !> \author Lianheng Tong
    1981             : ! **************************************************************************************************
    1982           0 :    SUBROUTINE cp_fm_lu_invert(matrix, info_out)
    1983             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix
    1984             :       INTEGER, INTENT(OUT), OPTIONAL           :: info_out
    1985             : 
    1986             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_lu_invert'
    1987             : 
    1988             :       INTEGER :: nrows_global, handle, info, lwork
    1989           0 :       INTEGER, DIMENSION(:), ALLOCATABLE       :: ipivot
    1990             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: mat
    1991             :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: mat_sp
    1992           0 :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE :: work
    1993           0 :       REAL(KIND=sp), DIMENSION(:), ALLOCATABLE :: work_sp
    1994             : #if defined(__SCALAPACK)
    1995             :       INTEGER                                  :: liwork
    1996             :       INTEGER, DIMENSION(9)                    :: desca
    1997           0 :       INTEGER, DIMENSION(:), ALLOCATABLE       :: iwork
    1998             : #else
    1999             :       INTEGER                                  :: lda
    2000             : #endif
    2001             : 
    2002           0 :       CALL timeset(routineN, handle)
    2003             : 
    2004           0 :       mat => matrix%local_data
    2005           0 :       mat_sp => matrix%local_data_sp
    2006           0 :       nrows_global = matrix%matrix_struct%nrow_global
    2007           0 :       CPASSERT(nrows_global .EQ. matrix%matrix_struct%ncol_global)
    2008           0 :       ALLOCATE (ipivot(nrows_global))
    2009             :       ! do LU decomposition
    2010             : #if defined(__SCALAPACK)
    2011           0 :       desca = matrix%matrix_struct%descriptor
    2012           0 :       IF (matrix%use_sp) THEN
    2013             :          CALL psgetrf(nrows_global, nrows_global, &
    2014           0 :                       mat_sp, 1, 1, desca, ipivot, info)
    2015             :       ELSE
    2016             :          CALL pdgetrf(nrows_global, nrows_global, &
    2017           0 :                       mat, 1, 1, desca, ipivot, info)
    2018             :       END IF
    2019             : #else
    2020             :       lda = SIZE(mat, 1)
    2021             :       IF (matrix%use_sp) THEN
    2022             :          CALL sgetrf(nrows_global, nrows_global, &
    2023             :                      mat_sp, lda, ipivot, info)
    2024             :       ELSE
    2025             :          CALL dgetrf(nrows_global, nrows_global, &
    2026             :                      mat, lda, ipivot, info)
    2027             :       END IF
    2028             : #endif
    2029           0 :       IF (info /= 0) THEN
    2030           0 :          CALL cp_abort(__LOCATION__, "LU decomposition has failed")
    2031             :       END IF
    2032             :       ! do inversion
    2033           0 :       IF (matrix%use_sp) THEN
    2034           0 :          ALLOCATE (work(1))
    2035             :       ELSE
    2036           0 :          ALLOCATE (work_sp(1))
    2037             :       END IF
    2038             : #if defined(__SCALAPACK)
    2039           0 :       ALLOCATE (iwork(1))
    2040           0 :       IF (matrix%use_sp) THEN
    2041             :          CALL psgetri(nrows_global, mat_sp, 1, 1, desca, &
    2042           0 :                       ipivot, work_sp, -1, iwork, -1, info)
    2043           0 :          lwork = INT(work_sp(1))
    2044           0 :          DEALLOCATE (work_sp)
    2045           0 :          ALLOCATE (work_sp(lwork))
    2046             :       ELSE
    2047             :          CALL pdgetri(nrows_global, mat, 1, 1, desca, &
    2048           0 :                       ipivot, work, -1, iwork, -1, info)
    2049           0 :          lwork = INT(work(1))
    2050           0 :          DEALLOCATE (work)
    2051           0 :          ALLOCATE (work(lwork))
    2052             :       END IF
    2053           0 :       liwork = INT(iwork(1))
    2054           0 :       DEALLOCATE (iwork)
    2055           0 :       ALLOCATE (iwork(liwork))
    2056           0 :       IF (matrix%use_sp) THEN
    2057             :          CALL psgetri(nrows_global, mat_sp, 1, 1, desca, &
    2058           0 :                       ipivot, work_sp, lwork, iwork, liwork, info)
    2059             :       ELSE
    2060             :          CALL pdgetri(nrows_global, mat, 1, 1, desca, &
    2061           0 :                       ipivot, work, lwork, iwork, liwork, info)
    2062             :       END IF
    2063           0 :       DEALLOCATE (iwork)
    2064             : #else
    2065             :       IF (matrix%use_sp) THEN
    2066             :          CALL sgetri(nrows_global, mat_sp, lda, &
    2067             :                      ipivot, work_sp, -1, info)
    2068             :          lwork = INT(work_sp(1))
    2069             :          DEALLOCATE (work_sp)
    2070             :          ALLOCATE (work_sp(lwork))
    2071             :          CALL sgetri(nrows_global, mat_sp, lda, &
    2072             :                      ipivot, work_sp, lwork, info)
    2073             :       ELSE
    2074             :          CALL dgetri(nrows_global, mat, lda, &
    2075             :                      ipivot, work, -1, info)
    2076             :          lwork = INT(work(1))
    2077             :          DEALLOCATE (work)
    2078             :          ALLOCATE (work(lwork))
    2079             :          CALL dgetri(nrows_global, mat, lda, &
    2080             :                      ipivot, work, lwork, info)
    2081             :       END IF
    2082             : #endif
    2083           0 :       IF (matrix%use_sp) THEN
    2084           0 :          DEALLOCATE (work_sp)
    2085             :       ELSE
    2086           0 :          DEALLOCATE (work)
    2087             :       END IF
    2088           0 :       DEALLOCATE (ipivot)
    2089             : 
    2090           0 :       IF (PRESENT(info_out)) THEN
    2091           0 :          info_out = info
    2092             :       ELSE
    2093           0 :          IF (info /= 0) &
    2094           0 :             CALL cp_abort(__LOCATION__, "LU inversion has failed")
    2095             :       END IF
    2096             : 
    2097           0 :       CALL timestop(handle)
    2098             : 
    2099           0 :    END SUBROUTINE cp_fm_lu_invert
    2100             : 
    2101             : ! **************************************************************************************************
    2102             : !> \brief norm of matrix using (p)dlange
    2103             : !> \param matrix   : input a general matrix
    2104             : !> \param mode     : 'M' max abs element value,
    2105             : !>                   '1' or 'O' one norm, i.e. maximum column sum
    2106             : !>                   'I' infinity norm, i.e. maximum row sum
    2107             : !>                   'F' or 'E' Frobenius norm, i.e. sqrt of sum of all squares of elements
    2108             : !> \return : the norm according to mode
    2109             : !> \author Lianheng Tong
    2110             : ! **************************************************************************************************
    2111         492 :    FUNCTION cp_fm_norm(matrix, mode) RESULT(res)
    2112             :       TYPE(cp_fm_type), INTENT(IN) :: matrix
    2113             :       CHARACTER, INTENT(IN) :: mode
    2114             :       REAL(KIND=dp) :: res
    2115             : 
    2116             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_norm'
    2117             : 
    2118             :       INTEGER :: nrows, ncols, handle, lwork, nrows_local, ncols_local
    2119             :       REAL(KIND=sp) :: res_sp
    2120             :       REAL(KIND=dp), DIMENSION(:, :), POINTER :: aa
    2121             :       REAL(KIND=sp), DIMENSION(:, :), POINTER :: aa_sp
    2122         492 :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE :: work
    2123         492 :       REAL(KIND=sp), DIMENSION(:), ALLOCATABLE :: work_sp
    2124             : #if defined(__SCALAPACK)
    2125             :       INTEGER, DIMENSION(9) :: desca
    2126             : #else
    2127             :       INTEGER :: lda
    2128             : #endif
    2129             : 
    2130         492 :       CALL timeset(routineN, handle)
    2131             : 
    2132             :       CALL cp_fm_get_info(matrix=matrix, &
    2133             :                           nrow_global=nrows, &
    2134             :                           ncol_global=ncols, &
    2135             :                           nrow_local=nrows_local, &
    2136         492 :                           ncol_local=ncols_local)
    2137         492 :       aa => matrix%local_data
    2138         492 :       aa_sp => matrix%local_data_sp
    2139             : 
    2140             : #if defined(__SCALAPACK)
    2141        4920 :       desca = matrix%matrix_struct%descriptor
    2142             :       SELECT CASE (mode)
    2143             :       CASE ('M', 'm')
    2144         492 :          lwork = 1
    2145             :       CASE ('1', 'O', 'o')
    2146         492 :          lwork = ncols_local
    2147             :       CASE ('I', 'i')
    2148           0 :          lwork = nrows_local
    2149             :       CASE ('F', 'f', 'E', 'e')
    2150           0 :          lwork = 1
    2151             :       CASE DEFAULT
    2152         492 :          CPABORT("mode input is not valid")
    2153             :       END SELECT
    2154         492 :       IF (matrix%use_sp) THEN
    2155           0 :          ALLOCATE (work_sp(lwork))
    2156           0 :          res_sp = pslange(mode, nrows, ncols, aa_sp, 1, 1, desca, work_sp)
    2157           0 :          DEALLOCATE (work_sp)
    2158           0 :          res = REAL(res_sp, KIND=dp)
    2159             :       ELSE
    2160        1476 :          ALLOCATE (work(lwork))
    2161         492 :          res = pdlange(mode, nrows, ncols, aa, 1, 1, desca, work)
    2162         492 :          DEALLOCATE (work)
    2163             :       END IF
    2164             : #else
    2165             :       SELECT CASE (mode)
    2166             :       CASE ('M', 'm')
    2167             :          lwork = 1
    2168             :       CASE ('1', 'O', 'o')
    2169             :          lwork = 1
    2170             :       CASE ('I', 'i')
    2171             :          lwork = nrows
    2172             :       CASE ('F', 'f', 'E', 'e')
    2173             :          lwork = 1
    2174             :       CASE DEFAULT
    2175             :          CPABORT("mode input is not valid")
    2176             :       END SELECT
    2177             :       IF (matrix%use_sp) THEN
    2178             :          ALLOCATE (work_sp(lwork))
    2179             :          lda = SIZE(aa_sp, 1)
    2180             :          res_sp = slange(mode, nrows, ncols, aa_sp, lda, work_sp)
    2181             :          DEALLOCATE (work_sp)
    2182             :          res = REAL(res_sp, KIND=dp)
    2183             :       ELSE
    2184             :          ALLOCATE (work(lwork))
    2185             :          lda = SIZE(aa, 1)
    2186             :          res = dlange(mode, nrows, ncols, aa, lda, work)
    2187             :          DEALLOCATE (work)
    2188             :       END IF
    2189             : #endif
    2190             : 
    2191         492 :       CALL timestop(handle)
    2192             : 
    2193         492 :    END FUNCTION cp_fm_norm
    2194             : 
    2195             : ! **************************************************************************************************
    2196             : !> \brief trace of a matrix using pdlatra
    2197             : !> \param matrix   : input a square matrix
    2198             : !> \return : the trace
    2199             : !> \author Lianheng Tong
    2200             : ! **************************************************************************************************
    2201           0 :    FUNCTION cp_fm_latra(matrix) RESULT(res)
    2202             :       TYPE(cp_fm_type), INTENT(IN) :: matrix
    2203             :       REAL(KIND=dp) :: res
    2204             : 
    2205             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_latra'
    2206             : 
    2207             :       INTEGER :: nrows, ncols, handle
    2208             :       REAL(KIND=sp) :: res_sp
    2209             :       REAL(KIND=dp), DIMENSION(:, :), POINTER :: aa
    2210             :       REAL(KIND=sp), DIMENSION(:, :), POINTER :: aa_sp
    2211             : #if defined(__SCALAPACK)
    2212             :       INTEGER, DIMENSION(9) :: desca
    2213             : #else
    2214             :       INTEGER :: ii
    2215             : #endif
    2216             : 
    2217           0 :       CALL timeset(routineN, handle)
    2218             : 
    2219           0 :       nrows = matrix%matrix_struct%nrow_global
    2220           0 :       ncols = matrix%matrix_struct%ncol_global
    2221           0 :       CPASSERT(nrows .EQ. ncols)
    2222           0 :       aa => matrix%local_data
    2223           0 :       aa_sp => matrix%local_data_sp
    2224             : 
    2225             : #if defined(__SCALAPACK)
    2226           0 :       desca = matrix%matrix_struct%descriptor
    2227           0 :       IF (matrix%use_sp) THEN
    2228           0 :          res_sp = pslatra(nrows, aa_sp, 1, 1, desca)
    2229           0 :          res = REAL(res_sp, KIND=dp)
    2230             :       ELSE
    2231           0 :          res = pdlatra(nrows, aa, 1, 1, desca)
    2232             :       END IF
    2233             : #else
    2234             :       IF (matrix%use_sp) THEN
    2235             :          res_sp = 0.0_sp
    2236             :          DO ii = 1, nrows
    2237             :             res_sp = res_sp + aa_sp(ii, ii)
    2238             :          END DO
    2239             :          res = REAL(res_sp, KIND=dp)
    2240             :       ELSE
    2241             :          res = 0.0_dp
    2242             :          DO ii = 1, nrows
    2243             :             res = res + aa(ii, ii)
    2244             :          END DO
    2245             :       END IF
    2246             : #endif
    2247             : 
    2248           0 :       CALL timestop(handle)
    2249             : 
    2250           0 :    END FUNCTION cp_fm_latra
    2251             : 
    2252             : ! **************************************************************************************************
    2253             : !> \brief compute a QR factorization with column pivoting of a M-by-N distributed matrix
    2254             : !>        sub( A ) = A(IA:IA+M-1,JA:JA+N-1)
    2255             : !> \param matrix   : input M-by-N distributed matrix sub( A ) which is to be factored
    2256             : !> \param tau      :  scalar factors TAU of the elementary reflectors. TAU is tied to the distributed matrix A
    2257             : !> \param nrow ...
    2258             : !> \param ncol ...
    2259             : !> \param first_row ...
    2260             : !> \param first_col ...
    2261             : !> \author MI
    2262             : ! **************************************************************************************************
    2263          36 :    SUBROUTINE cp_fm_pdgeqpf(matrix, tau, nrow, ncol, first_row, first_col)
    2264             : 
    2265             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix
    2266             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: tau
    2267             :       INTEGER, INTENT(IN)                                :: nrow, ncol, first_row, first_col
    2268             : 
    2269             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_pdgeqpf'
    2270             : 
    2271             :       INTEGER                                            :: handle
    2272             :       INTEGER                                            :: info, lwork
    2273          36 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ipiv
    2274             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a
    2275             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: work
    2276             : #if defined(__SCALAPACK)
    2277             :       INTEGER, DIMENSION(9) :: descc
    2278             : #else
    2279             :       INTEGER :: lda
    2280             : #endif
    2281             : 
    2282          36 :       CALL timeset(routineN, handle)
    2283             : 
    2284          36 :       a => matrix%local_data
    2285          36 :       lwork = -1
    2286         108 :       ALLOCATE (work(2*nrow))
    2287         108 :       ALLOCATE (ipiv(ncol))
    2288          36 :       info = 0
    2289             : 
    2290             : #if defined(__SCALAPACK)
    2291         360 :       descc(:) = matrix%matrix_struct%descriptor(:)
    2292             :       ! Call SCALAPACK routine to get optimal work dimension
    2293          36 :       CALL pdgeqpf(nrow, ncol, a, first_row, first_col, descc, ipiv, tau, work, lwork, info)
    2294          36 :       lwork = INT(work(1))
    2295          36 :       DEALLOCATE (work)
    2296         108 :       ALLOCATE (work(lwork))
    2297         244 :       tau = 0.0_dp
    2298         354 :       ipiv = 0
    2299             : 
    2300             :       ! Call SCALAPACK routine to get QR decomposition of CTs
    2301          36 :       CALL pdgeqpf(nrow, ncol, a, first_row, first_col, descc, ipiv, tau, work, lwork, info)
    2302             : #else
    2303             :       CPASSERT(first_row == 1 .AND. first_col == 1)
    2304             :       lda = SIZE(a, 1)
    2305             :       CALL dgeqp3(nrow, ncol, a, lda, ipiv, tau, work, lwork, info)
    2306             :       lwork = INT(work(1))
    2307             :       DEALLOCATE (work)
    2308             :       ALLOCATE (work(lwork))
    2309             :       tau = 0.0_dp
    2310             :       ipiv = 0
    2311             :       CALL dgeqp3(nrow, ncol, a, lda, ipiv, tau, work, lwork, info)
    2312             : #endif
    2313          36 :       CPASSERT(info == 0)
    2314             : 
    2315          36 :       DEALLOCATE (work)
    2316          36 :       DEALLOCATE (ipiv)
    2317             : 
    2318          36 :       CALL timestop(handle)
    2319             : 
    2320          36 :    END SUBROUTINE cp_fm_pdgeqpf
    2321             : 
    2322             : ! **************************************************************************************************
    2323             : !> \brief generates an M-by-N real distributed matrix Q denoting A(IA:IA+M-1,JA:JA+N-1)
    2324             : !>         with orthonormal columns, which is defined as the first N columns of a product of K
    2325             : !>         elementary reflectors of order M
    2326             : !> \param matrix : On entry, the j-th column must contain the vector which defines the elementary reflector
    2327             : !>                  as returned from PDGEQRF
    2328             : !>                 On exit it contains  the M-by-N distributed matrix Q
    2329             : !> \param tau :   contains the scalar factors TAU of elementary reflectors  as returned by PDGEQRF
    2330             : !> \param nrow ...
    2331             : !> \param first_row ...
    2332             : !> \param first_col ...
    2333             : !> \author MI
    2334             : ! **************************************************************************************************
    2335          36 :    SUBROUTINE cp_fm_pdorgqr(matrix, tau, nrow, first_row, first_col)
    2336             : 
    2337             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix
    2338             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: tau
    2339             :       INTEGER, INTENT(IN)                                :: nrow, first_row, first_col
    2340             : 
    2341             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_pdorgqr'
    2342             : 
    2343             :       INTEGER                                            :: handle
    2344             :       INTEGER                                            :: info, lwork
    2345             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a
    2346             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: work
    2347             : #if defined(__SCALAPACK)
    2348             :       INTEGER, DIMENSION(9) :: descc
    2349             : #else
    2350             :       INTEGER :: lda
    2351             : #endif
    2352             : 
    2353          36 :       CALL timeset(routineN, handle)
    2354             : 
    2355          36 :       a => matrix%local_data
    2356          36 :       lwork = -1
    2357         108 :       ALLOCATE (work(2*nrow))
    2358          36 :       info = 0
    2359             : 
    2360             : #if defined(__SCALAPACK)
    2361         360 :       descc(:) = matrix%matrix_struct%descriptor(:)
    2362             : 
    2363          36 :       CALL pdorgqr(nrow, nrow, nrow, a, first_row, first_col, descc, tau, work, lwork, info)
    2364          36 :       CPASSERT(info == 0)
    2365          36 :       lwork = INT(work(1))
    2366          36 :       DEALLOCATE (work)
    2367         108 :       ALLOCATE (work(lwork))
    2368             : 
    2369             :       ! Call SCALAPACK routine to get Q
    2370          36 :       CALL pdorgqr(nrow, nrow, nrow, a, first_row, first_col, descc, tau, work, lwork, info)
    2371             : #else
    2372             :       CPASSERT(first_row == 1 .AND. first_col == 1)
    2373             :       lda = SIZE(a, 1)
    2374             :       CALL dorgqr(nrow, nrow, nrow, a, lda, tau, work, lwork, info)
    2375             :       lwork = INT(work(1))
    2376             :       DEALLOCATE (work)
    2377             :       ALLOCATE (work(lwork))
    2378             :       CALL dorgqr(nrow, nrow, nrow, a, lda, tau, work, lwork, info)
    2379             : #endif
    2380          36 :       CPASSERT(INFO == 0)
    2381             : 
    2382          36 :       DEALLOCATE (work)
    2383          36 :       CALL timestop(handle)
    2384             : 
    2385          36 :    END SUBROUTINE cp_fm_pdorgqr
    2386             : 
    2387             : ! **************************************************************************************************
    2388             : !> \brief Applies a planar rotation defined by cs and sn to the i'th and j'th rows.
    2389             : !> \param cs cosine of the rotation angle
    2390             : !> \param sn sinus of the rotation angle
    2391             : !> \param irow ...
    2392             : !> \param jrow ...
    2393             : !> \author Ole Schuett
    2394             : ! **************************************************************************************************
    2395      569200 :    SUBROUTINE cp_fm_rot_rows(matrix, irow, jrow, cs, sn)
    2396             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix
    2397             :       INTEGER, INTENT(IN)                      :: irow, jrow
    2398             :       REAL(dp), INTENT(IN)                     :: cs, sn
    2399             : 
    2400             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_rot_rows'
    2401             :       INTEGER                                  :: handle, nrow, ncol
    2402             : 
    2403             : #if defined(__SCALAPACK)
    2404             :       INTEGER                                  :: info, lwork
    2405             :       INTEGER, DIMENSION(9)                    :: desc
    2406      569200 :       REAL(dp), DIMENSION(:), ALLOCATABLE      :: work
    2407             : #endif
    2408             : 
    2409      569200 :       CALL timeset(routineN, handle)
    2410      569200 :       CALL cp_fm_get_info(matrix, nrow_global=nrow, ncol_global=ncol)
    2411             : 
    2412             : #if defined(__SCALAPACK)
    2413      569200 :       lwork = 2*ncol + 1
    2414     1707600 :       ALLOCATE (work(lwork))
    2415     5692000 :       desc(:) = matrix%matrix_struct%descriptor(:)
    2416             :       CALL pdrot(ncol, &
    2417             :                  matrix%local_data(1, 1), irow, 1, desc, ncol, &
    2418             :                  matrix%local_data(1, 1), jrow, 1, desc, ncol, &
    2419      569200 :                  cs, sn, work, lwork, info)
    2420      569200 :       CPASSERT(info == 0)
    2421      569200 :       DEALLOCATE (work)
    2422             : #else
    2423             :       CALL drot(ncol, matrix%local_data(irow, 1), ncol, matrix%local_data(jrow, 1), ncol, cs, sn)
    2424             : #endif
    2425             : 
    2426      569200 :       CALL timestop(handle)
    2427      569200 :    END SUBROUTINE cp_fm_rot_rows
    2428             : 
    2429             : ! **************************************************************************************************
    2430             : !> \brief Applies a planar rotation defined by cs and sn to the i'th and j'th columnns.
    2431             : !> \param cs cosine of the rotation angle
    2432             : !> \param sn sinus of the rotation angle
    2433             : !> \param icol ...
    2434             : !> \param jcol ...
    2435             : !> \author Ole Schuett
    2436             : ! **************************************************************************************************
    2437      641264 :    SUBROUTINE cp_fm_rot_cols(matrix, icol, jcol, cs, sn)
    2438             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix
    2439             :       INTEGER, INTENT(IN)                      :: icol, jcol
    2440             :       REAL(dp), INTENT(IN)                     :: cs, sn
    2441             : 
    2442             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_rot_cols'
    2443             :       INTEGER                                  :: handle, nrow, ncol
    2444             : 
    2445             : #if defined(__SCALAPACK)
    2446             :       INTEGER                                  :: info, lwork
    2447             :       INTEGER, DIMENSION(9)                    :: desc
    2448      641264 :       REAL(dp), DIMENSION(:), ALLOCATABLE      :: work
    2449             : #endif
    2450             : 
    2451      641264 :       CALL timeset(routineN, handle)
    2452      641264 :       CALL cp_fm_get_info(matrix, nrow_global=nrow, ncol_global=ncol)
    2453             : 
    2454             : #if defined(__SCALAPACK)
    2455      641264 :       lwork = 2*nrow + 1
    2456     1923792 :       ALLOCATE (work(lwork))
    2457     6412640 :       desc(:) = matrix%matrix_struct%descriptor(:)
    2458             :       CALL pdrot(nrow, &
    2459             :                  matrix%local_data(1, 1), 1, icol, desc, 1, &
    2460             :                  matrix%local_data(1, 1), 1, jcol, desc, 1, &
    2461      641264 :                  cs, sn, work, lwork, info)
    2462      641264 :       CPASSERT(info == 0)
    2463      641264 :       DEALLOCATE (work)
    2464             : #else
    2465             :       CALL drot(nrow, matrix%local_data(1, icol), 1, matrix%local_data(1, jcol), 1, cs, sn)
    2466             : #endif
    2467             : 
    2468      641264 :       CALL timestop(handle)
    2469      641264 :    END SUBROUTINE cp_fm_rot_cols
    2470             : 
    2471             : ! **************************************************************************************************
    2472             : !> \brief Orthonormalizes selected rows and columns of a full matrix, matrix_a
    2473             : !> \param matrix_a ...
    2474             : !> \param B ...
    2475             : !> \param nrows number of rows of matrix_a, optional, defaults to size(matrix_a,1)
    2476             : !> \param ncols number of columns of matrix_a, optional, defaults to size(matrix_a, 2)
    2477             : !> \param start_row starting index of rows, optional, defaults to 1
    2478             : !> \param start_col starting index of columns, optional, defaults to 1
    2479             : !> \param do_norm ...
    2480             : !> \param do_print ...
    2481             : ! **************************************************************************************************
    2482           0 :    SUBROUTINE cp_fm_Gram_Schmidt_orthonorm(matrix_a, B, nrows, ncols, start_row, start_col, &
    2483             :                                            do_norm, do_print)
    2484             : 
    2485             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a
    2486             :       REAL(kind=dp), DIMENSION(:, :), INTENT(OUT)        :: B
    2487             :       INTEGER, INTENT(IN), OPTIONAL                      :: nrows, ncols, start_row, start_col
    2488             :       LOGICAL, INTENT(IN), OPTIONAL                      :: do_norm, do_print
    2489             : 
    2490             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_Gram_Schmidt_orthonorm'
    2491             : 
    2492             :       INTEGER :: end_col_global, end_col_local, end_row_global, end_row_local, handle, i, j, &
    2493             :                  j_col, ncol_global, ncol_local, nrow_global, nrow_local, start_col_global, &
    2494             :                  start_col_local, start_row_global, start_row_local, this_col, unit_nr
    2495           0 :       INTEGER, DIMENSION(:), POINTER                     :: col_indices, row_indices
    2496             :       LOGICAL                                            :: my_do_norm, my_do_print
    2497             :       REAL(KIND=dp)                                      :: norm
    2498           0 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: a
    2499             : 
    2500           0 :       CALL timeset(routineN, handle)
    2501             : 
    2502           0 :       my_do_norm = .TRUE.
    2503           0 :       IF (PRESENT(do_norm)) my_do_norm = do_norm
    2504             : 
    2505           0 :       my_do_print = .FALSE.
    2506           0 :       IF (PRESENT(do_print) .AND. (my_do_norm)) my_do_print = do_print
    2507             : 
    2508           0 :       unit_nr = -1
    2509           0 :       IF (my_do_print) THEN
    2510           0 :          unit_nr = cp_logger_get_default_unit_nr()
    2511           0 :          IF (unit_nr < 1) my_do_print = .FALSE.
    2512             :       END IF
    2513             : 
    2514           0 :       IF (SIZE(B) /= 0) THEN
    2515           0 :          IF (PRESENT(nrows)) THEN
    2516           0 :             nrow_global = nrows
    2517             :          ELSE
    2518           0 :             nrow_global = SIZE(B, 1)
    2519             :          END IF
    2520             : 
    2521           0 :          IF (PRESENT(ncols)) THEN
    2522           0 :             ncol_global = ncols
    2523             :          ELSE
    2524           0 :             ncol_global = SIZE(B, 2)
    2525             :          END IF
    2526             : 
    2527           0 :          IF (PRESENT(start_row)) THEN
    2528           0 :             start_row_global = start_row
    2529             :          ELSE
    2530             :             start_row_global = 1
    2531             :          END IF
    2532             : 
    2533           0 :          IF (PRESENT(start_col)) THEN
    2534           0 :             start_col_global = start_col
    2535             :          ELSE
    2536             :             start_col_global = 1
    2537             :          END IF
    2538             : 
    2539           0 :          end_row_global = start_row_global + nrow_global - 1
    2540           0 :          end_col_global = start_col_global + ncol_global - 1
    2541             : 
    2542             :          CALL cp_fm_get_info(matrix=matrix_a, &
    2543             :                              nrow_global=nrow_global, ncol_global=ncol_global, &
    2544             :                              nrow_local=nrow_local, ncol_local=ncol_local, &
    2545           0 :                              row_indices=row_indices, col_indices=col_indices)
    2546           0 :          IF (end_row_global > nrow_global) THEN
    2547             :             end_row_global = nrow_global
    2548             :          END IF
    2549           0 :          IF (end_col_global > ncol_global) THEN
    2550             :             end_col_global = ncol_global
    2551             :          END IF
    2552             : 
    2553             :          ! find out row/column indices of locally stored matrix elements that
    2554             :          ! needs to be copied.
    2555             :          ! Arrays row_indices and col_indices are assumed to be sorted in
    2556             :          ! ascending order
    2557           0 :          DO start_row_local = 1, nrow_local
    2558           0 :             IF (row_indices(start_row_local) >= start_row_global) EXIT
    2559             :          END DO
    2560             : 
    2561           0 :          DO end_row_local = start_row_local, nrow_local
    2562           0 :             IF (row_indices(end_row_local) > end_row_global) EXIT
    2563             :          END DO
    2564           0 :          end_row_local = end_row_local - 1
    2565             : 
    2566           0 :          DO start_col_local = 1, ncol_local
    2567           0 :             IF (col_indices(start_col_local) >= start_col_global) EXIT
    2568             :          END DO
    2569             : 
    2570           0 :          DO end_col_local = start_col_local, ncol_local
    2571           0 :             IF (col_indices(end_col_local) > end_col_global) EXIT
    2572             :          END DO
    2573           0 :          end_col_local = end_col_local - 1
    2574             : 
    2575           0 :          a => matrix_a%local_data
    2576             : 
    2577           0 :          this_col = col_indices(start_col_local) - start_col_global + 1
    2578             : 
    2579           0 :          B(:, this_col) = a(:, start_col_local)
    2580             : 
    2581           0 :          IF (my_do_norm) THEN
    2582           0 :             norm = SQRT(accurate_dot_product(B(:, this_col), B(:, this_col)))
    2583           0 :             B(:, this_col) = B(:, this_col)/norm
    2584           0 :             IF (my_do_print) WRITE (unit_nr, '(I3,F8.3)') this_col, norm
    2585             :          END IF
    2586             : 
    2587           0 :          DO i = start_col_local + 1, end_col_local
    2588           0 :             this_col = col_indices(i) - start_col_global + 1
    2589           0 :             B(:, this_col) = a(:, i)
    2590           0 :             DO j = start_col_local, i - 1
    2591           0 :                j_col = col_indices(j) - start_col_global + 1
    2592             :                B(:, this_col) = B(:, this_col) - &
    2593             :                                 accurate_dot_product(B(:, j_col), B(:, this_col))* &
    2594           0 :                                 B(:, j_col)/accurate_dot_product(B(:, j_col), B(:, j_col))
    2595             :             END DO
    2596             : 
    2597           0 :             IF (my_do_norm) THEN
    2598           0 :                norm = SQRT(accurate_dot_product(B(:, this_col), B(:, this_col)))
    2599           0 :                B(:, this_col) = B(:, this_col)/norm
    2600           0 :                IF (my_do_print) WRITE (unit_nr, '(I3,F8.3)') this_col, norm
    2601             :             END IF
    2602             : 
    2603             :          END DO
    2604           0 :          CALL matrix_a%matrix_struct%para_env%sum(B)
    2605             :       END IF
    2606             : 
    2607           0 :       CALL timestop(handle)
    2608             : 
    2609           0 :    END SUBROUTINE cp_fm_Gram_Schmidt_orthonorm
    2610             : 
    2611             : ! **************************************************************************************************
    2612             : !> \brief Cholesky decomposition
    2613             : !> \param fm_matrix ...
    2614             : !> \param n ...
    2615             : ! **************************************************************************************************
    2616       10089 :    SUBROUTINE cp_fm_potrf(fm_matrix, n)
    2617             :       TYPE(cp_fm_type)                         :: fm_matrix
    2618             :       INTEGER, INTENT(in)                      :: n
    2619             : 
    2620             :       INTEGER                                  :: info
    2621       10089 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    2622       10089 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
    2623             : #if defined(__SCALAPACK)
    2624             :       INTEGER, DIMENSION(9)                    :: desca
    2625             : #endif
    2626             : 
    2627       10089 :       a => fm_matrix%local_data
    2628       10089 :       a_sp => fm_matrix%local_data_sp
    2629             : #if defined(__SCALAPACK)
    2630      100890 :       desca(:) = fm_matrix%matrix_struct%descriptor(:)
    2631       10089 :       IF (fm_matrix%use_sp) THEN
    2632           0 :          CALL pspotrf('U', n, a_sp(1, 1), 1, 1, desca, info)
    2633             :       ELSE
    2634       10089 :          CALL pdpotrf('U', n, a(1, 1), 1, 1, desca, info)
    2635             :       END IF
    2636             : #else
    2637             :       IF (fm_matrix%use_sp) THEN
    2638             :          CALL spotrf('U', n, a_sp(1, 1), SIZE(a_sp, 1), info)
    2639             :       ELSE
    2640             :          CALL dpotrf('U', n, a(1, 1), SIZE(a, 1), info)
    2641             :       END IF
    2642             : #endif
    2643       10089 :       IF (info /= 0) &
    2644           0 :          CPABORT("Cholesky decomposition failed. Matrix ill conditioned ?")
    2645             : 
    2646       10089 :    END SUBROUTINE cp_fm_potrf
    2647             : 
    2648             : ! **************************************************************************************************
    2649             : !> \brief Invert trianguar matrix
    2650             : !> \param fm_matrix the matrix to invert (must be an upper triangular matrix)
    2651             : !> \param n size of the matrix to invert
    2652             : ! **************************************************************************************************
    2653        9313 :    SUBROUTINE cp_fm_potri(fm_matrix, n)
    2654             :       TYPE(cp_fm_type)                          :: fm_matrix
    2655             :       INTEGER, INTENT(in)                       :: n
    2656             : 
    2657        9313 :       REAL(KIND=dp), DIMENSION(:, :), POINTER   :: a
    2658        9313 :       REAL(KIND=sp), DIMENSION(:, :), POINTER   :: a_sp
    2659             :       INTEGER                                   :: info
    2660             : #if defined(__SCALAPACK)
    2661             :       INTEGER, DIMENSION(9)                     :: desca
    2662             : #endif
    2663             : 
    2664        9313 :       a => fm_matrix%local_data
    2665        9313 :       a_sp => fm_matrix%local_data_sp
    2666             : #if defined(__SCALAPACK)
    2667       93130 :       desca(:) = fm_matrix%matrix_struct%descriptor(:)
    2668        9313 :       IF (fm_matrix%use_sp) THEN
    2669           0 :          CALL pspotri('U', n, a_sp(1, 1), 1, 1, desca, info)
    2670             :       ELSE
    2671        9313 :          CALL pdpotri('U', n, a(1, 1), 1, 1, desca, info)
    2672             :       END IF
    2673             : #else
    2674             :       IF (fm_matrix%use_sp) THEN
    2675             :          CALL spotri('U', n, a_sp(1, 1), SIZE(a_sp, 1), info)
    2676             :       ELSE
    2677             :          CALL dpotri('U', n, a(1, 1), SIZE(a, 1), info)
    2678             :       END IF
    2679             : #endif
    2680        9313 :       CPASSERT(info == 0)
    2681        9313 :    END SUBROUTINE cp_fm_potri
    2682             : 
    2683             : ! **************************************************************************************************
    2684             : !> \brief ...
    2685             : !> \param fm_matrix ...
    2686             : !> \param neig ...
    2687             : !> \param fm_matrixb ...
    2688             : !> \param fm_matrixout ...
    2689             : !> \param op ...
    2690             : !> \param pos ...
    2691             : !> \param transa ...
    2692             : ! **************************************************************************************************
    2693        1184 :    SUBROUTINE cp_fm_cholesky_restore(fm_matrix, neig, fm_matrixb, fm_matrixout, op, pos, transa)
    2694             :       TYPE(cp_fm_type)                               :: fm_matrix
    2695             :       TYPE(cp_fm_type)                               :: fm_matrixb
    2696             :       TYPE(cp_fm_type)                               :: fm_matrixout
    2697             :       INTEGER, INTENT(IN)                            :: neig
    2698             :       CHARACTER(LEN=*), INTENT(IN)                   :: op
    2699             :       CHARACTER(LEN=*), INTENT(IN)                   :: pos
    2700             :       CHARACTER(LEN=*), INTENT(IN)                   :: transa
    2701             : 
    2702        1184 :       REAL(KIND=dp), DIMENSION(:, :), POINTER        :: a, b, outm
    2703        1184 :       REAL(KIND=sp), DIMENSION(:, :), POINTER        :: a_sp, b_sp, outm_sp
    2704             :       INTEGER                                        :: n, itype
    2705             :       REAL(KIND=dp)                                  :: alpha
    2706             : #if defined(__SCALAPACK)
    2707             :       INTEGER                                        :: i
    2708             :       INTEGER, DIMENSION(9)                          :: desca, descb, descout
    2709             : #endif
    2710             : 
    2711             :       ! notice b is the cholesky guy
    2712        1184 :       a => fm_matrix%local_data
    2713        1184 :       b => fm_matrixb%local_data
    2714        1184 :       outm => fm_matrixout%local_data
    2715        1184 :       a_sp => fm_matrix%local_data_sp
    2716        1184 :       b_sp => fm_matrixb%local_data_sp
    2717        1184 :       outm_sp => fm_matrixout%local_data_sp
    2718             : 
    2719        1184 :       n = fm_matrix%matrix_struct%nrow_global
    2720        1184 :       itype = 1
    2721             : 
    2722             : #if defined(__SCALAPACK)
    2723       11840 :       desca(:) = fm_matrix%matrix_struct%descriptor(:)
    2724       11840 :       descb(:) = fm_matrixb%matrix_struct%descriptor(:)
    2725       11840 :       descout(:) = fm_matrixout%matrix_struct%descriptor(:)
    2726        1184 :       alpha = 1.0_dp
    2727        5316 :       DO i = 1, neig
    2728        5316 :          IF (fm_matrix%use_sp) THEN
    2729           0 :             CALL pscopy(n, a_sp(1, 1), 1, i, desca, 1, outm_sp(1, 1), 1, i, descout, 1)
    2730             :          ELSE
    2731        4132 :             CALL pdcopy(n, a(1, 1), 1, i, desca, 1, outm(1, 1), 1, i, descout, 1)
    2732             :          END IF
    2733             :       END DO
    2734        1184 :       IF (op .EQ. "SOLVE") THEN
    2735        1184 :          IF (fm_matrix%use_sp) THEN
    2736             :             CALL pstrsm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), 1, 1, descb, &
    2737           0 :                         outm_sp(1, 1), 1, 1, descout)
    2738             :          ELSE
    2739        1184 :             CALL pdtrsm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), 1, 1, descb, outm(1, 1), 1, 1, descout)
    2740             :          END IF
    2741             :       ELSE
    2742           0 :          IF (fm_matrix%use_sp) THEN
    2743             :             CALL pstrmm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), 1, 1, descb, &
    2744           0 :                         outm_sp(1, 1), 1, 1, descout)
    2745             :          ELSE
    2746           0 :             CALL pdtrmm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), 1, 1, descb, outm(1, 1), 1, 1, descout)
    2747             :          END IF
    2748             :       END IF
    2749             : #else
    2750             :       alpha = 1.0_dp
    2751             :       IF (fm_matrix%use_sp) THEN
    2752             :          CALL scopy(neig*n, a_sp(1, 1), 1, outm_sp(1, 1), 1)
    2753             :       ELSE
    2754             :          CALL dcopy(neig*n, a(1, 1), 1, outm(1, 1), 1)
    2755             :       END IF
    2756             :       IF (op .EQ. "SOLVE") THEN
    2757             :          IF (fm_matrix%use_sp) THEN
    2758             :             CALL strsm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), SIZE(b_sp, 1), outm_sp(1, 1), n)
    2759             :          ELSE
    2760             :             CALL dtrsm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), SIZE(b, 1), outm(1, 1), n)
    2761             :          END IF
    2762             :       ELSE
    2763             :          IF (fm_matrix%use_sp) THEN
    2764             :             CALL strmm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), n, outm_sp(1, 1), n)
    2765             :          ELSE
    2766             :             CALL dtrmm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), n, outm(1, 1), n)
    2767             :          END IF
    2768             :       END IF
    2769             : #endif
    2770             : 
    2771        1184 :    END SUBROUTINE cp_fm_cholesky_restore
    2772             : 
    2773             : END MODULE cp_fm_basic_linalg

Generated by: LCOV version 1.15