Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief basic linear algebra operations for full matrixes
10 : !> \par History
11 : !> 08.2002 splitted out of qs_blacs [fawzi]
12 : !> \author Fawzi Mohamed
13 : ! **************************************************************************************************
14 : MODULE parallel_gemm_api
15 : USE ISO_C_BINDING, ONLY: C_CHAR,&
16 : C_DOUBLE,&
17 : C_INT,&
18 : C_LOC,&
19 : C_PTR
20 : USE cp_cfm_basic_linalg, ONLY: cp_cfm_gemm
21 : USE cp_cfm_types, ONLY: cp_cfm_type
22 : USE cp_fm_basic_linalg, ONLY: cp_fm_gemm
23 : USE cp_fm_types, ONLY: cp_fm_get_mm_type,&
24 : cp_fm_set_all_submatrix,&
25 : cp_fm_type
26 : USE input_constants, ONLY: do_cosma,&
27 : do_scalapack
28 : USE kinds, ONLY: dp
29 : USE offload_api, ONLY: offload_activate_chosen_device
30 : #include "./base/base_uses.f90"
31 :
32 : IMPLICIT NONE
33 : PRIVATE
34 :
35 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'parallel_gemm_api'
36 :
37 : PUBLIC :: parallel_gemm
38 :
39 : INTERFACE parallel_gemm
40 : MODULE PROCEDURE parallel_gemm_fm
41 : MODULE PROCEDURE parallel_gemm_cfm
42 : END INTERFACE parallel_gemm
43 :
44 : CONTAINS
45 :
46 : ! **************************************************************************************************
47 : !> \brief ...
48 : !> \param transa ...
49 : !> \param transb ...
50 : !> \param m ...
51 : !> \param n ...
52 : !> \param k ...
53 : !> \param alpha ...
54 : !> \param matrix_a ...
55 : !> \param matrix_b ...
56 : !> \param beta ...
57 : !> \param matrix_c ...
58 : !> \param a_first_col ...
59 : !> \param a_first_row ...
60 : !> \param b_first_col ...
61 : !> \param b_first_row ...
62 : !> \param c_first_col ...
63 : !> \param c_first_row ...
64 : ! **************************************************************************************************
65 1155502 : SUBROUTINE parallel_gemm_fm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
66 : matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
67 : c_first_col, c_first_row)
68 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
69 : INTEGER, INTENT(IN) :: m, n, k
70 : REAL(KIND=dp), INTENT(IN) :: alpha
71 : TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
72 : REAL(KIND=dp), INTENT(IN) :: beta
73 : TYPE(cp_fm_type), INTENT(IN) :: matrix_c
74 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
75 : b_first_row, c_first_col, c_first_row
76 :
77 : CHARACTER(len=*), PARAMETER :: routineN = 'parallel_gemm_fm'
78 :
79 : INTEGER :: cfc, cfr, handle, my_multi
80 :
81 : MARK_USED(cfc)
82 : MARK_USED(cfr)
83 :
84 1155502 : my_multi = cp_fm_get_mm_type()
85 :
86 0 : SELECT CASE (my_multi)
87 : CASE (do_scalapack)
88 0 : CALL timeset(routineN//"_gemm", handle)
89 : CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
90 : a_first_col=a_first_col, &
91 : a_first_row=a_first_row, &
92 : b_first_col=b_first_col, &
93 : b_first_row=b_first_row, &
94 : c_first_col=c_first_col, &
95 0 : c_first_row=c_first_row)
96 : CASE (do_cosma)
97 : #if defined(__COSMA)
98 1155502 : CALL timeset(routineN//"_cosma", handle)
99 : !> This seems not to be correct in COSMA! See BLAS definition:
100 : !> On entry, BETA specifies the scalar beta. When BETA is
101 : !> supplied as zero then C need not be set on input.
102 1155502 : IF (beta == 0.0_dp) THEN
103 879555 : cfr = 1
104 879555 : cfc = 1
105 879555 : IF (PRESENT(c_first_row)) cfr = c_first_row
106 879555 : IF (PRESENT(c_first_col)) cfc = c_first_col
107 879555 : CALL cp_fm_set_all_submatrix(matrix_c, 0.0_dp, cfr, cfc, m, n)
108 : END IF
109 1155502 : CALL offload_activate_chosen_device()
110 : CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
111 : matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
112 : a_first_col=a_first_col, &
113 : a_first_row=a_first_row, &
114 : b_first_col=b_first_col, &
115 : b_first_row=b_first_row, &
116 : c_first_col=c_first_col, &
117 2311004 : c_first_row=c_first_row)
118 : #else
119 : CPABORT("CP2K compiled without the COSMA library.")
120 : #endif
121 : END SELECT
122 1155502 : CALL timestop(handle)
123 :
124 1155502 : END SUBROUTINE parallel_gemm_fm
125 :
126 : ! **************************************************************************************************
127 : !> \brief ...
128 : !> \param transa ...
129 : !> \param transb ...
130 : !> \param m ...
131 : !> \param n ...
132 : !> \param k ...
133 : !> \param alpha ...
134 : !> \param matrix_a ...
135 : !> \param matrix_b ...
136 : !> \param beta ...
137 : !> \param matrix_c ...
138 : !> \param a_first_col ...
139 : !> \param a_first_row ...
140 : !> \param b_first_col ...
141 : !> \param b_first_row ...
142 : !> \param c_first_col ...
143 : !> \param c_first_row ...
144 : ! **************************************************************************************************
145 283736 : SUBROUTINE parallel_gemm_cfm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
146 : matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
147 : c_first_col, c_first_row)
148 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
149 : INTEGER, INTENT(IN) :: m, n, k
150 : COMPLEX(KIND=dp), INTENT(IN) :: alpha
151 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_a, matrix_b
152 : COMPLEX(KIND=dp), INTENT(IN) :: beta
153 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_c
154 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
155 : b_first_row, c_first_col, c_first_row
156 :
157 : CHARACTER(len=*), PARAMETER :: routineN = 'parallel_gemm_cfm'
158 :
159 : INTEGER :: handle, handle1, my_multi
160 :
161 283736 : CALL timeset(routineN, handle)
162 :
163 283736 : my_multi = cp_fm_get_mm_type()
164 :
165 0 : SELECT CASE (my_multi)
166 : CASE (do_scalapack)
167 0 : CALL timeset(routineN//"_gemm", handle1)
168 : CALL cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
169 : a_first_col=a_first_col, &
170 : a_first_row=a_first_row, &
171 : b_first_col=b_first_col, &
172 : b_first_row=b_first_row, &
173 : c_first_col=c_first_col, &
174 0 : c_first_row=c_first_row)
175 0 : CALL timestop(handle1)
176 : CASE (do_cosma)
177 : #if defined(__COSMA)
178 283736 : CALL timeset(routineN//"_cosma", handle1)
179 283736 : CALL offload_activate_chosen_device()
180 : CALL cosma_pzgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
181 : matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
182 : a_first_col=a_first_col, &
183 : a_first_row=a_first_row, &
184 : b_first_col=b_first_col, &
185 : b_first_row=b_first_row, &
186 : c_first_col=c_first_col, &
187 283736 : c_first_row=c_first_row)
188 567472 : CALL timestop(handle1)
189 : #else
190 : CPABORT("CP2K compiled without the COSMA library.")
191 : #endif
192 : END SELECT
193 283736 : CALL timestop(handle)
194 :
195 283736 : END SUBROUTINE parallel_gemm_cfm
196 :
197 : #if defined(__COSMA)
198 : ! **************************************************************************************************
199 : !> \brief Fortran wrapper for cosma_pdgemm.
200 : !> \param transa ...
201 : !> \param transb ...
202 : !> \param m ...
203 : !> \param n ...
204 : !> \param k ...
205 : !> \param alpha ...
206 : !> \param matrix_a ...
207 : !> \param matrix_b ...
208 : !> \param beta ...
209 : !> \param matrix_c ...
210 : !> \param a_first_col ...
211 : !> \param a_first_row ...
212 : !> \param b_first_col ...
213 : !> \param b_first_row ...
214 : !> \param c_first_col ...
215 : !> \param c_first_row ...
216 : !> \author Ole Schuett
217 : ! **************************************************************************************************
218 1155502 : SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
219 : a_first_col, a_first_row, b_first_col, b_first_row, &
220 : c_first_col, c_first_row)
221 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
222 : INTEGER, INTENT(IN) :: m, n, k
223 : REAL(KIND=dp), INTENT(IN) :: alpha
224 : TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
225 : REAL(KIND=dp), INTENT(IN) :: beta
226 : TYPE(cp_fm_type), INTENT(IN) :: matrix_c
227 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
228 : b_first_row, c_first_col, c_first_row
229 :
230 : INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
231 : INTERFACE
232 : SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
233 : b, ib, jb, descb, beta, c, ic, jc, descc) &
234 : BIND(C, name="cosma_pdgemm")
235 : IMPORT :: C_PTR, C_INT, C_DOUBLE, C_CHAR
236 : CHARACTER(KIND=C_CHAR) :: transa
237 : CHARACTER(KIND=C_CHAR) :: transb
238 : INTEGER(KIND=C_INT) :: m
239 : INTEGER(KIND=C_INT) :: n
240 : INTEGER(KIND=C_INT) :: k
241 : REAL(KIND=C_DOUBLE) :: alpha
242 : TYPE(C_PTR), VALUE :: a
243 : INTEGER(KIND=C_INT) :: ia
244 : INTEGER(KIND=C_INT) :: ja
245 : TYPE(C_PTR), VALUE :: desca
246 : TYPE(C_PTR), VALUE :: b
247 : INTEGER(KIND=C_INT) :: ib
248 : INTEGER(KIND=C_INT) :: jb
249 : TYPE(C_PTR), VALUE :: descb
250 : REAL(KIND=C_DOUBLE) :: beta
251 : TYPE(C_PTR), VALUE :: c
252 : INTEGER(KIND=C_INT) :: ic
253 : INTEGER(KIND=C_INT) :: jc
254 : TYPE(C_PTR), VALUE :: descc
255 : END SUBROUTINE cosma_pdgemm_c
256 : END INTERFACE
257 :
258 1155502 : IF (PRESENT(a_first_row)) THEN
259 2742 : i_a = a_first_row
260 : ELSE
261 1152760 : i_a = 1
262 : END IF
263 1155502 : IF (PRESENT(a_first_col)) THEN
264 2742 : j_a = a_first_col
265 : ELSE
266 1152760 : j_a = 1
267 : END IF
268 1155502 : IF (PRESENT(b_first_row)) THEN
269 3100 : i_b = b_first_row
270 : ELSE
271 1152402 : i_b = 1
272 : END IF
273 1155502 : IF (PRESENT(b_first_col)) THEN
274 4052 : j_b = b_first_col
275 : ELSE
276 1151450 : j_b = 1
277 : END IF
278 1155502 : IF (PRESENT(c_first_row)) THEN
279 2498 : i_c = c_first_row
280 : ELSE
281 1153004 : i_c = 1
282 : END IF
283 1155502 : IF (PRESENT(c_first_col)) THEN
284 2516 : j_c = c_first_col
285 : ELSE
286 1152986 : j_c = 1
287 : END IF
288 :
289 : CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
290 : alpha=alpha, &
291 : a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
292 : desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
293 : b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
294 : descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
295 : beta=beta, &
296 : c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
297 1155502 : descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
298 :
299 1155502 : END SUBROUTINE cosma_pdgemm
300 :
301 : ! **************************************************************************************************
302 : !> \brief Fortran wrapper for cosma_pdgemm.
303 : !> \param transa ...
304 : !> \param transb ...
305 : !> \param m ...
306 : !> \param n ...
307 : !> \param k ...
308 : !> \param alpha ...
309 : !> \param matrix_a ...
310 : !> \param matrix_b ...
311 : !> \param beta ...
312 : !> \param matrix_c ...
313 : !> \param a_first_col ...
314 : !> \param a_first_row ...
315 : !> \param b_first_col ...
316 : !> \param b_first_row ...
317 : !> \param c_first_col ...
318 : !> \param c_first_row ...
319 : !> \author Ole Schuett
320 : ! **************************************************************************************************
321 283736 : SUBROUTINE cosma_pzgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
322 : a_first_col, a_first_row, b_first_col, b_first_row, &
323 : c_first_col, c_first_row)
324 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
325 : INTEGER, INTENT(IN) :: m, n, k
326 : COMPLEX(KIND=dp), INTENT(IN) :: alpha
327 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_a, matrix_b
328 : COMPLEX(KIND=dp), INTENT(IN) :: beta
329 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_c
330 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
331 : b_first_row, c_first_col, c_first_row
332 :
333 : INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
334 : REAL(KIND=dp), DIMENSION(2), TARGET :: alpha_t, beta_t
335 : INTERFACE
336 : SUBROUTINE cosma_pzgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
337 : b, ib, jb, descb, beta, c, ic, jc, descc) &
338 : BIND(C, name="cosma_pzgemm")
339 : IMPORT :: C_PTR, C_INT, C_CHAR
340 : CHARACTER(KIND=C_CHAR) :: transa
341 : CHARACTER(KIND=C_CHAR) :: transb
342 : INTEGER(KIND=C_INT) :: m
343 : INTEGER(KIND=C_INT) :: n
344 : INTEGER(KIND=C_INT) :: k
345 : TYPE(C_PTR), VALUE :: alpha
346 : TYPE(C_PTR), VALUE :: a
347 : INTEGER(KIND=C_INT) :: ia
348 : INTEGER(KIND=C_INT) :: ja
349 : TYPE(C_PTR), VALUE :: desca
350 : TYPE(C_PTR), VALUE :: b
351 : INTEGER(KIND=C_INT) :: ib
352 : INTEGER(KIND=C_INT) :: jb
353 : TYPE(C_PTR), VALUE :: descb
354 : TYPE(C_PTR), VALUE :: beta
355 : TYPE(C_PTR), VALUE :: c
356 : INTEGER(KIND=C_INT) :: ic
357 : INTEGER(KIND=C_INT) :: jc
358 : TYPE(C_PTR), VALUE :: descc
359 : END SUBROUTINE cosma_pzgemm_c
360 : END INTERFACE
361 :
362 283736 : IF (PRESENT(a_first_row)) THEN
363 0 : i_a = a_first_row
364 : ELSE
365 283736 : i_a = 1
366 : END IF
367 283736 : IF (PRESENT(a_first_col)) THEN
368 0 : j_a = a_first_col
369 : ELSE
370 283736 : j_a = 1
371 : END IF
372 283736 : IF (PRESENT(b_first_row)) THEN
373 0 : i_b = b_first_row
374 : ELSE
375 283736 : i_b = 1
376 : END IF
377 283736 : IF (PRESENT(b_first_col)) THEN
378 0 : j_b = b_first_col
379 : ELSE
380 283736 : j_b = 1
381 : END IF
382 283736 : IF (PRESENT(c_first_row)) THEN
383 0 : i_c = c_first_row
384 : ELSE
385 283736 : i_c = 1
386 : END IF
387 283736 : IF (PRESENT(c_first_col)) THEN
388 0 : j_c = c_first_col
389 : ELSE
390 283736 : j_c = 1
391 : END IF
392 :
393 283736 : alpha_t(1) = REAL(alpha, KIND=dp)
394 283736 : alpha_t(2) = REAL(AIMAG(alpha), KIND=dp)
395 283736 : beta_t(1) = REAL(beta, KIND=dp)
396 283736 : beta_t(2) = REAL(AIMAG(beta), KIND=dp)
397 :
398 : CALL cosma_pzgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
399 : alpha=C_LOC(alpha_t), &
400 : a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
401 : desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
402 : b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
403 : descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
404 : beta=C_LOC(beta_t), &
405 : c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
406 283736 : descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
407 :
408 283736 : END SUBROUTINE cosma_pzgemm
409 : #endif
410 :
411 : END MODULE parallel_gemm_api
|