LCOV - code coverage report
Current view: top level - src - local_gemm_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 94.7 % 19 18
Test Date: 2025-12-04 06:27:48 Functions: 83.3 % 6 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : MODULE local_gemm_api
       9              :    USE ISO_C_BINDING, ONLY: C_NULL_PTR, &
      10              :                             C_PTR
      11              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      12              :    USE input_constants, ONLY: do_dgemm_spla
      13              :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
      14              :                             C_LOC
      15              :    USE spla, ONLY: SPLA_PU_HOST, &
      16              :                    SPLA_PU_GPU, &
      17              :                    SPLA_OP_NONE, &
      18              :                    SPLA_OP_TRANSPOSE, &
      19              :                    SPLA_OP_CONJ_TRANSPOSE, &
      20              :                    spla_ctx_create, &
      21              :                    spla_ctx_destroy, &
      22              :                    spla_dgemm, &
      23              :                    spla_sgemm, &
      24              :                    spla_cgemm, &
      25              :                    spla_zgemm, &
      26              :                    spla_ctx_set_op_threshold_gpu, &
      27              :                    SPLA_SUCCESS
      28              : #endif
      29              : 
      30              :    USE cp_log_handling, ONLY: cp_to_string
      31              :    USE offload_api, ONLY: offload_activate_chosen_device
      32              : 
      33              : #include "./base/base_uses.f90"
      34              : 
      35              :    IMPLICIT NONE
      36              : 
      37              :    PRIVATE
      38              : 
      39              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
      40              : 
      41              :    PUBLIC :: local_gemm_ctxt_type, &
      42              :              local_gemm_set_library
      43              : 
      44              :    INTEGER, PARAMETER, PUBLIC :: &
      45              :       LOCAL_GEMM_PU_HOST = 0, &
      46              :       LOCAL_GEMM_PU_GPU = 1
      47              : 
      48              :    INTEGER, PRIVATE :: do_dgemm = 1
      49              : 
      50              :    TYPE local_gemm_ctxt_type
      51              :       TYPE(C_PTR) :: spla_context = C_NULL_PTR
      52              :    CONTAINS
      53              :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: create => local_gemm_create
      54              :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: destroy => local_gemm_destroy
      55              :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
      56              :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: gemm => local_gemm
      57              :    END TYPE
      58              : 
      59              : CONTAINS
      60              : 
      61              : ! **************************************************************************************************
      62              : !> \brief ...
      63              : !> \param opA ...
      64              : !> \param opB ...
      65              : !> \param m ...
      66              : !> \param n ...
      67              : !> \param k ...
      68              : !> \param alpha ...
      69              : !> \param A ...
      70              : !> \param lda ...
      71              : !> \param B ...
      72              : !> \param ldb ...
      73              : !> \param beta ...
      74              : !> \param C ...
      75              : !> \param ldc ...
      76              : !> \param ctx ...
      77              : ! **************************************************************************************************
      78       106744 :    SUBROUTINE local_gemm(opA, opB, m, n, k, &
      79        53372 :                          alpha, A, lda, B, ldb, &
      80        53372 :                          beta, C, ldc, ctx)
      81              :       CHARACTER, INTENT(in) :: opA
      82              :       CHARACTER, INTENT(in) :: opB
      83              :       INTEGER, INTENT(in) :: m
      84              :       INTEGER, INTENT(in) :: n
      85              :       INTEGER, INTENT(in) :: k
      86              :       REAL(8), INTENT(in) :: alpha
      87              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      88              :       REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
      89              : #else
      90              :       REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
      91              : #endif
      92              :       INTEGER, INTENT(in) :: lda
      93              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      94              :       REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
      95              : #else
      96              :       REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
      97              : #endif
      98              : 
      99              :       INTEGER, INTENT(in) :: ldb
     100              :       REAL(8), INTENT(in) :: beta
     101              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     102              :       REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
     103              : #else
     104              :       REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
     105              : #endif
     106              :       INTEGER, INTENT(in) :: ldc
     107              :       CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
     108              : 
     109              :       INTEGER                                            :: handle
     110              : !     no point of using SPLA offloading on CPU ONLY nodes
     111              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     112              :       INTEGER :: spla_op_A, spla_op_B, spla_error
     113              : #endif
     114              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
     115        53372 :       CALL timeset(routineN, handle)
     116              : 
     117              : !     no point of using SPLA offloading on CPU ONLY nodes
     118              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     119              :       IF (do_dgemm == do_dgemm_spla) THEN
     120              : 
     121              :          IF (opA == 'N') spla_op_A = SPLA_OP_NONE
     122              :          IF (opA == 'T') spla_op_A = SPLA_OP_TRANSPOSE
     123              : 
     124              :          IF (opB == 'N') spla_op_B = SPLA_OP_NONE
     125              :          IF (opB == 'T') spla_op_B = SPLA_OP_TRANSPOSE
     126              : 
     127              : #if __GNUC__ >= 9
     128              :          CPASSERT(IS_CONTIGUOUS(A))
     129              :          CPASSERT(IS_CONTIGUOUS(B))
     130              :          CPASSERT(IS_CONTIGUOUS(C))
     131              : #endif
     132              : 
     133              :          CALL offload_activate_chosen_device()
     134              :          spla_error = spla_dgemm(spla_op_A, spla_op_B, &
     135              :                                  m, n, k, alpha, &
     136              :                                  c_loc(A), lda, &
     137              :                                  c_loc(B), ldb, &
     138              :                                  beta, c_loc(C), ldc, ctx%spla_context)
     139              :          IF (spla_error /= SPLA_SUCCESS) &
     140              :             CPABORT("spla_dgemm failed: "//cp_to_string(spla_error))
     141              :       ELSE
     142              : #endif
     143              :          CALL dgemm(opA, opB, m, n, k, alpha, &
     144              :                     A, lda, &
     145      1523922 :                     B, ldb, beta, C, ldc)
     146              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     147              :       END IF
     148              : #else
     149              :       MARK_USED(ctx)
     150              : #endif
     151        53372 :       CALL timestop(handle)
     152              : 
     153        53372 :    END SUBROUTINE local_gemm
     154              : 
     155              : ! **************************************************************************************************
     156              : !> \brief create a context for handling gemm offloading
     157              : !> \param ctx newly created context
     158              : !> \param pu processing unit to run the (s,d,c,z}dgemm
     159              : ! **************************************************************************************************
     160          412 :    SUBROUTINE local_gemm_create(ctx, pu)
     161              :       CLASS(local_gemm_ctxt_type), INTENT(out) :: ctx
     162              :       INTEGER, INTENT(in) :: pu
     163              : 
     164              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     165              :       INTEGER :: error_
     166              : 
     167              :       IF (.NOT. C_ASSOCIATED(ctx%spla_context)) THEN
     168              :          IF (do_dgemm == do_dgemm_spla) THEN
     169              :             CALL offload_activate_chosen_device()
     170              : 
     171              :             error_ = spla_ctx_create(ctx%spla_context, pu)
     172              :             IF (error_ /= SPLA_SUCCESS) &
     173              :                CPABORT("spla_ctx_create failed: "//cp_to_string(error_))
     174              :          ELSE
     175              :             ctx%spla_context = C_NULL_PTR
     176              :          END IF
     177              :       END IF
     178              : #else
     179              :       MARK_USED(pu)
     180          412 :       ctx%spla_context = C_NULL_PTR
     181              : #endif
     182          412 :    END SUBROUTINE local_gemm_create
     183              : 
     184              : ! **************************************************************************************************
     185              : !> \brief release resources associated to a gemm context
     186              : !> \param ctx handle
     187              : ! **************************************************************************************************
     188          882 :    SUBROUTINE local_gemm_destroy(ctx)
     189              :       CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
     190              : 
     191              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     192              :       INTEGER :: error_
     193              : 
     194              :       IF (do_dgemm == do_dgemm_spla) THEN
     195              :          CALL offload_activate_chosen_device()
     196              : 
     197              :          error_ = spla_ctx_destroy(ctx%spla_context)
     198              :          IF (error_ /= SPLA_SUCCESS) &
     199              :             CPABORT("spla_ctx_destroy failed: "//cp_to_string(error_))
     200              :       END IF
     201              : #endif
     202          882 :       ctx%spla_context = C_NULL_PTR
     203          882 :    END SUBROUTINE local_gemm_destroy
     204              : 
     205              : ! **************************************************************************************************
     206              : !> \brief ...
     207              : !> \param ctx ...
     208              : !> \param opThresholdGPU ...
     209              : ! **************************************************************************************************
     210          412 :    SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
     211              :       CLASS(local_gemm_ctxt_type), INTENT(INOUT)                                        :: ctx
     212              :       INTEGER, INTENT(in)                                :: opThresholdGPU
     213              : 
     214              : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     215              :       INTEGER                                            :: error__
     216              : 
     217              :       CALL offload_activate_chosen_device()
     218              :       error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opThresholdGPU)
     219              : #else
     220              :       MARK_USED(ctx)
     221              :       MARK_USED(opThresholdGPU)
     222              : #endif
     223          412 :    END SUBROUTINE local_gemm_set_op_threshold_gpu
     224              : 
     225              : ! **************************************************************************************************
     226              : !> \brief ...
     227              : !> \param dgemm_library ...
     228              : ! **************************************************************************************************
     229         9881 :    SUBROUTINE local_gemm_set_library(dgemm_library)
     230              :       INTEGER, INTENT(IN)                                :: dgemm_library
     231              : 
     232         9881 :       do_dgemm = dgemm_library
     233         9881 :    END SUBROUTINE local_gemm_set_library
     234              : 
     235            0 : END MODULE local_gemm_api
        

Generated by: LCOV version 2.0-1