LCOV - code coverage report
Current view: top level - src - parallel_gemm_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 82.7 % 75 62
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 4 4

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief basic linear algebra operations for full matrixes
      10              : !> \par History
      11              : !>      08.2002 splitted out of qs_blacs [fawzi]
      12              : !> \author Fawzi Mohamed
      13              : ! **************************************************************************************************
      14              : MODULE parallel_gemm_api
      15              :    USE ISO_C_BINDING,                   ONLY: C_CHAR,&
      16              :                                               C_DOUBLE,&
      17              :                                               C_INT,&
      18              :                                               C_LOC,&
      19              :                                               C_PTR
      20              :    USE cp_cfm_basic_linalg,             ONLY: cp_cfm_gemm
      21              :    USE cp_cfm_types,                    ONLY: cp_cfm_type
      22              :    USE cp_fm_basic_linalg,              ONLY: cp_fm_gemm
      23              :    USE cp_fm_types,                     ONLY: cp_fm_get_mm_type,&
      24              :                                               cp_fm_set_all_submatrix,&
      25              :                                               cp_fm_type
      26              :    USE input_constants,                 ONLY: do_cosma,&
      27              :                                               do_scalapack
      28              :    USE kinds,                           ONLY: dp
      29              :    USE offload_api,                     ONLY: offload_activate_chosen_device
      30              : #include "./base/base_uses.f90"
      31              : 
      32              :    IMPLICIT NONE
      33              :    PRIVATE
      34              : 
      35              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'parallel_gemm_api'
      36              : 
      37              :    PUBLIC :: parallel_gemm
      38              : 
      39              :    INTERFACE parallel_gemm
      40              :       MODULE PROCEDURE parallel_gemm_fm
      41              :       MODULE PROCEDURE parallel_gemm_cfm
      42              :    END INTERFACE parallel_gemm
      43              : 
      44              : CONTAINS
      45              : 
      46              : ! **************************************************************************************************
      47              : !> \brief ...
      48              : !> \param transa ...
      49              : !> \param transb ...
      50              : !> \param m ...
      51              : !> \param n ...
      52              : !> \param k ...
      53              : !> \param alpha ...
      54              : !> \param matrix_a ...
      55              : !> \param matrix_b ...
      56              : !> \param beta ...
      57              : !> \param matrix_c ...
      58              : !> \param a_first_col ...
      59              : !> \param a_first_row ...
      60              : !> \param b_first_col ...
      61              : !> \param b_first_row ...
      62              : !> \param c_first_col ...
      63              : !> \param c_first_row ...
      64              : ! **************************************************************************************************
      65      1155502 :    SUBROUTINE parallel_gemm_fm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
      66              :                                matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
      67              :                                c_first_col, c_first_row)
      68              :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
      69              :       INTEGER, INTENT(IN)                                :: m, n, k
      70              :       REAL(KIND=dp), INTENT(IN)                          :: alpha
      71              :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b
      72              :       REAL(KIND=dp), INTENT(IN)                          :: beta
      73              :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c
      74              :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
      75              :                                                             b_first_row, c_first_col, c_first_row
      76              : 
      77              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'parallel_gemm_fm'
      78              : 
      79              :       INTEGER                                            :: cfc, cfr, handle, my_multi
      80              : 
      81              :       MARK_USED(cfc)
      82              :       MARK_USED(cfr)
      83              : 
      84      1155502 :       my_multi = cp_fm_get_mm_type()
      85              : 
      86            0 :       SELECT CASE (my_multi)
      87              :       CASE (do_scalapack)
      88            0 :          CALL timeset(routineN//"_gemm", handle)
      89              :          CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
      90              :                          a_first_col=a_first_col, &
      91              :                          a_first_row=a_first_row, &
      92              :                          b_first_col=b_first_col, &
      93              :                          b_first_row=b_first_row, &
      94              :                          c_first_col=c_first_col, &
      95            0 :                          c_first_row=c_first_row)
      96              :       CASE (do_cosma)
      97              : #if defined(__COSMA)
      98      1155502 :          CALL timeset(routineN//"_cosma", handle)
      99              :          !> This seems not to be correct in COSMA! See BLAS definition:
     100              :          !>           On entry,  BETA  specifies the scalar  beta.  When  BETA  is
     101              :          !>           supplied as zero then C need not be set on input.
     102      1155502 :          IF (beta == 0.0_dp) THEN
     103       879555 :             cfr = 1
     104       879555 :             cfc = 1
     105       879555 :             IF (PRESENT(c_first_row)) cfr = c_first_row
     106       879555 :             IF (PRESENT(c_first_col)) cfc = c_first_col
     107       879555 :             CALL cp_fm_set_all_submatrix(matrix_c, 0.0_dp, cfr, cfc, m, n)
     108              :          END IF
     109      1155502 :          CALL offload_activate_chosen_device()
     110              :          CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
     111              :                            matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
     112              :                            a_first_col=a_first_col, &
     113              :                            a_first_row=a_first_row, &
     114              :                            b_first_col=b_first_col, &
     115              :                            b_first_row=b_first_row, &
     116              :                            c_first_col=c_first_col, &
     117      2311004 :                            c_first_row=c_first_row)
     118              : #else
     119              :          CPABORT("CP2K compiled without the COSMA library.")
     120              : #endif
     121              :       END SELECT
     122      1155502 :       CALL timestop(handle)
     123              : 
     124      1155502 :    END SUBROUTINE parallel_gemm_fm
     125              : 
     126              : ! **************************************************************************************************
     127              : !> \brief ...
     128              : !> \param transa ...
     129              : !> \param transb ...
     130              : !> \param m ...
     131              : !> \param n ...
     132              : !> \param k ...
     133              : !> \param alpha ...
     134              : !> \param matrix_a ...
     135              : !> \param matrix_b ...
     136              : !> \param beta ...
     137              : !> \param matrix_c ...
     138              : !> \param a_first_col ...
     139              : !> \param a_first_row ...
     140              : !> \param b_first_col ...
     141              : !> \param b_first_row ...
     142              : !> \param c_first_col ...
     143              : !> \param c_first_row ...
     144              : ! **************************************************************************************************
     145       283736 :    SUBROUTINE parallel_gemm_cfm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
     146              :                                 matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
     147              :                                 c_first_col, c_first_row)
     148              :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
     149              :       INTEGER, INTENT(IN)                                :: m, n, k
     150              :       COMPLEX(KIND=dp), INTENT(IN)                       :: alpha
     151              :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b
     152              :       COMPLEX(KIND=dp), INTENT(IN)                       :: beta
     153              :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_c
     154              :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     155              :                                                             b_first_row, c_first_col, c_first_row
     156              : 
     157              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'parallel_gemm_cfm'
     158              : 
     159              :       INTEGER                                            :: handle, handle1, my_multi
     160              : 
     161       283736 :       CALL timeset(routineN, handle)
     162              : 
     163       283736 :       my_multi = cp_fm_get_mm_type()
     164              : 
     165            0 :       SELECT CASE (my_multi)
     166              :       CASE (do_scalapack)
     167            0 :          CALL timeset(routineN//"_gemm", handle1)
     168              :          CALL cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
     169              :                           a_first_col=a_first_col, &
     170              :                           a_first_row=a_first_row, &
     171              :                           b_first_col=b_first_col, &
     172              :                           b_first_row=b_first_row, &
     173              :                           c_first_col=c_first_col, &
     174            0 :                           c_first_row=c_first_row)
     175            0 :          CALL timestop(handle1)
     176              :       CASE (do_cosma)
     177              : #if defined(__COSMA)
     178       283736 :          CALL timeset(routineN//"_cosma", handle1)
     179       283736 :          CALL offload_activate_chosen_device()
     180              :          CALL cosma_pzgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
     181              :                            matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
     182              :                            a_first_col=a_first_col, &
     183              :                            a_first_row=a_first_row, &
     184              :                            b_first_col=b_first_col, &
     185              :                            b_first_row=b_first_row, &
     186              :                            c_first_col=c_first_col, &
     187       283736 :                            c_first_row=c_first_row)
     188       567472 :          CALL timestop(handle1)
     189              : #else
     190              :          CPABORT("CP2K compiled without the COSMA library.")
     191              : #endif
     192              :       END SELECT
     193       283736 :       CALL timestop(handle)
     194              : 
     195       283736 :    END SUBROUTINE parallel_gemm_cfm
     196              : 
     197              : #if defined(__COSMA)
     198              : ! **************************************************************************************************
     199              : !> \brief Fortran wrapper for cosma_pdgemm.
     200              : !> \param transa ...
     201              : !> \param transb ...
     202              : !> \param m ...
     203              : !> \param n ...
     204              : !> \param k ...
     205              : !> \param alpha ...
     206              : !> \param matrix_a ...
     207              : !> \param matrix_b ...
     208              : !> \param beta ...
     209              : !> \param matrix_c ...
     210              : !> \param a_first_col ...
     211              : !> \param a_first_row ...
     212              : !> \param b_first_col ...
     213              : !> \param b_first_row ...
     214              : !> \param c_first_col ...
     215              : !> \param c_first_row ...
     216              : !> \author Ole Schuett
     217              : ! **************************************************************************************************
     218      1155502 :    SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
     219              :                            a_first_col, a_first_row, b_first_col, b_first_row, &
     220              :                            c_first_col, c_first_row)
     221              :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
     222              :       INTEGER, INTENT(IN)                                :: m, n, k
     223              :       REAL(KIND=dp), INTENT(IN)                          :: alpha
     224              :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b
     225              :       REAL(KIND=dp), INTENT(IN)                          :: beta
     226              :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c
     227              :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     228              :                                                             b_first_row, c_first_col, c_first_row
     229              : 
     230              :       INTEGER                                            :: i_a, i_b, i_c, j_a, j_b, j_c
     231              :       INTERFACE
     232              :          SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
     233              :                                    b, ib, jb, descb, beta, c, ic, jc, descc) &
     234              :             BIND(C, name="cosma_pdgemm")
     235              :             IMPORT :: C_PTR, C_INT, C_DOUBLE, C_CHAR
     236              :             CHARACTER(KIND=C_CHAR)                    :: transa
     237              :             CHARACTER(KIND=C_CHAR)                    :: transb
     238              :             INTEGER(KIND=C_INT)                       :: m
     239              :             INTEGER(KIND=C_INT)                       :: n
     240              :             INTEGER(KIND=C_INT)                       :: k
     241              :             REAL(KIND=C_DOUBLE)                       :: alpha
     242              :             TYPE(C_PTR), VALUE                        :: a
     243              :             INTEGER(KIND=C_INT)                       :: ia
     244              :             INTEGER(KIND=C_INT)                       :: ja
     245              :             TYPE(C_PTR), VALUE                        :: desca
     246              :             TYPE(C_PTR), VALUE                        :: b
     247              :             INTEGER(KIND=C_INT)                       :: ib
     248              :             INTEGER(KIND=C_INT)                       :: jb
     249              :             TYPE(C_PTR), VALUE                        :: descb
     250              :             REAL(KIND=C_DOUBLE)                       :: beta
     251              :             TYPE(C_PTR), VALUE                        :: c
     252              :             INTEGER(KIND=C_INT)                       :: ic
     253              :             INTEGER(KIND=C_INT)                       :: jc
     254              :             TYPE(C_PTR), VALUE                        :: descc
     255              :          END SUBROUTINE cosma_pdgemm_c
     256              :       END INTERFACE
     257              : 
     258      1155502 :       IF (PRESENT(a_first_row)) THEN
     259         2742 :          i_a = a_first_row
     260              :       ELSE
     261      1152760 :          i_a = 1
     262              :       END IF
     263      1155502 :       IF (PRESENT(a_first_col)) THEN
     264         2742 :          j_a = a_first_col
     265              :       ELSE
     266      1152760 :          j_a = 1
     267              :       END IF
     268      1155502 :       IF (PRESENT(b_first_row)) THEN
     269         3100 :          i_b = b_first_row
     270              :       ELSE
     271      1152402 :          i_b = 1
     272              :       END IF
     273      1155502 :       IF (PRESENT(b_first_col)) THEN
     274         4052 :          j_b = b_first_col
     275              :       ELSE
     276      1151450 :          j_b = 1
     277              :       END IF
     278      1155502 :       IF (PRESENT(c_first_row)) THEN
     279         2498 :          i_c = c_first_row
     280              :       ELSE
     281      1153004 :          i_c = 1
     282              :       END IF
     283      1155502 :       IF (PRESENT(c_first_col)) THEN
     284         2516 :          j_c = c_first_col
     285              :       ELSE
     286      1152986 :          j_c = 1
     287              :       END IF
     288              : 
     289              :       CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
     290              :                           alpha=alpha, &
     291              :                           a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
     292              :                           desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
     293              :                           b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
     294              :                           descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
     295              :                           beta=beta, &
     296              :                           c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
     297      1155502 :                           descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
     298              : 
     299      1155502 :    END SUBROUTINE cosma_pdgemm
     300              : 
     301              : ! **************************************************************************************************
     302              : !> \brief Fortran wrapper for cosma_pdgemm.
     303              : !> \param transa ...
     304              : !> \param transb ...
     305              : !> \param m ...
     306              : !> \param n ...
     307              : !> \param k ...
     308              : !> \param alpha ...
     309              : !> \param matrix_a ...
     310              : !> \param matrix_b ...
     311              : !> \param beta ...
     312              : !> \param matrix_c ...
     313              : !> \param a_first_col ...
     314              : !> \param a_first_row ...
     315              : !> \param b_first_col ...
     316              : !> \param b_first_row ...
     317              : !> \param c_first_col ...
     318              : !> \param c_first_row ...
     319              : !> \author Ole Schuett
     320              : ! **************************************************************************************************
     321       283736 :    SUBROUTINE cosma_pzgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
     322              :                            a_first_col, a_first_row, b_first_col, b_first_row, &
     323              :                            c_first_col, c_first_row)
     324              :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
     325              :       INTEGER, INTENT(IN)                                :: m, n, k
     326              :       COMPLEX(KIND=dp), INTENT(IN)                       :: alpha
     327              :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b
     328              :       COMPLEX(KIND=dp), INTENT(IN)                       :: beta
     329              :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_c
     330              :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     331              :                                                             b_first_row, c_first_col, c_first_row
     332              : 
     333              :       INTEGER                                            :: i_a, i_b, i_c, j_a, j_b, j_c
     334              :       REAL(KIND=dp), DIMENSION(2), TARGET                :: alpha_t, beta_t
     335              :       INTERFACE
     336              :          SUBROUTINE cosma_pzgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
     337              :                                    b, ib, jb, descb, beta, c, ic, jc, descc) &
     338              :             BIND(C, name="cosma_pzgemm")
     339              :             IMPORT :: C_PTR, C_INT, C_CHAR
     340              :             CHARACTER(KIND=C_CHAR)                    :: transa
     341              :             CHARACTER(KIND=C_CHAR)                    :: transb
     342              :             INTEGER(KIND=C_INT)                       :: m
     343              :             INTEGER(KIND=C_INT)                       :: n
     344              :             INTEGER(KIND=C_INT)                       :: k
     345              :             TYPE(C_PTR), VALUE                        :: alpha
     346              :             TYPE(C_PTR), VALUE                        :: a
     347              :             INTEGER(KIND=C_INT)                       :: ia
     348              :             INTEGER(KIND=C_INT)                       :: ja
     349              :             TYPE(C_PTR), VALUE                        :: desca
     350              :             TYPE(C_PTR), VALUE                        :: b
     351              :             INTEGER(KIND=C_INT)                       :: ib
     352              :             INTEGER(KIND=C_INT)                       :: jb
     353              :             TYPE(C_PTR), VALUE                        :: descb
     354              :             TYPE(C_PTR), VALUE                        :: beta
     355              :             TYPE(C_PTR), VALUE                        :: c
     356              :             INTEGER(KIND=C_INT)                       :: ic
     357              :             INTEGER(KIND=C_INT)                       :: jc
     358              :             TYPE(C_PTR), VALUE                        :: descc
     359              :          END SUBROUTINE cosma_pzgemm_c
     360              :       END INTERFACE
     361              : 
     362       283736 :       IF (PRESENT(a_first_row)) THEN
     363            0 :          i_a = a_first_row
     364              :       ELSE
     365       283736 :          i_a = 1
     366              :       END IF
     367       283736 :       IF (PRESENT(a_first_col)) THEN
     368            0 :          j_a = a_first_col
     369              :       ELSE
     370       283736 :          j_a = 1
     371              :       END IF
     372       283736 :       IF (PRESENT(b_first_row)) THEN
     373            0 :          i_b = b_first_row
     374              :       ELSE
     375       283736 :          i_b = 1
     376              :       END IF
     377       283736 :       IF (PRESENT(b_first_col)) THEN
     378            0 :          j_b = b_first_col
     379              :       ELSE
     380       283736 :          j_b = 1
     381              :       END IF
     382       283736 :       IF (PRESENT(c_first_row)) THEN
     383            0 :          i_c = c_first_row
     384              :       ELSE
     385       283736 :          i_c = 1
     386              :       END IF
     387       283736 :       IF (PRESENT(c_first_col)) THEN
     388            0 :          j_c = c_first_col
     389              :       ELSE
     390       283736 :          j_c = 1
     391              :       END IF
     392              : 
     393       283736 :       alpha_t(1) = REAL(alpha, KIND=dp)
     394       283736 :       alpha_t(2) = REAL(AIMAG(alpha), KIND=dp)
     395       283736 :       beta_t(1) = REAL(beta, KIND=dp)
     396       283736 :       beta_t(2) = REAL(AIMAG(beta), KIND=dp)
     397              : 
     398              :       CALL cosma_pzgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
     399              :                           alpha=C_LOC(alpha_t), &
     400              :                           a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
     401              :                           desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
     402              :                           b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
     403              :                           descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
     404              :                           beta=C_LOC(beta_t), &
     405              :                           c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
     406       283736 :                           descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
     407              : 
     408       283736 :    END SUBROUTINE cosma_pzgemm
     409              : #endif
     410              : 
     411              : END MODULE parallel_gemm_api
        

Generated by: LCOV version 2.0-1