LCOV - code coverage report
Current view: top level - src - local_gemm_api.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:e7e05ae) Lines: 18 18 100.0 %
Date: 2024-04-18 06:59:28 Functions: 5 5 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : MODULE local_gemm_api
       9             :    USE ISO_C_BINDING, ONLY: C_LOC, &
      10             :                             C_NULL_PTR, &
      11             :                             C_PTR
      12             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      13             :    USE input_constants, ONLY: do_dgemm_spla
      14             :    USE spla, ONLY: SPLA_PU_HOST, &
      15             :                    SPLA_PU_GPU, &
      16             :                    SPLA_OP_NONE, &
      17             :                    SPLA_OP_TRANSPOSE, &
      18             :                    SPLA_OP_CONJ_TRANSPOSE, &
      19             :                    spla_ctx_create, &
      20             :                    spla_ctx_destroy, &
      21             :                    spla_dgemm, &
      22             :                    spla_sgemm, &
      23             :                    spla_cgemm, &
      24             :                    spla_zgemm, &
      25             :                    spla_ctx_set_op_threshold_gpu, &
      26             :                    SPLA_SUCCESS
      27             : #endif
      28             : 
      29             :    USE offload_api, ONLY: offload_activate_chosen_device
      30             : 
      31             : #include "./base/base_uses.f90"
      32             : 
      33             :    IMPLICIT NONE
      34             : 
      35             :    PRIVATE
      36             : 
      37             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
      38             : 
      39             :    PUBLIC :: local_gemm, &
      40             :              local_gemm_create, &
      41             :              local_gemm_destroy, &
      42             :              local_gemm_set_op_threshold_gpu, &
      43             :              local_gemm_set_library
      44             : 
      45             :    INTEGER, PARAMETER, PUBLIC :: &
      46             :       LOCAL_GEMM_PU_HOST = 0, &
      47             :       LOCAL_GEMM_PU_GPU = 1
      48             : 
      49             :    INTEGER, PRIVATE :: do_dgemm = 1
      50             : 
      51             : CONTAINS
      52             : 
      53             : ! **************************************************************************************************
      54             : !> \brief ...
      55             : !> \param opA ...
      56             : !> \param opB ...
      57             : !> \param m ...
      58             : !> \param n ...
      59             : !> \param k ...
      60             : !> \param alpha ...
      61             : !> \param A ...
      62             : !> \param lda ...
      63             : !> \param B ...
      64             : !> \param ldb ...
      65             : !> \param beta ...
      66             : !> \param C ...
      67             : !> \param ldc ...
      68             : !> \param ctx ...
      69             : ! **************************************************************************************************
      70       20528 :    SUBROUTINE local_gemm(opA, opB, m, n, k, &
      71       10264 :                          alpha, A, lda, B, ldb, &
      72       10264 :                          beta, C, ldc, ctx)
      73             :       CHARACTER, INTENT(in) :: opA
      74             :       CHARACTER, INTENT(in) :: opB
      75             :       INTEGER, INTENT(in) :: m
      76             :       INTEGER, INTENT(in) :: n
      77             :       INTEGER, INTENT(in) :: k
      78             :       REAL(8), INTENT(in) :: alpha
      79             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      80             :       REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
      81             : #else
      82             :       REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
      83             : #endif
      84             :       INTEGER, INTENT(in) :: lda
      85             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      86             :       REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
      87             : #else
      88             :       REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
      89             : #endif
      90             : 
      91             :       INTEGER, INTENT(in) :: ldb
      92             :       REAL(8), INTENT(in) :: beta
      93             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      94             :       REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
      95             : #else
      96             :       REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
      97             : #endif
      98             :       INTEGER, INTENT(in) :: ldc
      99             :       TYPE(C_ptr), OPTIONAL, INTENT(inout) :: ctx
     100             : 
     101             :       INTEGER                                            :: handle
     102             : !     no point of using SPLA offloading on CPU ONLY nodes
     103             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     104             :       INTEGER :: spla_op_A, spla_op_B, spla_error
     105             : #endif
     106             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
     107       10264 :       CALL timeset(routineN, handle)
     108             : 
     109             : !     no point of using SPLA offloading on CPU ONLY nodes
     110             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     111             :       IF (PRESENT(ctx) .AND. do_dgemm == do_dgemm_spla) THEN
     112             : 
     113             :          IF (opA == 'N') spla_op_A = SPLA_OP_NONE
     114             :          IF (opA == 'T') spla_op_A = SPLA_OP_TRANSPOSE
     115             : 
     116             :          IF (opB == 'N') spla_op_B = SPLA_OP_NONE
     117             :          IF (opB == 'T') spla_op_B = SPLA_OP_TRANSPOSE
     118             : 
     119             : #if __GNUC__ >= 9
     120             :          CPASSERT(IS_CONTIGUOUS(A))
     121             :          CPASSERT(IS_CONTIGUOUS(B))
     122             :          CPASSERT(IS_CONTIGUOUS(C))
     123             : #endif
     124             : 
     125             :          CALL offload_activate_chosen_device()
     126             :          spla_error = spla_dgemm(spla_op_A, spla_op_B, &
     127             :                                  m, n, k, alpha, &
     128             :                                  c_loc(A), lda, &
     129             :                                  c_loc(B), ldb, &
     130             :                                  beta, c_loc(C), ldc, ctx)
     131             :          CPASSERT(spla_error == SPLA_SUCCESS)
     132             :       ELSE
     133             : #endif
     134             :          CALL dgemm(opA, opB, m, n, k, alpha, &
     135             :                     A, lda, &
     136     1480814 :                     B, ldb, beta, C, ldc)
     137             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     138             :       END IF
     139             : #else
     140             :       MARK_USED(ctx)
     141             : #endif
     142       10264 :       CALL timestop(handle)
     143             : 
     144       10264 :    END SUBROUTINE local_gemm
     145             : 
     146             : ! **************************************************************************************************
     147             : !> \brief create a context for handling gemm offloading
     148             : !> \param ctx newly created context
     149             : !> \param pu processing unit to run the (s,d,c,z}dgemm
     150             : ! **************************************************************************************************
     151         806 :    SUBROUTINE local_gemm_create(ctx, pu)
     152             :       TYPE(c_ptr), INTENT(out) :: ctx
     153             :       INTEGER, INTENT(in) :: pu
     154             : 
     155             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     156             :       INTEGER :: error_
     157             : 
     158             :       IF (do_dgemm == do_dgemm_spla) THEN
     159             :          CALL offload_activate_chosen_device()
     160             : 
     161             :          error_ = spla_ctx_create(ctx, pu)
     162             :          CPASSERT(error_ == SPLA_SUCCESS)
     163             :       ELSE
     164             :          ctx = C_NULL_PTR
     165             :       END IF
     166             : #else
     167             :       MARK_USED(pu)
     168             :       MARK_USED(ctx)
     169         806 :       ctx = C_NULL_PTR
     170             : #endif
     171         806 :    END SUBROUTINE local_gemm_create
     172             : 
     173             : ! **************************************************************************************************
     174             : !> \brief release resources associated to a gemm context
     175             : !> \param ctx handle
     176             : ! **************************************************************************************************
     177         806 :    SUBROUTINE local_gemm_destroy(ctx)
     178             :       TYPE(c_ptr), INTENT(inout) :: ctx
     179             : 
     180             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     181             :       INTEGER :: error_
     182             : 
     183             :       IF (do_dgemm == do_dgemm_spla) THEN
     184             :          CALL offload_activate_chosen_device()
     185             : 
     186             :          error_ = spla_ctx_destroy(ctx)
     187             :          CPASSERT(error_ == SPLA_SUCCESS)
     188             :       END IF
     189             : #else
     190             :       MARK_USED(ctx)
     191             : #endif
     192         806 :       ctx = C_NULL_PTR
     193         806 :    END SUBROUTINE local_gemm_destroy
     194             : 
     195             : ! **************************************************************************************************
     196             : !> \brief ...
     197             : !> \param ctx ...
     198             : !> \param opThresholdGPU ...
     199             : ! **************************************************************************************************
     200         806 :    SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
     201             :       TYPE(c_ptr)                                        :: ctx
     202             :       INTEGER, INTENT(in)                                :: opThresholdGPU
     203             : 
     204             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     205             :       INTEGER                                            :: error__
     206             : 
     207             :       CALL offload_activate_chosen_device()
     208             :       error__ = spla_ctx_set_op_threshold_gpu(ctx, opThresholdGPU)
     209             : #else
     210             :       MARK_USED(ctx)
     211             :       MARK_USED(opThresholdGPU)
     212             : #endif
     213         806 :    END SUBROUTINE local_gemm_set_op_threshold_gpu
     214             : 
     215             : ! **************************************************************************************************
     216             : !> \brief ...
     217             : !> \param dgemm_library ...
     218             : ! **************************************************************************************************
     219        8989 :    SUBROUTINE local_gemm_set_library(dgemm_library)
     220             :       INTEGER, INTENT(IN)                                :: dgemm_library
     221             : 
     222        8989 :       do_dgemm = dgemm_library
     223        8989 :    END SUBROUTINE local_gemm_set_library
     224             : 
     225             : END MODULE local_gemm_api

Generated by: LCOV version 1.15