LCOV - code coverage report
Current view: top level - src/fm - cp_cfm_basic_linalg.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:1f285aa) Lines: 324 408 79.4 %
Date: 2024-04-23 06:49:27 Functions: 17 20 85.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Basic linear algebra operations for complex full matrices.
      10             : !> \note
      11             : !>      - not all functionality implemented
      12             : !> \par History
      13             : !>      Nearly literal copy of Fawzi's routines
      14             : !> \author Joost VandeVondele
      15             : ! **************************************************************************************************
      16             : MODULE cp_cfm_basic_linalg
      17             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      18             :    USE cp_cfm_types,                    ONLY: cp_cfm_get_info,&
      19             :                                               cp_cfm_type
      20             :    USE cp_fm_struct,                    ONLY: cp_fm_struct_equivalent
      21             :    USE cp_fm_types,                     ONLY: cp_fm_type
      22             :    USE kahan_sum,                       ONLY: accurate_dot_product
      23             :    USE kinds,                           ONLY: dp
      24             :    USE mathconstants,                   ONLY: z_one,&
      25             :                                               z_zero
      26             :    USE message_passing,                 ONLY: mp_comm_type
      27             : #include "../base/base_uses.f90"
      28             : 
      29             :    IMPLICIT NONE
      30             :    PRIVATE
      31             : 
      32             :    LOGICAL, PRIVATE, PARAMETER :: debug_this_module = .TRUE.
      33             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'cp_cfm_basic_linalg'
      34             : 
      35             :    PUBLIC :: cp_cfm_cholesky_decompose, &
      36             :              cp_cfm_column_scale, &
      37             :              cp_cfm_gemm, &
      38             :              cp_cfm_lu_decompose, &
      39             :              cp_cfm_lu_invert, &
      40             :              cp_cfm_norm, &
      41             :              cp_cfm_scale, &
      42             :              cp_cfm_scale_and_add, &
      43             :              cp_cfm_scale_and_add_fm, &
      44             :              cp_cfm_schur_product, &
      45             :              cp_cfm_solve, &
      46             :              cp_cfm_trace, &
      47             :              cp_cfm_transpose, &
      48             :              cp_cfm_triangular_invert, &
      49             :              cp_cfm_triangular_multiply, &
      50             :              cp_cfm_rot_rows, &
      51             :              cp_cfm_rot_cols, &
      52             :              cp_cfm_cholesky_invert
      53             : 
      54             :    REAL(kind=dp), EXTERNAL :: zlange, pzlange
      55             : 
      56             :    INTERFACE cp_cfm_scale
      57             :       MODULE PROCEDURE cp_cfm_dscale, cp_cfm_zscale
      58             :    END INTERFACE cp_cfm_scale
      59             : 
      60             : ! **************************************************************************************************
      61             : 
      62             : CONTAINS
      63             : 
      64             : ! **************************************************************************************************
      65             : !> \brief Computes the element-wise (Schur) product of two matrices: C = A \circ B .
      66             : !> \param matrix_a the first input matrix
      67             : !> \param matrix_b the second input matrix
      68             : !> \param matrix_c matrix to store the result
      69             : ! **************************************************************************************************
      70         204 :    SUBROUTINE cp_cfm_schur_product(matrix_a, matrix_b, matrix_c)
      71             : 
      72             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b, matrix_c
      73             : 
      74             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_schur_product'
      75             : 
      76         204 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a, b, c
      77             :       INTEGER                                            :: handle, icol_local, irow_local, mypcol, &
      78             :                                                             myprow, ncol_local, nrow_local
      79             : 
      80         204 :       CALL timeset(routineN, handle)
      81             : 
      82         204 :       myprow = matrix_a%matrix_struct%context%mepos(1)
      83         204 :       mypcol = matrix_a%matrix_struct%context%mepos(2)
      84             : 
      85         204 :       a => matrix_a%local_data
      86         204 :       b => matrix_b%local_data
      87         204 :       c => matrix_c%local_data
      88             : 
      89         204 :       nrow_local = matrix_a%matrix_struct%nrow_locals(myprow)
      90         204 :       ncol_local = matrix_a%matrix_struct%ncol_locals(mypcol)
      91             : 
      92        1020 :       DO icol_local = 1, ncol_local
      93        2652 :          DO irow_local = 1, nrow_local
      94        2448 :             c(irow_local, icol_local) = a(irow_local, icol_local)*b(irow_local, icol_local)
      95             :          END DO
      96             :       END DO
      97             : 
      98         204 :       CALL timestop(handle)
      99             : 
     100         204 :    END SUBROUTINE cp_cfm_schur_product
     101             : 
     102             : ! **************************************************************************************************
     103             : !> \brief Computes the element-wise (Schur) product of two matrices: C = A \circ conjg(B) .
     104             : !> \param matrix_a the first input matrix
     105             : !> \param matrix_b the second input matrix
     106             : !> \param matrix_c matrix to store the result
     107             : ! **************************************************************************************************
     108           0 :    SUBROUTINE cp_cfm_schur_product_cc(matrix_a, matrix_b, matrix_c)
     109             : 
     110             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b, matrix_c
     111             : 
     112             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_schur_product_cc'
     113             : 
     114           0 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a, b, c
     115             :       INTEGER                                            :: handle, icol_local, irow_local, mypcol, &
     116             :                                                             myprow, ncol_local, nrow_local
     117             : 
     118           0 :       CALL timeset(routineN, handle)
     119             : 
     120           0 :       myprow = matrix_a%matrix_struct%context%mepos(1)
     121           0 :       mypcol = matrix_a%matrix_struct%context%mepos(2)
     122             : 
     123           0 :       a => matrix_a%local_data
     124           0 :       b => matrix_b%local_data
     125           0 :       c => matrix_c%local_data
     126             : 
     127           0 :       nrow_local = matrix_a%matrix_struct%nrow_locals(myprow)
     128           0 :       ncol_local = matrix_a%matrix_struct%ncol_locals(mypcol)
     129             : 
     130           0 :       DO icol_local = 1, ncol_local
     131           0 :          DO irow_local = 1, nrow_local
     132           0 :             c(irow_local, icol_local) = a(irow_local, icol_local)*CONJG(b(irow_local, icol_local))
     133             :          END DO
     134             :       END DO
     135             : 
     136           0 :       CALL timestop(handle)
     137             : 
     138           0 :    END SUBROUTINE cp_cfm_schur_product_cc
     139             : 
     140             : ! **************************************************************************************************
     141             : !> \brief Scale and add two BLACS matrices (a = alpha*a + beta*b).
     142             : !> \param alpha ...
     143             : !> \param matrix_a ...
     144             : !> \param beta ...
     145             : !> \param matrix_b ...
     146             : !> \date    11.06.2001
     147             : !> \author  Matthias Krack
     148             : !> \version 1.0
     149             : !> \note
     150             : !>    Use explicit loops to avoid temporary arrays, as a compiler reasonably assumes that arrays
     151             : !>    matrix_a%local_data and matrix_b%local_data may overlap (they are referenced by pointers).
     152             : !>    In general case (alpha*a + beta*b) explicit loops appears to be up to two times more efficient
     153             : !>    than equivalent LAPACK calls (zscale, zaxpy). This is because using LAPACK calls implies
     154             : !>    two passes through each array, so data need to be retrieved twice if arrays are large
     155             : !>    enough to not fit into the processor's cache.
     156             : ! **************************************************************************************************
     157      187202 :    SUBROUTINE cp_cfm_scale_and_add(alpha, matrix_a, beta, matrix_b)
     158             :       COMPLEX(kind=dp), INTENT(in)                       :: alpha
     159             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a
     160             :       COMPLEX(kind=dp), INTENT(in), OPTIONAL             :: beta
     161             :       TYPE(cp_cfm_type), INTENT(IN), OPTIONAL            :: matrix_b
     162             : 
     163             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_scale_and_add'
     164             : 
     165             :       COMPLEX(kind=dp)                                   :: my_beta
     166      187202 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a, b
     167             :       INTEGER                                            :: handle, icol_local, irow_local, mypcol, &
     168             :                                                             myprow, ncol_local, nrow_local
     169             : 
     170      187202 :       CALL timeset(routineN, handle)
     171             : 
     172      187202 :       my_beta = z_zero
     173      187202 :       IF (PRESENT(beta)) my_beta = beta
     174      187202 :       NULLIFY (a, b)
     175             : 
     176             :       ! to do: use dscal,dcopy,daxp
     177      187202 :       myprow = matrix_a%matrix_struct%context%mepos(1)
     178      187202 :       mypcol = matrix_a%matrix_struct%context%mepos(2)
     179             : 
     180      187202 :       nrow_local = matrix_a%matrix_struct%nrow_locals(myprow)
     181      187202 :       ncol_local = matrix_a%matrix_struct%ncol_locals(mypcol)
     182             : 
     183      187202 :       a => matrix_a%local_data
     184             : 
     185      187202 :       IF (my_beta == z_zero) THEN
     186             : 
     187        3778 :          IF (alpha == z_zero) THEN
     188           0 :             a(:, :) = z_zero
     189        3778 :          ELSE IF (alpha == z_one) THEN
     190        3778 :             CALL timestop(handle)
     191        3778 :             RETURN
     192             :          ELSE
     193           0 :             a(:, :) = alpha*a(:, :)
     194             :          END IF
     195             : 
     196             :       ELSE
     197      183424 :          CPASSERT(PRESENT(matrix_b))
     198      183424 :          IF (matrix_a%matrix_struct%context /= matrix_b%matrix_struct%context) &
     199           0 :             CPABORT("matrixes must be in the same blacs context")
     200             : 
     201      183424 :          IF (cp_fm_struct_equivalent(matrix_a%matrix_struct, &
     202             :                                      matrix_b%matrix_struct)) THEN
     203             : 
     204      183424 :             b => matrix_b%local_data
     205             : 
     206      183424 :             IF (alpha == z_zero) THEN
     207           0 :                IF (my_beta == z_one) THEN
     208             :                   !a(:, :) = b(:, :)
     209           0 :                   DO icol_local = 1, ncol_local
     210           0 :                      DO irow_local = 1, nrow_local
     211           0 :                         a(irow_local, icol_local) = b(irow_local, icol_local)
     212             :                      END DO
     213             :                   END DO
     214             :                ELSE
     215             :                   !a(:, :) = my_beta*b(:, :)
     216           0 :                   DO icol_local = 1, ncol_local
     217           0 :                      DO irow_local = 1, nrow_local
     218           0 :                         a(irow_local, icol_local) = my_beta*b(irow_local, icol_local)
     219             :                      END DO
     220             :                   END DO
     221             :                END IF
     222      183424 :             ELSE IF (alpha == z_one) THEN
     223      182040 :                IF (my_beta == z_one) THEN
     224             :                   !a(:, :) = a(:, :)+b(:, :)
     225     1895469 :                   DO icol_local = 1, ncol_local
     226    28873444 :                      DO irow_local = 1, nrow_local
     227    28728311 :                         a(irow_local, icol_local) = a(irow_local, icol_local) + b(irow_local, icol_local)
     228             :                      END DO
     229             :                   END DO
     230             :                ELSE
     231             :                   !a(:, :) = a(:, :)+my_beta*b(:, :)
     232      695213 :                   DO icol_local = 1, ncol_local
     233    11039662 :                      DO irow_local = 1, nrow_local
     234    11002755 :                         a(irow_local, icol_local) = a(irow_local, icol_local) + my_beta*b(irow_local, icol_local)
     235             :                      END DO
     236             :                   END DO
     237             :                END IF
     238             :             ELSE
     239             :                !a(:, :) = alpha*a(:, :)+my_beta*b(:, :)
     240       33368 :                DO icol_local = 1, ncol_local
     241      781272 :                   DO irow_local = 1, nrow_local
     242      779888 :                      a(irow_local, icol_local) = alpha*a(irow_local, icol_local) + my_beta*b(irow_local, icol_local)
     243             :                   END DO
     244             :                END DO
     245             :             END IF
     246             :          ELSE
     247             : #if defined(__SCALAPACK)
     248           0 :             CPABORT("to do (pdscal,pdcopy,pdaxpy)")
     249             : #else
     250             :             CPABORT("")
     251             : #endif
     252             :          END IF
     253             :       END IF
     254      183424 :       CALL timestop(handle)
     255      187202 :    END SUBROUTINE cp_cfm_scale_and_add
     256             : 
     257             : ! **************************************************************************************************
     258             : !> \brief Scale and add two BLACS matrices (a = alpha*a + beta*b).
     259             : !>        where b is a real matrix (adapted from cp_cfm_scale_and_add).
     260             : !> \param alpha ...
     261             : !> \param matrix_a ...
     262             : !> \param beta ...
     263             : !> \param matrix_b ...
     264             : !> \date    01.08.2014
     265             : !> \author  JGH
     266             : !> \version 1.0
     267             : ! **************************************************************************************************
     268      125754 :    SUBROUTINE cp_cfm_scale_and_add_fm(alpha, matrix_a, beta, matrix_b)
     269             :       COMPLEX(kind=dp), INTENT(in)                       :: alpha
     270             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a
     271             :       COMPLEX(kind=dp), INTENT(in)                       :: beta
     272             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_b
     273             : 
     274             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_scale_and_add_fm'
     275             : 
     276      125754 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a
     277             :       INTEGER                                            :: handle, icol_local, irow_local, mypcol, &
     278             :                                                             myprow, ncol_local, nrow_local
     279      125754 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: b
     280             : 
     281      125754 :       CALL timeset(routineN, handle)
     282             : 
     283      125754 :       NULLIFY (a, b)
     284             : 
     285      125754 :       myprow = matrix_a%matrix_struct%context%mepos(1)
     286      125754 :       mypcol = matrix_a%matrix_struct%context%mepos(2)
     287             : 
     288      125754 :       nrow_local = matrix_a%matrix_struct%nrow_locals(myprow)
     289      125754 :       ncol_local = matrix_a%matrix_struct%ncol_locals(mypcol)
     290             : 
     291      125754 :       a => matrix_a%local_data
     292             : 
     293      125754 :       IF (beta == z_zero) THEN
     294             : 
     295           0 :          IF (alpha == z_zero) THEN
     296           0 :             a(:, :) = z_zero
     297           0 :          ELSE IF (alpha == z_one) THEN
     298           0 :             CALL timestop(handle)
     299           0 :             RETURN
     300             :          ELSE
     301           0 :             a(:, :) = alpha*a(:, :)
     302             :          END IF
     303             : 
     304             :       ELSE
     305      125754 :          IF (matrix_a%matrix_struct%context /= matrix_b%matrix_struct%context) &
     306           0 :             CPABORT("matrices must be in the same blacs context")
     307             : 
     308      125754 :          IF (cp_fm_struct_equivalent(matrix_a%matrix_struct, &
     309             :                                      matrix_b%matrix_struct)) THEN
     310             : 
     311      125754 :             b => matrix_b%local_data
     312             : 
     313      125754 :             IF (alpha == z_zero) THEN
     314       46162 :                IF (beta == z_one) THEN
     315             :                   !a(:, :) = b(:, :)
     316     1563126 :                   DO icol_local = 1, ncol_local
     317    57242652 :                      DO irow_local = 1, nrow_local
     318    57196554 :                         a(irow_local, icol_local) = b(irow_local, icol_local)
     319             :                      END DO
     320             :                   END DO
     321             :                ELSE
     322             :                   !a(:, :) = beta*b(:, :)
     323        2512 :                   DO icol_local = 1, ncol_local
     324       55648 :                      DO irow_local = 1, nrow_local
     325       55584 :                         a(irow_local, icol_local) = beta*b(irow_local, icol_local)
     326             :                      END DO
     327             :                   END DO
     328             :                END IF
     329       79592 :             ELSE IF (alpha == z_one) THEN
     330       49071 :                IF (beta == z_one) THEN
     331             :                   !a(:, :) = a(:, :)+b(:, :)
     332       58421 :                   DO icol_local = 1, ncol_local
     333     1358021 :                      DO irow_local = 1, nrow_local
     334     1355928 :                         a(irow_local, icol_local) = a(irow_local, icol_local) + b(irow_local, icol_local)
     335             :                      END DO
     336             :                   END DO
     337             :                ELSE
     338             :                   !a(:, :) = a(:, :)+beta*b(:, :)
     339     1597990 :                   DO icol_local = 1, ncol_local
     340    58018828 :                      DO irow_local = 1, nrow_local
     341    57971850 :                         a(irow_local, icol_local) = a(irow_local, icol_local) + beta*b(irow_local, icol_local)
     342             :                      END DO
     343             :                   END DO
     344             :                END IF
     345             :             ELSE
     346             :                !a(:, :) = alpha*a(:, :)+beta*b(:, :)
     347      346801 :                DO icol_local = 1, ncol_local
     348     5761521 :                   DO irow_local = 1, nrow_local
     349     5731000 :                      a(irow_local, icol_local) = alpha*a(irow_local, icol_local) + beta*b(irow_local, icol_local)
     350             :                   END DO
     351             :                END DO
     352             :             END IF
     353             :          ELSE
     354             : #if defined(__SCALAPACK)
     355           0 :             CPABORT("to do (pdscal,pdcopy,pdaxpy)")
     356             : #else
     357             :             CPABORT("")
     358             : #endif
     359             :          END IF
     360             :       END IF
     361      125754 :       CALL timestop(handle)
     362      125754 :    END SUBROUTINE cp_cfm_scale_and_add_fm
     363             : 
     364             : ! **************************************************************************************************
     365             : !> \brief Computes LU decomposition of a given matrix.
     366             : !> \param matrix_a     full matrix
     367             : !> \param determinant  determinant
     368             : !> \date    11.06.2001
     369             : !> \author  Matthias Krack
     370             : !> \version 1.0
     371             : !> \note
     372             : !>    The actual purpose right now is to efficiently compute the determinant of a given matrix.
     373             : !>    The original content of the matrix is destroyed.
     374             : ! **************************************************************************************************
     375        1086 :    SUBROUTINE cp_cfm_lu_decompose(matrix_a, determinant)
     376             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrix_a
     377             :       COMPLEX(kind=dp), INTENT(out)                      :: determinant
     378             : 
     379             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_lu_decompose'
     380             : 
     381        1086 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a
     382             :       INTEGER                                            :: counter, handle, info, irow, nrow_global
     383        1086 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ipivot
     384             : 
     385             : #if defined(__SCALAPACK)
     386             :       INTEGER                                            :: icol, ncol_local, nrow_local
     387             :       INTEGER, DIMENSION(9)                              :: desca
     388        1086 :       INTEGER, DIMENSION(:), POINTER                     :: col_indices, row_indices
     389             : #else
     390             :       INTEGER                                            :: lda
     391             : #endif
     392             : 
     393        1086 :       CALL timeset(routineN, handle)
     394             : 
     395        1086 :       nrow_global = matrix_a%matrix_struct%nrow_global
     396        1086 :       a => matrix_a%local_data
     397             : 
     398        3258 :       ALLOCATE (ipivot(nrow_global))
     399             : #if defined(__SCALAPACK)
     400             :       CALL cp_cfm_get_info(matrix_a, nrow_local=nrow_local, ncol_local=ncol_local, &
     401        1086 :                            row_indices=row_indices, col_indices=col_indices)
     402             : 
     403       10860 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     404        1086 :       CALL pzgetrf(nrow_global, nrow_global, a(1, 1), 1, 1, desca, ipivot, info)
     405             : 
     406        1086 :       counter = 0
     407        3789 :       DO irow = 1, nrow_local
     408        3789 :          IF (ipivot(irow) .NE. row_indices(irow)) counter = counter + 1
     409             :       END DO
     410             : 
     411        1086 :       IF (MOD(counter, 2) == 0) THEN
     412        1073 :          determinant = z_one
     413             :       ELSE
     414          13 :          determinant = -z_one
     415             :       END IF
     416             : 
     417             :       ! compute product of diagonal elements
     418             :       irow = 1
     419             :       icol = 1
     420        5013 :       DO WHILE (irow <= nrow_local .AND. icol <= ncol_local)
     421        5013 :          IF (row_indices(irow) < col_indices(icol)) THEN
     422           0 :             irow = irow + 1
     423        3927 :          ELSE IF (row_indices(irow) > col_indices(icol)) THEN
     424        1224 :             icol = icol + 1
     425             :          ELSE ! diagonal element
     426        2703 :             determinant = determinant*a(irow, icol)
     427        2703 :             irow = irow + 1
     428        2703 :             icol = icol + 1
     429             :          END IF
     430             :       END DO
     431        1086 :       CALL matrix_a%matrix_struct%para_env%prod(determinant)
     432             : #else
     433             :       lda = SIZE(a, 1)
     434             :       CALL zgetrf(nrow_global, nrow_global, a(1, 1), lda, ipivot, info)
     435             :       counter = 0
     436             :       determinant = z_one
     437             :       DO irow = 1, nrow_global
     438             :          IF (ipivot(irow) .NE. irow) counter = counter + 1
     439             :          determinant = determinant*a(irow, irow)
     440             :       END DO
     441             :       IF (MOD(counter, 2) == 1) determinant = -1.0_dp*determinant
     442             : #endif
     443             : 
     444             :       ! info is allowed to be zero
     445             :       ! this does just signal a zero diagonal element
     446        1086 :       DEALLOCATE (ipivot)
     447             : 
     448        1086 :       CALL timestop(handle)
     449        2172 :    END SUBROUTINE
     450             : 
     451             : ! **************************************************************************************************
     452             : !> \brief Performs one of the matrix-matrix operations:
     453             : !>        matrix_c = alpha * op1( matrix_a ) * op2( matrix_b ) + beta*matrix_c.
     454             : !> \param transa       form of op1( matrix_a ):
     455             : !>                     op1( matrix_a ) = matrix_a,   when transa == 'N' ,
     456             : !>                     op1( matrix_a ) = matrix_a^T, when transa == 'T' ,
     457             : !>                     op1( matrix_a ) = matrix_a^H, when transa == 'C' ,
     458             : !> \param transb       form of op2( matrix_b )
     459             : !> \param m            number of rows of the matrix op1( matrix_a )
     460             : !> \param n            number of columns of the matrix op2( matrix_b )
     461             : !> \param k            number of columns of the matrix op1( matrix_a ) as well as
     462             : !>                     number of rows of the matrix op2( matrix_b )
     463             : !> \param alpha        scale factor
     464             : !> \param matrix_a     matrix A
     465             : !> \param matrix_b     matrix B
     466             : !> \param beta         scale factor
     467             : !> \param matrix_c     matrix C
     468             : !> \param a_first_col  (optional) the first column of the matrix_a to multiply
     469             : !> \param a_first_row  (optional) the first row of the matrix_a to multiply
     470             : !> \param b_first_col  (optional) the first column of the matrix_b to multiply
     471             : !> \param b_first_row  (optional) the first row of the matrix_b to multiply
     472             : !> \param c_first_col  (optional) the first column of the matrix_c
     473             : !> \param c_first_row  (optional) the first row of the matrix_c
     474             : !> \date    07.06.2001
     475             : !> \author  Matthias Krack
     476             : !> \version 1.0
     477             : ! **************************************************************************************************
     478       75654 :    SUBROUTINE cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
     479             :                           matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, c_first_col, &
     480             :                           c_first_row)
     481             :       CHARACTER(len=1), INTENT(in)                       :: transa, transb
     482             :       INTEGER, INTENT(in)                                :: m, n, k
     483             :       COMPLEX(kind=dp), INTENT(in)                       :: alpha
     484             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b
     485             :       COMPLEX(kind=dp), INTENT(in)                       :: beta
     486             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrix_c
     487             :       INTEGER, INTENT(in), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     488             :                                                             b_first_row, c_first_col, c_first_row
     489             : 
     490             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_gemm'
     491             : 
     492       75654 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a, b, c
     493             :       INTEGER                                            :: handle, i_a, i_b, i_c, j_a, j_b, j_c
     494             : #if defined(__SCALAPACK)
     495             :       INTEGER, DIMENSION(9)                              :: desca, descb, descc
     496             : #else
     497             :       INTEGER                                            :: lda, ldb, ldc
     498             : #endif
     499             : 
     500       75654 :       CALL timeset(routineN, handle)
     501       75654 :       a => matrix_a%local_data
     502       75654 :       b => matrix_b%local_data
     503       75654 :       c => matrix_c%local_data
     504             : 
     505       75654 :       IF (PRESENT(a_first_row)) THEN
     506           0 :          i_a = a_first_row
     507             :       ELSE
     508       75654 :          i_a = 1
     509             :       END IF
     510       75654 :       IF (PRESENT(a_first_col)) THEN
     511           0 :          j_a = a_first_col
     512             :       ELSE
     513       75654 :          j_a = 1
     514             :       END IF
     515       75654 :       IF (PRESENT(b_first_row)) THEN
     516           0 :          i_b = b_first_row
     517             :       ELSE
     518       75654 :          i_b = 1
     519             :       END IF
     520       75654 :       IF (PRESENT(b_first_col)) THEN
     521           0 :          j_b = b_first_col
     522             :       ELSE
     523       75654 :          j_b = 1
     524             :       END IF
     525       75654 :       IF (PRESENT(c_first_row)) THEN
     526           0 :          i_c = c_first_row
     527             :       ELSE
     528       75654 :          i_c = 1
     529             :       END IF
     530       75654 :       IF (PRESENT(c_first_col)) THEN
     531           0 :          j_c = c_first_col
     532             :       ELSE
     533       75654 :          j_c = 1
     534             :       END IF
     535             : 
     536             : #if defined(__SCALAPACK)
     537      756540 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     538      756540 :       descb(:) = matrix_b%matrix_struct%descriptor(:)
     539      756540 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     540             : 
     541             :       CALL pzgemm(transa, transb, m, n, k, alpha, a(1, 1), i_a, j_a, desca, &
     542       75654 :                   b(1, 1), i_b, j_b, descb, beta, c(1, 1), i_c, j_c, descc)
     543             : #else
     544             :       lda = SIZE(a, 1)
     545             :       ldb = SIZE(b, 1)
     546             :       ldc = SIZE(c, 1)
     547             : 
     548             :       CALL zgemm(transa, transb, m, n, k, alpha, a(i_a, j_a), &
     549             :                  lda, b(i_b, j_b), ldb, beta, c(i_c, j_c), ldc)
     550             : #endif
     551       75654 :       CALL timestop(handle)
     552       75654 :    END SUBROUTINE cp_cfm_gemm
     553             : 
     554             : ! **************************************************************************************************
     555             : !> \brief Scales columns of the full matrix by corresponding factors.
     556             : !> \param matrix_a matrix to scale
     557             : !> \param scaling  scale factors for every column. The actual number of scaled columns is
     558             : !>                 limited by the number of scale factors given or by the actual number of columns
     559             : !>                 whichever is smaller.
     560             : !> \author Joost VandeVondele
     561             : ! **************************************************************************************************
     562        7682 :    SUBROUTINE cp_cfm_column_scale(matrix_a, scaling)
     563             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrix_a
     564             :       COMPLEX(kind=dp), DIMENSION(:), INTENT(in)         :: scaling
     565             : 
     566             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_column_scale'
     567             : 
     568        7682 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a
     569             :       INTEGER                                            :: handle, icol_local, ncol_local, &
     570             :                                                             nrow_local
     571             : #if defined(__SCALAPACK)
     572        7682 :       INTEGER, DIMENSION(:), POINTER                     :: col_indices
     573             : #endif
     574             : 
     575        7682 :       CALL timeset(routineN, handle)
     576             : 
     577        7682 :       a => matrix_a%local_data
     578             : 
     579             : #if defined(__SCALAPACK)
     580        7682 :       CALL cp_cfm_get_info(matrix_a, nrow_local=nrow_local, ncol_local=ncol_local, col_indices=col_indices)
     581        7682 :       ncol_local = MIN(ncol_local, SIZE(scaling))
     582             : 
     583      321344 :       DO icol_local = 1, ncol_local
     584      321344 :          CALL zscal(nrow_local, scaling(col_indices(icol_local)), a(1, icol_local), 1)
     585             :       END DO
     586             : #else
     587             :       nrow_local = SIZE(a, 1)
     588             :       ncol_local = MIN(SIZE(a, 2), SIZE(scaling))
     589             : 
     590             :       DO icol_local = 1, ncol_local
     591             :          CALL zscal(nrow_local, scaling(icol_local), a(1, icol_local), 1)
     592             :       END DO
     593             : #endif
     594             : 
     595        7682 :       CALL timestop(handle)
     596        7682 :    END SUBROUTINE cp_cfm_column_scale
     597             : 
     598             : ! **************************************************************************************************
     599             : !> \brief Scales a complex matrix by a real number.
     600             : !>      matrix_a = alpha * matrix_b
     601             : !> \param alpha    scale factor
     602             : !> \param matrix_a complex matrix to scale
     603             : ! **************************************************************************************************
     604        7238 :    SUBROUTINE cp_cfm_dscale(alpha, matrix_a)
     605             :       REAL(kind=dp), INTENT(in)                          :: alpha
     606             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a
     607             : 
     608             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'cp_cfm_dscale'
     609             : 
     610             :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a
     611             :       INTEGER                                            :: handle
     612             : 
     613        7238 :       CALL timeset(routineN, handle)
     614             : 
     615             :       NULLIFY (a)
     616             : 
     617        7238 :       a => matrix_a%local_data
     618             : 
     619        7238 :       CALL zdscal(SIZE(a), alpha, a(1, 1), 1)
     620             : 
     621        7238 :       CALL timestop(handle)
     622        7238 :    END SUBROUTINE cp_cfm_dscale
     623             : 
     624             : ! **************************************************************************************************
     625             : !> \brief Scales a complex matrix by a complex number.
     626             : !>      matrix_a = alpha * matrix_b
     627             : !> \param alpha    scale factor
     628             : !> \param matrix_a complex matrix to scale
     629             : !> \note
     630             : !>      use cp_fm_set_all to zero (avoids problems with nan)
     631             : ! **************************************************************************************************
     632       22888 :    SUBROUTINE cp_cfm_zscale(alpha, matrix_a)
     633             :       COMPLEX(kind=dp), INTENT(in)                       :: alpha
     634             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a
     635             : 
     636             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'cp_cfm_zscale'
     637             : 
     638             :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a
     639             :       INTEGER                                            :: handle
     640             : 
     641       22888 :       CALL timeset(routineN, handle)
     642             : 
     643             :       NULLIFY (a)
     644             : 
     645       22888 :       a => matrix_a%local_data
     646             : 
     647       22888 :       CALL zscal(SIZE(a), alpha, a(1, 1), 1)
     648             : 
     649       22888 :       CALL timestop(handle)
     650       22888 :    END SUBROUTINE cp_cfm_zscale
     651             : 
     652             : ! **************************************************************************************************
     653             : !> \brief Solve the system of linear equations A*b=A_general using LU decomposition.
     654             : !>        Pay attention that both matrices are overwritten on exit and that
     655             : !>        the result is stored into the matrix 'general_a'.
     656             : !> \param matrix_a     matrix A (overwritten on exit)
     657             : !> \param general_a    (input) matrix A_general, (output) matrix B
     658             : !> \param determinant  (optional) determinant
     659             : !> \author Florian Schiffmann
     660             : ! **************************************************************************************************
     661        6526 :    SUBROUTINE cp_cfm_solve(matrix_a, general_a, determinant)
     662             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrix_a, general_a
     663             :       COMPLEX(kind=dp), OPTIONAL                         :: determinant
     664             : 
     665             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_solve'
     666             : 
     667        6526 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a, a_general
     668             :       INTEGER                                            :: counter, handle, info, irow, nrow_global
     669        6526 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ipivot
     670             : 
     671             : #if defined(__SCALAPACK)
     672             :       INTEGER                                            :: icol, ncol_local, nrow_local
     673             :       INTEGER, DIMENSION(9)                              :: desca, descb
     674        6526 :       INTEGER, DIMENSION(:), POINTER                     :: col_indices, row_indices
     675             : #else
     676             :       INTEGER                                            :: lda, ldb
     677             : #endif
     678             : 
     679        6526 :       CALL timeset(routineN, handle)
     680             : 
     681        6526 :       a => matrix_a%local_data
     682        6526 :       a_general => general_a%local_data
     683        6526 :       nrow_global = matrix_a%matrix_struct%nrow_global
     684       19578 :       ALLOCATE (ipivot(nrow_global))
     685             : 
     686             : #if defined(__SCALAPACK)
     687       65260 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     688       65260 :       descb(:) = general_a%matrix_struct%descriptor(:)
     689        6526 :       CALL pzgetrf(nrow_global, nrow_global, a(1, 1), 1, 1, desca, ipivot, info)
     690        6526 :       IF (PRESENT(determinant)) THEN
     691             :          CALL cp_cfm_get_info(matrix_a, nrow_local=nrow_local, ncol_local=ncol_local, &
     692        5238 :                               row_indices=row_indices, col_indices=col_indices)
     693             : 
     694        5238 :          counter = 0
     695       15714 :          DO irow = 1, nrow_local
     696       15714 :             IF (ipivot(irow) .NE. row_indices(irow)) counter = counter + 1
     697             :          END DO
     698             : 
     699        5238 :          IF (MOD(counter, 2) == 0) THEN
     700        5236 :             determinant = z_one
     701             :          ELSE
     702           2 :             determinant = -z_one
     703             :          END IF
     704             : 
     705             :          ! compute product of diagonal elements
     706             :          irow = 1
     707             :          icol = 1
     708       20952 :          DO WHILE (irow <= nrow_local .AND. icol <= ncol_local)
     709       20952 :             IF (row_indices(irow) < col_indices(icol)) THEN
     710           0 :                irow = irow + 1
     711       15714 :             ELSE IF (row_indices(irow) > col_indices(icol)) THEN
     712        5238 :                icol = icol + 1
     713             :             ELSE ! diagonal element
     714       10476 :                determinant = determinant*a(irow, icol)
     715       10476 :                irow = irow + 1
     716       10476 :                icol = icol + 1
     717             :             END IF
     718             :          END DO
     719        5238 :          CALL matrix_a%matrix_struct%para_env%prod(determinant)
     720             :       END IF
     721             : 
     722             :       CALL pzgetrs("N", nrow_global, nrow_global, a(1, 1), 1, 1, desca, &
     723        6526 :                    ipivot, a_general(1, 1), 1, 1, descb, info)
     724             : #else
     725             :       lda = SIZE(a, 1)
     726             :       ldb = SIZE(a_general, 1)
     727             :       CALL zgetrf(nrow_global, nrow_global, a(1, 1), lda, ipivot, info)
     728             :       IF (PRESENT(determinant)) THEN
     729             :          counter = 0
     730             :          determinant = z_one
     731             :          DO irow = 1, nrow_global
     732             :             IF (ipivot(irow) .NE. irow) counter = counter + 1
     733             :             determinant = determinant*a(irow, irow)
     734             :          END DO
     735             :          IF (MOD(counter, 2) == 1) determinant = -1.0_dp*determinant
     736             :       END IF
     737             :       CALL zgetrs("N", nrow_global, nrow_global, a(1, 1), lda, ipivot, a_general(1, 1), ldb, info)
     738             : #endif
     739             : 
     740             :       ! info is allowed to be zero
     741             :       ! this does just signal a zero diagonal element
     742        6526 :       DEALLOCATE (ipivot)
     743        6526 :       CALL timestop(handle)
     744             : 
     745        6526 :    END SUBROUTINE cp_cfm_solve
     746             : 
     747             : ! **************************************************************************************************
     748             : !> \brief Inverts a matrix using LU decomposition. The input matrix will be overwritten.
     749             : !> \param matrix     input a general square non-singular matrix, outputs its inverse
     750             : !> \param info_out   optional, if present outputs the info from (p)zgetri
     751             : !> \author Lianheng Tong
     752             : ! **************************************************************************************************
     753       49793 :    SUBROUTINE cp_cfm_lu_invert(matrix, info_out)
     754             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrix
     755             :       INTEGER, INTENT(out), OPTIONAL                     :: info_out
     756             : 
     757             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_lu_invert'
     758             : 
     759       49793 :       COMPLEX(kind=dp), ALLOCATABLE, DIMENSION(:)        :: work
     760             :       COMPLEX(kind=dp), DIMENSION(1)                     :: work1
     761       49793 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: mat
     762             :       INTEGER                                            :: handle, info, lwork, nrows_global
     763       49793 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ipivot
     764             : 
     765             : #if defined(__SCALAPACK)
     766             :       INTEGER                                            :: liwork
     767       49793 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iwork
     768             :       INTEGER, DIMENSION(1)                              :: iwork1
     769             :       INTEGER, DIMENSION(9)                              :: desca
     770             : #else
     771             :       INTEGER                                            :: lda
     772             : #endif
     773             : 
     774       49793 :       CALL timeset(routineN, handle)
     775             : 
     776       49793 :       mat => matrix%local_data
     777       49793 :       nrows_global = matrix%matrix_struct%nrow_global
     778       49793 :       CPASSERT(nrows_global .EQ. matrix%matrix_struct%ncol_global)
     779      149379 :       ALLOCATE (ipivot(nrows_global))
     780             : 
     781             :       ! do LU decomposition
     782             : #if defined(__SCALAPACK)
     783      497930 :       desca = matrix%matrix_struct%descriptor
     784             :       CALL pzgetrf(nrows_global, nrows_global, &
     785       49793 :                    mat(1, 1), 1, 1, desca, ipivot, info)
     786             : #else
     787             :       lda = SIZE(mat, 1)
     788             :       CALL zgetrf(nrows_global, nrows_global, &
     789             :                   mat(1, 1), lda, ipivot, info)
     790             : #endif
     791       49793 :       IF (info /= 0) THEN
     792           0 :          CALL cp_abort(__LOCATION__, "LU decomposition has failed")
     793             :       END IF
     794             : 
     795             :       ! do inversion
     796             : #if defined(__SCALAPACK)
     797             :       CALL pzgetri(nrows_global, mat(1, 1), 1, 1, desca, &
     798       49793 :                    ipivot, work1, -1, iwork1, -1, info)
     799       49793 :       lwork = INT(work1(1))
     800       49793 :       liwork = INT(iwork1(1))
     801      149379 :       ALLOCATE (work(lwork))
     802      149379 :       ALLOCATE (iwork(liwork))
     803             :       CALL pzgetri(nrows_global, mat(1, 1), 1, 1, desca, &
     804       49793 :                    ipivot, work, lwork, iwork, liwork, info)
     805       49793 :       DEALLOCATE (iwork)
     806             : #else
     807             :       CALL zgetri(nrows_global, mat(1, 1), lda, ipivot, work1, -1, info)
     808             :       lwork = INT(work1(1))
     809             :       ALLOCATE (work(lwork))
     810             :       CALL zgetri(nrows_global, mat(1, 1), lda, ipivot, work, lwork, info)
     811             : #endif
     812       49793 :       DEALLOCATE (work)
     813       49793 :       DEALLOCATE (ipivot)
     814             : 
     815       49793 :       IF (PRESENT(info_out)) THEN
     816           0 :          info_out = info
     817             :       ELSE
     818       49793 :          IF (info /= 0) &
     819           0 :             CALL cp_abort(__LOCATION__, "LU inversion has failed")
     820             :       END IF
     821             : 
     822       49793 :       CALL timestop(handle)
     823             : 
     824       49793 :    END SUBROUTINE cp_cfm_lu_invert
     825             : 
     826             : ! **************************************************************************************************
     827             : !> \brief Used to replace a symmetric positive definite matrix M with its Cholesky
     828             : !>      decomposition U: M = U^T * U, with U upper triangular.
     829             : !> \param matrix   the matrix to replace with its Cholesky decomposition
     830             : !> \param n        the number of row (and columns) of the matrix &
     831             : !>                 (defaults to the min(size(matrix)))
     832             : !> \param info_out if present, outputs info from (p)zpotrf
     833             : !> \par History
     834             : !>      05.2002 created [JVdV]
     835             : !>      12.2002 updated, added n optional parm [fawzi]
     836             : !>      09.2021 removed CPASSERT(info == 0) since there is already check of info [Jan Wilhelm]
     837             : !> \author Joost
     838             : ! **************************************************************************************************
     839       22776 :    SUBROUTINE cp_cfm_cholesky_decompose(matrix, n, info_out)
     840             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrix
     841             :       INTEGER, INTENT(in), OPTIONAL                      :: n
     842             :       INTEGER, INTENT(out), OPTIONAL                     :: info_out
     843             : 
     844             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_cholesky_decompose'
     845             : 
     846       22776 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: a
     847             :       INTEGER                                            :: handle, info, my_n
     848             : #if defined(__SCALAPACK)
     849             :       INTEGER, DIMENSION(9)                              :: desca
     850             : #else
     851             :       INTEGER                                            :: lda
     852             : #endif
     853             : 
     854       22776 :       CALL timeset(routineN, handle)
     855             : 
     856             :       my_n = MIN(matrix%matrix_struct%nrow_global, &
     857       22776 :                  matrix%matrix_struct%ncol_global)
     858       22776 :       IF (PRESENT(n)) THEN
     859        6724 :          CPASSERT(n <= my_n)
     860        6724 :          my_n = n
     861             :       END IF
     862             : 
     863       22776 :       a => matrix%local_data
     864             : 
     865             : #if defined(__SCALAPACK)
     866      227760 :       desca(:) = matrix%matrix_struct%descriptor(:)
     867       22776 :       CALL pzpotrf('U', my_n, a(1, 1), 1, 1, desca, info)
     868             : #else
     869             :       lda = SIZE(a, 1)
     870             :       CALL zpotrf('U', my_n, a(1, 1), lda, info)
     871             : #endif
     872             : 
     873       22776 :       IF (PRESENT(info_out)) THEN
     874        6724 :          info_out = info
     875             :       ELSE
     876       16052 :          IF (info /= 0) &
     877             :             CALL cp_abort(__LOCATION__, &
     878           0 :                           "Cholesky decompose failed: matrix is not positive definite  or ill-conditioned")
     879             :       END IF
     880             : 
     881       22776 :       CALL timestop(handle)
     882             : 
     883       22776 :    END SUBROUTINE cp_cfm_cholesky_decompose
     884             : 
     885             : ! **************************************************************************************************
     886             : !> \brief Used to replace Cholesky decomposition by the inverse.
     887             : !> \param matrix : the matrix to invert (must be an upper triangular matrix),
     888             : !>                 and is the output of Cholesky decomposition
     889             : !> \param n : size of the matrix to invert (defaults to the min(size(matrix)))
     890             : !> \param info_out : if present, outputs info of (p)zpotri
     891             : !> \par History
     892             : !>      05.2002 created Lianheng Tong, based on cp_fm_cholesky_invert
     893             : !> \author Lianheng Tong
     894             : ! **************************************************************************************************
     895        6670 :    SUBROUTINE cp_cfm_cholesky_invert(matrix, n, info_out)
     896             :       TYPE(cp_cfm_type), INTENT(IN)           :: matrix
     897             :       INTEGER, INTENT(in), OPTIONAL              :: n
     898             :       INTEGER, INTENT(out), OPTIONAL             :: info_out
     899             : 
     900             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_cholesky_invert'
     901        6670 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER  :: aa
     902             :       INTEGER                                    :: info, handle
     903             :       INTEGER                                    :: my_n
     904             : #if defined(__SCALAPACK)
     905             :       INTEGER, DIMENSION(9)                      :: desca
     906             : #endif
     907             : 
     908        6670 :       CALL timeset(routineN, handle)
     909             : 
     910             :       my_n = MIN(matrix%matrix_struct%nrow_global, &
     911        6670 :                  matrix%matrix_struct%ncol_global)
     912        6670 :       IF (PRESENT(n)) THEN
     913           0 :          CPASSERT(n <= my_n)
     914           0 :          my_n = n
     915             :       END IF
     916             : 
     917        6670 :       aa => matrix%local_data
     918             : 
     919             : #if defined(__SCALAPACK)
     920       66700 :       desca = matrix%matrix_struct%descriptor
     921        6670 :       CALL pzpotri('U', my_n, aa(1, 1), 1, 1, desca, info)
     922             : #else
     923             :       CALL zpotri('U', my_n, aa(1, 1), SIZE(aa, 1), info)
     924             : #endif
     925             : 
     926        6670 :       IF (PRESENT(info_out)) THEN
     927           0 :          info_out = info
     928             :       ELSE
     929        6670 :          IF (info /= 0) &
     930             :             CALL cp_abort(__LOCATION__, &
     931           0 :                           "Cholesky invert failed: the matrix is not positive definite or ill-conditioned.")
     932             :       END IF
     933             : 
     934        6670 :       CALL timestop(handle)
     935             : 
     936        6670 :    END SUBROUTINE cp_cfm_cholesky_invert
     937             : 
     938             : ! **************************************************************************************************
     939             : !> \brief Returns the trace of matrix_a^T matrix_b, i.e
     940             : !>      sum_{i,j}(matrix_a(i,j)*matrix_b(i,j)) .
     941             : !> \param matrix_a a complex matrix
     942             : !> \param matrix_b another complex matrix
     943             : !> \param trace    value of the trace operator
     944             : !> \par History
     945             : !>    * 09.2017 created [Sergey Chulkov]
     946             : !> \author Sergey Chulkov
     947             : !> \note
     948             : !>      Based on the subroutine cp_fm_trace(). Note the transposition of matrix_a!
     949             : ! **************************************************************************************************
     950       27253 :    SUBROUTINE cp_cfm_trace(matrix_a, matrix_b, trace)
     951             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b
     952             :       COMPLEX(kind=dp), INTENT(out)                      :: trace
     953             : 
     954             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'cp_cfm_trace'
     955             : 
     956             :       INTEGER                                            :: handle, mypcol, myprow, ncol_local, &
     957             :                                                             npcol, nprow, nrow_local
     958             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     959             :       TYPE(mp_comm_type)                                 :: group
     960             : 
     961       27253 :       CALL timeset(routineN, handle)
     962             : 
     963       27253 :       context => matrix_a%matrix_struct%context
     964       27253 :       myprow = context%mepos(1)
     965       27253 :       mypcol = context%mepos(2)
     966       27253 :       nprow = context%num_pe(1)
     967       27253 :       npcol = context%num_pe(2)
     968             : 
     969       27253 :       group = matrix_a%matrix_struct%para_env
     970             : 
     971       27253 :       nrow_local = MIN(matrix_a%matrix_struct%nrow_locals(myprow), matrix_b%matrix_struct%nrow_locals(myprow))
     972       27253 :       ncol_local = MIN(matrix_a%matrix_struct%ncol_locals(mypcol), matrix_b%matrix_struct%ncol_locals(mypcol))
     973             : 
     974             :       ! compute an accurate dot-product
     975             :       trace = accurate_dot_product(matrix_a%local_data(1:nrow_local, 1:ncol_local), &
     976       27253 :                                    matrix_b%local_data(1:nrow_local, 1:ncol_local))
     977             : 
     978       27253 :       CALL group%sum(trace)
     979             : 
     980       27253 :       CALL timestop(handle)
     981             : 
     982       27253 :    END SUBROUTINE cp_cfm_trace
     983             : 
     984             : ! **************************************************************************************************
     985             : !> \brief Multiplies in place by a triangular matrix:
     986             : !>       matrix_b = alpha op(triangular_matrix) matrix_b
     987             : !>      or (if side='R')
     988             : !>       matrix_b = alpha matrix_b op(triangular_matrix)
     989             : !>      op(triangular_matrix) is:
     990             : !>       triangular_matrix (if transa="N" and invert_tr=.false.)
     991             : !>       triangular_matrix^T (if transa="T" and invert_tr=.false.)
     992             : !>       triangular_matrix^H (if transa="C" and invert_tr=.false.)
     993             : !>       triangular_matrix^(-1) (if transa="N" and invert_tr=.true.)
     994             : !>       triangular_matrix^(-T) (if transa="T" and invert_tr=.true.)
     995             : !>       triangular_matrix^(-H) (if transa="C" and invert_tr=.true.)
     996             : !> \param triangular_matrix the triangular matrix that multiplies the other
     997             : !> \param matrix_b the matrix that gets multiplied and stores the result
     998             : !> \param side on which side of matrix_b stays op(triangular_matrix)
     999             : !>        (defaults to 'L')
    1000             : !> \param transa_tr ...
    1001             : !> \param invert_tr if the triangular matrix should be inverted
    1002             : !>        (defaults to false)
    1003             : !> \param uplo_tr if triangular_matrix is stored in the upper ('U') or
    1004             : !>        lower ('L') triangle (defaults to 'U')
    1005             : !> \param unit_diag_tr if the diagonal elements of triangular_matrix should
    1006             : !>        be assumed to be 1 (defaults to false)
    1007             : !> \param n_rows the number of rows of the result (defaults to
    1008             : !>        size(matrix_b,1))
    1009             : !> \param n_cols the number of columns of the result (defaults to
    1010             : !>        size(matrix_b,2))
    1011             : !> \param alpha ...
    1012             : !> \par History
    1013             : !>      08.2002 created [fawzi]
    1014             : !> \author Fawzi Mohamed
    1015             : !> \note
    1016             : !>      needs an mpi env
    1017             : ! **************************************************************************************************
    1018       91836 :    SUBROUTINE cp_cfm_triangular_multiply(triangular_matrix, matrix_b, side, &
    1019             :                                          transa_tr, invert_tr, uplo_tr, unit_diag_tr, n_rows, n_cols, &
    1020             :                                          alpha)
    1021             :       TYPE(cp_cfm_type), INTENT(IN)                      :: triangular_matrix, matrix_b
    1022             :       CHARACTER, INTENT(in), OPTIONAL                    :: side, transa_tr
    1023             :       LOGICAL, INTENT(in), OPTIONAL                      :: invert_tr
    1024             :       CHARACTER, INTENT(in), OPTIONAL                    :: uplo_tr
    1025             :       LOGICAL, INTENT(in), OPTIONAL                      :: unit_diag_tr
    1026             :       INTEGER, INTENT(in), OPTIONAL                      :: n_rows, n_cols
    1027             :       COMPLEX(kind=dp), INTENT(in), OPTIONAL             :: alpha
    1028             : 
    1029             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_triangular_multiply'
    1030             : 
    1031             :       CHARACTER                                          :: side_char, transa, unit_diag, uplo
    1032             :       COMPLEX(kind=dp)                                   :: al
    1033             :       INTEGER                                            :: handle, m, n
    1034             :       LOGICAL                                            :: invert
    1035             : 
    1036       45918 :       CALL timeset(routineN, handle)
    1037       45918 :       side_char = 'L'
    1038       45918 :       unit_diag = 'N'
    1039       45918 :       uplo = 'U'
    1040       45918 :       transa = 'N'
    1041       45918 :       invert = .FALSE.
    1042       45918 :       al = CMPLX(1.0_dp, 0.0_dp, dp)
    1043       45918 :       CALL cp_cfm_get_info(matrix_b, nrow_global=m, ncol_global=n)
    1044       45918 :       IF (PRESENT(side)) side_char = side
    1045       45918 :       IF (PRESENT(invert_tr)) invert = invert_tr
    1046       45918 :       IF (PRESENT(uplo_tr)) uplo = uplo_tr
    1047       45918 :       IF (PRESENT(unit_diag_tr)) THEN
    1048           0 :          IF (unit_diag_tr) THEN
    1049           0 :             unit_diag = 'U'
    1050             :          ELSE
    1051             :             unit_diag = 'N'
    1052             :          END IF
    1053             :       END IF
    1054       45918 :       IF (PRESENT(transa_tr)) transa = transa_tr
    1055       45918 :       IF (PRESENT(alpha)) al = alpha
    1056       45918 :       IF (PRESENT(n_rows)) m = n_rows
    1057       45918 :       IF (PRESENT(n_cols)) n = n_cols
    1058             : 
    1059       45918 :       IF (invert) THEN
    1060             : 
    1061             : #if defined(__SCALAPACK)
    1062             :          CALL pztrsm(side_char, uplo, transa, unit_diag, m, n, al, &
    1063             :                      triangular_matrix%local_data(1, 1), 1, 1, &
    1064             :                      triangular_matrix%matrix_struct%descriptor, &
    1065             :                      matrix_b%local_data(1, 1), 1, 1, &
    1066         534 :                      matrix_b%matrix_struct%descriptor(1))
    1067             : #else
    1068             :          CALL ztrsm(side_char, uplo, transa, unit_diag, m, n, al, &
    1069             :                     triangular_matrix%local_data(1, 1), &
    1070             :                     SIZE(triangular_matrix%local_data, 1), &
    1071             :                     matrix_b%local_data(1, 1), SIZE(matrix_b%local_data, 1))
    1072             : #endif
    1073             : 
    1074             :       ELSE
    1075             : 
    1076             : #if defined(__SCALAPACK)
    1077             :          CALL pztrmm(side_char, uplo, transa, unit_diag, m, n, al, &
    1078             :                      triangular_matrix%local_data(1, 1), 1, 1, &
    1079             :                      triangular_matrix%matrix_struct%descriptor, &
    1080             :                      matrix_b%local_data(1, 1), 1, 1, &
    1081       45384 :                      matrix_b%matrix_struct%descriptor(1))
    1082             : #else
    1083             :          CALL ztrmm(side_char, uplo, transa, unit_diag, m, n, al, &
    1084             :                     triangular_matrix%local_data(1, 1), &
    1085             :                     SIZE(triangular_matrix%local_data, 1), &
    1086             :                     matrix_b%local_data(1, 1), SIZE(matrix_b%local_data, 1))
    1087             : #endif
    1088             : 
    1089             :       END IF
    1090             : 
    1091       45918 :       CALL timestop(handle)
    1092             : 
    1093       45918 :    END SUBROUTINE cp_cfm_triangular_multiply
    1094             : 
    1095             : ! **************************************************************************************************
    1096             : !> \brief Inverts a triangular matrix.
    1097             : !> \param matrix_a ...
    1098             : !> \param uplo ...
    1099             : !> \param info_out ...
    1100             : !> \author MI
    1101             : ! **************************************************************************************************
    1102       15128 :    SUBROUTINE cp_cfm_triangular_invert(matrix_a, uplo, info_out)
    1103             :       TYPE(cp_cfm_type), INTENT(IN)         :: matrix_a
    1104             :       CHARACTER, INTENT(in), OPTIONAL          :: uplo
    1105             :       INTEGER, INTENT(out), OPTIONAL           :: info_out
    1106             : 
    1107             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_triangular_invert'
    1108             : 
    1109             :       CHARACTER                                :: unit_diag, my_uplo
    1110             :       INTEGER                                  :: handle, info, ncol_global
    1111             :       COMPLEX(kind=dp), DIMENSION(:, :), &
    1112       15128 :          POINTER                                :: a
    1113             : #if defined(__SCALAPACK)
    1114             :       INTEGER, DIMENSION(9)                    :: desca
    1115             : #endif
    1116             : 
    1117       15128 :       CALL timeset(routineN, handle)
    1118             : 
    1119       15128 :       unit_diag = 'N'
    1120       15128 :       my_uplo = 'U'
    1121       15128 :       IF (PRESENT(uplo)) my_uplo = uplo
    1122             : 
    1123       15128 :       ncol_global = matrix_a%matrix_struct%ncol_global
    1124             : 
    1125       15128 :       a => matrix_a%local_data
    1126             : 
    1127             : #if defined(__SCALAPACK)
    1128      151280 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1129       15128 :       CALL pztrtri(my_uplo, unit_diag, ncol_global, a(1, 1), 1, 1, desca, info)
    1130             : #else
    1131             :       CALL ztrtri(my_uplo, unit_diag, ncol_global, a(1, 1), ncol_global, info)
    1132             : #endif
    1133             : 
    1134       15128 :       IF (PRESENT(info_out)) THEN
    1135           0 :          info_out = info
    1136             :       ELSE
    1137       15128 :          IF (info /= 0) &
    1138             :             CALL cp_abort(__LOCATION__, &
    1139           0 :                           "triangular invert failed: matrix is not positive definite  or ill-conditioned")
    1140             :       END IF
    1141             : 
    1142       15128 :       CALL timestop(handle)
    1143       15128 :    END SUBROUTINE cp_cfm_triangular_invert
    1144             : 
    1145             : ! **************************************************************************************************
    1146             : !> \brief Transposes a BLACS distributed complex matrix.
    1147             : !> \param matrix    input matrix
    1148             : !> \param trans     'T' for transpose, 'C' for Hermitian conjugate
    1149             : !> \param matrixt   output matrix
    1150             : !> \author Lianheng Tong
    1151             : ! **************************************************************************************************
    1152       14574 :    SUBROUTINE cp_cfm_transpose(matrix, trans, matrixt)
    1153             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix
    1154             :       CHARACTER, INTENT(in)                              :: trans
    1155             :       TYPE(cp_cfm_type), INTENT(IN)                   :: matrixt
    1156             : 
    1157             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_transpose'
    1158             : 
    1159       14574 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: aa, cc
    1160             :       INTEGER                                            :: handle, ncol_global, nrow_global
    1161             : #if defined(__SCALAPACK)
    1162             :       INTEGER, DIMENSION(9)                              :: desca, descc
    1163             : #else
    1164             :       INTEGER                                            :: ii, jj
    1165             : #endif
    1166             : 
    1167       14574 :       CALL timeset(routineN, handle)
    1168             : 
    1169       14574 :       nrow_global = matrix%matrix_struct%nrow_global
    1170       14574 :       ncol_global = matrix%matrix_struct%ncol_global
    1171             : 
    1172       14574 :       CPASSERT(matrixt%matrix_struct%nrow_global == ncol_global)
    1173       14574 :       CPASSERT(matrixt%matrix_struct%ncol_global == nrow_global)
    1174             : 
    1175       14574 :       aa => matrix%local_data
    1176       14574 :       cc => matrixt%local_data
    1177             : 
    1178             : #if defined(__SCALAPACK)
    1179      145740 :       desca = matrix%matrix_struct%descriptor
    1180      145740 :       descc = matrixt%matrix_struct%descriptor
    1181        6610 :       SELECT CASE (trans)
    1182             :       CASE ('T')
    1183             :          CALL pztranu(nrow_global, ncol_global, &
    1184             :                       z_one, aa(1, 1), 1, 1, desca, &
    1185        6610 :                       z_zero, cc(1, 1), 1, 1, descc)
    1186             :       CASE ('C')
    1187             :          CALL pztranc(nrow_global, ncol_global, &
    1188             :                       z_one, aa(1, 1), 1, 1, desca, &
    1189        7964 :                       z_zero, cc(1, 1), 1, 1, descc)
    1190             :       CASE DEFAULT
    1191       14574 :          CPABORT("trans only accepts 'T' or 'C'")
    1192             :       END SELECT
    1193             : #else
    1194             :       SELECT CASE (trans)
    1195             :       CASE ('T')
    1196             :          DO jj = 1, ncol_global
    1197             :             DO ii = 1, nrow_global
    1198             :                cc(ii, jj) = aa(jj, ii)
    1199             :             END DO
    1200             :          END DO
    1201             :       CASE ('C')
    1202             :          DO jj = 1, ncol_global
    1203             :             DO ii = 1, nrow_global
    1204             :                cc(ii, jj) = CONJG(aa(jj, ii))
    1205             :             END DO
    1206             :          END DO
    1207             :       CASE DEFAULT
    1208             :          CPABORT("trans only accepts 'T' or 'C'")
    1209             :       END SELECT
    1210             : #endif
    1211             : 
    1212       14574 :       CALL timestop(handle)
    1213       14574 :    END SUBROUTINE cp_cfm_transpose
    1214             : 
    1215             : ! **************************************************************************************************
    1216             : !> \brief Norm of matrix using (p)zlange.
    1217             : !> \param matrix     input a general matrix
    1218             : !> \param mode       'M' max abs element value,
    1219             : !>                   '1' or 'O' one norm, i.e. maximum column sum,
    1220             : !>                   'I' infinity norm, i.e. maximum row sum,
    1221             : !>                   'F' or 'E' Frobenius norm, i.e. sqrt of sum of all squares of elements
    1222             : !> \return the norm according to mode
    1223             : !> \author Lianheng Tong
    1224             : ! **************************************************************************************************
    1225       88052 :    FUNCTION cp_cfm_norm(matrix, mode) RESULT(res)
    1226             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix
    1227             :       CHARACTER, INTENT(IN)                              :: mode
    1228             :       REAL(kind=dp)                                      :: res
    1229             : 
    1230             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_norm'
    1231             : 
    1232       88052 :       COMPLEX(kind=dp), DIMENSION(:, :), POINTER         :: aa
    1233             :       INTEGER                                            :: handle, lwork, ncols, ncols_local, &
    1234             :                                                             nrows, nrows_local
    1235       88052 :       REAL(kind=dp), ALLOCATABLE, DIMENSION(:)           :: work
    1236             : 
    1237             : #if defined(__SCALAPACK)
    1238             :       INTEGER, DIMENSION(9)                              :: desca
    1239             : #else
    1240             :       INTEGER                                            :: lda
    1241             : #endif
    1242             : 
    1243       88052 :       CALL timeset(routineN, handle)
    1244             : 
    1245             :       CALL cp_cfm_get_info(matrix=matrix, &
    1246             :                            nrow_global=nrows, &
    1247             :                            ncol_global=ncols, &
    1248             :                            nrow_local=nrows_local, &
    1249       88052 :                            ncol_local=ncols_local)
    1250       88052 :       aa => matrix%local_data
    1251             : 
    1252             :       SELECT CASE (mode)
    1253             :       CASE ('M', 'm')
    1254           0 :          lwork = 1
    1255             :       CASE ('1', 'O', 'o')
    1256             : #if defined(__SCALAPACK)
    1257           0 :          lwork = ncols_local
    1258             : #else
    1259             :          lwork = 1
    1260             : #endif
    1261             :       CASE ('I', 'i')
    1262             : #if defined(__SCALAPACK)
    1263           0 :          lwork = nrows_local
    1264             : #else
    1265             :          lwork = nrows
    1266             : #endif
    1267             :       CASE ('F', 'f', 'E', 'e')
    1268           0 :          lwork = 1
    1269             :       CASE DEFAULT
    1270       88052 :          CPABORT("mode input is not valid")
    1271             :       END SELECT
    1272             : 
    1273      264156 :       ALLOCATE (work(lwork))
    1274             : 
    1275             : #if defined(__SCALAPACK)
    1276      880520 :       desca = matrix%matrix_struct%descriptor
    1277       88052 :       res = pzlange(mode, nrows, ncols, aa(1, 1), 1, 1, desca, work)
    1278             : #else
    1279             :       lda = SIZE(aa, 1)
    1280             :       res = zlange(mode, nrows, ncols, aa(1, 1), lda, work)
    1281             : #endif
    1282             : 
    1283       88052 :       DEALLOCATE (work)
    1284       88052 :       CALL timestop(handle)
    1285       88052 :    END FUNCTION cp_cfm_norm
    1286             : 
    1287             : ! **************************************************************************************************
    1288             : !> \brief Applies a planar rotation defined by cs and sn to the i'th and j'th rows.
    1289             : !> \param matrix ...
    1290             : !> \param irow ...
    1291             : !> \param jrow ...
    1292             : !> \param cs cosine of the rotation angle
    1293             : !> \param sn sinus of the rotation angle
    1294             : !> \author Ole Schuett
    1295             : ! **************************************************************************************************
    1296           0 :    SUBROUTINE cp_cfm_rot_rows(matrix, irow, jrow, cs, sn)
    1297             :       TYPE(cp_cfm_type), INTENT(IN)            :: matrix
    1298             :       INTEGER, INTENT(IN)                      :: irow, jrow
    1299             :       REAL(dp), INTENT(IN)                     :: cs, sn
    1300             : 
    1301             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_rot_rows'
    1302             :       INTEGER                                  :: handle, nrow, ncol
    1303             :       COMPLEX(KIND=dp)                         :: sn_cmplx
    1304             : 
    1305             : #if defined(__SCALAPACK)
    1306             :       INTEGER                                  :: info, lwork
    1307             :       INTEGER, DIMENSION(9)                    :: desc
    1308           0 :       REAL(dp), DIMENSION(:), ALLOCATABLE      :: work
    1309             : #endif
    1310             : 
    1311           0 :       CALL timeset(routineN, handle)
    1312           0 :       CALL cp_cfm_get_info(matrix, nrow_global=nrow, ncol_global=ncol)
    1313           0 :       sn_cmplx = CMPLX(sn, 0.0_dp, dp)
    1314             : 
    1315             : #if defined(__SCALAPACK)
    1316           0 :       lwork = 2*ncol + 1
    1317           0 :       ALLOCATE (work(lwork))
    1318           0 :       desc(:) = matrix%matrix_struct%descriptor(:)
    1319             :       CALL pzrot(ncol, &
    1320             :                  matrix%local_data(1, 1), irow, 1, desc, ncol, &
    1321             :                  matrix%local_data(1, 1), jrow, 1, desc, ncol, &
    1322           0 :                  cs, sn_cmplx, work, lwork, info)
    1323           0 :       CPASSERT(info == 0)
    1324           0 :       DEALLOCATE (work)
    1325             : #else
    1326             :       CALL zrot(ncol, matrix%local_data(irow, 1), ncol, matrix%local_data(jrow, 1), ncol, cs, sn_cmplx)
    1327             : #endif
    1328             : 
    1329           0 :       CALL timestop(handle)
    1330           0 :    END SUBROUTINE cp_cfm_rot_rows
    1331             : 
    1332             : ! **************************************************************************************************
    1333             : !> \brief Applies a planar rotation defined by cs and sn to the i'th and j'th columnns.
    1334             : !> \param matrix ...
    1335             : !> \param icol ...
    1336             : !> \param jcol ...
    1337             : !> \param cs cosine of the rotation angle
    1338             : !> \param sn sinus of the rotation angle
    1339             : !> \author Ole Schuett
    1340             : ! **************************************************************************************************
    1341           0 :    SUBROUTINE cp_cfm_rot_cols(matrix, icol, jcol, cs, sn)
    1342             :       TYPE(cp_cfm_type), INTENT(IN)            :: matrix
    1343             :       INTEGER, INTENT(IN)                      :: icol, jcol
    1344             :       REAL(dp), INTENT(IN)                     :: cs, sn
    1345             : 
    1346             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_cfm_rot_cols'
    1347             :       INTEGER                                  :: handle, nrow, ncol
    1348             :       COMPLEX(KIND=dp)                         :: sn_cmplx
    1349             : 
    1350             : #if defined(__SCALAPACK)
    1351             :       INTEGER                                  :: info, lwork
    1352             :       INTEGER, DIMENSION(9)                    :: desc
    1353           0 :       REAL(dp), DIMENSION(:), ALLOCATABLE      :: work
    1354             : #endif
    1355             : 
    1356           0 :       CALL timeset(routineN, handle)
    1357           0 :       CALL cp_cfm_get_info(matrix, nrow_global=nrow, ncol_global=ncol)
    1358           0 :       sn_cmplx = CMPLX(sn, 0.0_dp, dp)
    1359             : 
    1360             : #if defined(__SCALAPACK)
    1361           0 :       lwork = 2*nrow + 1
    1362           0 :       ALLOCATE (work(lwork))
    1363           0 :       desc(:) = matrix%matrix_struct%descriptor(:)
    1364             :       CALL pzrot(nrow, &
    1365             :                  matrix%local_data(1, 1), 1, icol, desc, 1, &
    1366             :                  matrix%local_data(1, 1), 1, jcol, desc, 1, &
    1367           0 :                  cs, sn_cmplx, work, lwork, info)
    1368           0 :       CPASSERT(info == 0)
    1369           0 :       DEALLOCATE (work)
    1370             : #else
    1371             :       CALL zrot(nrow, matrix%local_data(1, icol), 1, matrix%local_data(1, jcol), 1, cs, sn_cmplx)
    1372             : #endif
    1373             : 
    1374           0 :       CALL timestop(handle)
    1375           0 :    END SUBROUTINE cp_cfm_rot_cols
    1376             : 
    1377             : END MODULE cp_cfm_basic_linalg

Generated by: LCOV version 1.15