Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : MODULE local_gemm_api
9 : USE ISO_C_BINDING, ONLY: C_LOC, &
10 : C_NULL_PTR, &
11 : C_PTR
12 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
13 : USE input_constants, ONLY: do_dgemm_spla
14 : USE spla, ONLY: SPLA_PU_HOST, &
15 : SPLA_PU_GPU, &
16 : SPLA_OP_NONE, &
17 : SPLA_OP_TRANSPOSE, &
18 : SPLA_OP_CONJ_TRANSPOSE, &
19 : spla_ctx_create, &
20 : spla_ctx_destroy, &
21 : spla_dgemm, &
22 : spla_sgemm, &
23 : spla_cgemm, &
24 : spla_zgemm, &
25 : spla_ctx_set_op_threshold_gpu, &
26 : SPLA_SUCCESS
27 : #endif
28 :
29 : USE offload_api, ONLY: offload_activate_chosen_device
30 :
31 : #include "./base/base_uses.f90"
32 :
33 : IMPLICIT NONE
34 :
35 : PRIVATE
36 :
37 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
38 :
39 : PUBLIC :: local_gemm, &
40 : local_gemm_create, &
41 : local_gemm_destroy, &
42 : local_gemm_set_op_threshold_gpu, &
43 : local_gemm_set_library
44 :
45 : INTEGER, PARAMETER, PUBLIC :: &
46 : LOCAL_GEMM_PU_HOST = 0, &
47 : LOCAL_GEMM_PU_GPU = 1
48 :
49 : INTEGER, PRIVATE :: do_dgemm = 1
50 :
51 : CONTAINS
52 :
53 : ! **************************************************************************************************
54 : !> \brief ...
55 : !> \param opA ...
56 : !> \param opB ...
57 : !> \param m ...
58 : !> \param n ...
59 : !> \param k ...
60 : !> \param alpha ...
61 : !> \param A ...
62 : !> \param lda ...
63 : !> \param B ...
64 : !> \param ldb ...
65 : !> \param beta ...
66 : !> \param C ...
67 : !> \param ldc ...
68 : !> \param ctx ...
69 : ! **************************************************************************************************
70 20528 : SUBROUTINE local_gemm(opA, opB, m, n, k, &
71 10264 : alpha, A, lda, B, ldb, &
72 10264 : beta, C, ldc, ctx)
73 : CHARACTER, INTENT(in) :: opA
74 : CHARACTER, INTENT(in) :: opB
75 : INTEGER, INTENT(in) :: m
76 : INTEGER, INTENT(in) :: n
77 : INTEGER, INTENT(in) :: k
78 : REAL(8), INTENT(in) :: alpha
79 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
80 : REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
81 : #else
82 : REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
83 : #endif
84 : INTEGER, INTENT(in) :: lda
85 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
86 : REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
87 : #else
88 : REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
89 : #endif
90 :
91 : INTEGER, INTENT(in) :: ldb
92 : REAL(8), INTENT(in) :: beta
93 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
94 : REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
95 : #else
96 : REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
97 : #endif
98 : INTEGER, INTENT(in) :: ldc
99 : TYPE(C_ptr), OPTIONAL, INTENT(inout) :: ctx
100 :
101 : INTEGER :: handle
102 : ! no point of using SPLA offloading on CPU ONLY nodes
103 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
104 : INTEGER :: spla_op_A, spla_op_B, spla_error
105 : #endif
106 : CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
107 10264 : CALL timeset(routineN, handle)
108 :
109 : ! no point of using SPLA offloading on CPU ONLY nodes
110 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
111 : IF (PRESENT(ctx) .AND. do_dgemm == do_dgemm_spla) THEN
112 :
113 : IF (opA == 'N') spla_op_A = SPLA_OP_NONE
114 : IF (opA == 'T') spla_op_A = SPLA_OP_TRANSPOSE
115 :
116 : IF (opB == 'N') spla_op_B = SPLA_OP_NONE
117 : IF (opB == 'T') spla_op_B = SPLA_OP_TRANSPOSE
118 :
119 : #if __GNUC__ >= 9
120 : CPASSERT(IS_CONTIGUOUS(A))
121 : CPASSERT(IS_CONTIGUOUS(B))
122 : CPASSERT(IS_CONTIGUOUS(C))
123 : #endif
124 :
125 : CALL offload_activate_chosen_device()
126 : spla_error = spla_dgemm(spla_op_A, spla_op_B, &
127 : m, n, k, alpha, &
128 : c_loc(A), lda, &
129 : c_loc(B), ldb, &
130 : beta, c_loc(C), ldc, ctx)
131 : CPASSERT(spla_error == SPLA_SUCCESS)
132 : ELSE
133 : #endif
134 : CALL dgemm(opA, opB, m, n, k, alpha, &
135 : A, lda, &
136 1480814 : B, ldb, beta, C, ldc)
137 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
138 : END IF
139 : #else
140 : MARK_USED(ctx)
141 : #endif
142 10264 : CALL timestop(handle)
143 :
144 10264 : END SUBROUTINE local_gemm
145 :
146 : ! **************************************************************************************************
147 : !> \brief create a context for handling gemm offloading
148 : !> \param ctx newly created context
149 : !> \param pu processing unit to run the (s,d,c,z}dgemm
150 : ! **************************************************************************************************
151 806 : SUBROUTINE local_gemm_create(ctx, pu)
152 : TYPE(c_ptr), INTENT(out) :: ctx
153 : INTEGER, INTENT(in) :: pu
154 :
155 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
156 : INTEGER :: error_
157 :
158 : IF (do_dgemm == do_dgemm_spla) THEN
159 : CALL offload_activate_chosen_device()
160 :
161 : error_ = spla_ctx_create(ctx, pu)
162 : CPASSERT(error_ == SPLA_SUCCESS)
163 : ELSE
164 : ctx = C_NULL_PTR
165 : END IF
166 : #else
167 : MARK_USED(pu)
168 : MARK_USED(ctx)
169 806 : ctx = C_NULL_PTR
170 : #endif
171 806 : END SUBROUTINE local_gemm_create
172 :
173 : ! **************************************************************************************************
174 : !> \brief release resources associated to a gemm context
175 : !> \param ctx handle
176 : ! **************************************************************************************************
177 806 : SUBROUTINE local_gemm_destroy(ctx)
178 : TYPE(c_ptr), INTENT(inout) :: ctx
179 :
180 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
181 : INTEGER :: error_
182 :
183 : IF (do_dgemm == do_dgemm_spla) THEN
184 : CALL offload_activate_chosen_device()
185 :
186 : error_ = spla_ctx_destroy(ctx)
187 : CPASSERT(error_ == SPLA_SUCCESS)
188 : END IF
189 : #else
190 : MARK_USED(ctx)
191 : #endif
192 806 : ctx = C_NULL_PTR
193 806 : END SUBROUTINE local_gemm_destroy
194 :
195 : ! **************************************************************************************************
196 : !> \brief ...
197 : !> \param ctx ...
198 : !> \param opThresholdGPU ...
199 : ! **************************************************************************************************
200 806 : SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
201 : TYPE(c_ptr) :: ctx
202 : INTEGER, INTENT(in) :: opThresholdGPU
203 :
204 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
205 : INTEGER :: error__
206 :
207 : CALL offload_activate_chosen_device()
208 : error__ = spla_ctx_set_op_threshold_gpu(ctx, opThresholdGPU)
209 : #else
210 : MARK_USED(ctx)
211 : MARK_USED(opThresholdGPU)
212 : #endif
213 806 : END SUBROUTINE local_gemm_set_op_threshold_gpu
214 :
215 : ! **************************************************************************************************
216 : !> \brief ...
217 : !> \param dgemm_library ...
218 : ! **************************************************************************************************
219 8989 : SUBROUTINE local_gemm_set_library(dgemm_library)
220 : INTEGER, INTENT(IN) :: dgemm_library
221 :
222 8989 : do_dgemm = dgemm_library
223 8989 : END SUBROUTINE local_gemm_set_library
224 :
225 : END MODULE local_gemm_api
|