Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : MODULE local_gemm_api
9 : USE ISO_C_BINDING, ONLY: C_NULL_PTR, &
10 : C_PTR
11 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
12 : USE input_constants, ONLY: do_dgemm_spla
13 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
14 : C_LOC
15 : USE spla, ONLY: SPLA_PU_HOST, &
16 : SPLA_PU_GPU, &
17 : SPLA_OP_NONE, &
18 : SPLA_OP_TRANSPOSE, &
19 : SPLA_OP_CONJ_TRANSPOSE, &
20 : spla_ctx_create, &
21 : spla_ctx_destroy, &
22 : spla_dgemm, &
23 : spla_sgemm, &
24 : spla_cgemm, &
25 : spla_zgemm, &
26 : spla_ctx_set_op_threshold_gpu, &
27 : SPLA_SUCCESS
28 : #endif
29 :
30 : USE cp_log_handling, ONLY: cp_to_string
31 : USE offload_api, ONLY: offload_activate_chosen_device
32 :
33 : #include "./base/base_uses.f90"
34 :
35 : IMPLICIT NONE
36 :
37 : PRIVATE
38 :
39 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
40 :
41 : PUBLIC :: local_gemm_ctxt_type, &
42 : local_gemm_set_library
43 :
44 : INTEGER, PARAMETER, PUBLIC :: &
45 : LOCAL_GEMM_PU_HOST = 0, &
46 : LOCAL_GEMM_PU_GPU = 1
47 :
48 : INTEGER, PRIVATE :: do_dgemm = 1
49 :
50 : TYPE local_gemm_ctxt_type
51 : TYPE(C_PTR) :: spla_context = C_NULL_PTR
52 : CONTAINS
53 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: create => local_gemm_create
54 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: destroy => local_gemm_destroy
55 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
56 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: gemm => local_gemm
57 : END TYPE
58 :
59 : CONTAINS
60 :
61 : ! **************************************************************************************************
62 : !> \brief ...
63 : !> \param opA ...
64 : !> \param opB ...
65 : !> \param m ...
66 : !> \param n ...
67 : !> \param k ...
68 : !> \param alpha ...
69 : !> \param A ...
70 : !> \param lda ...
71 : !> \param B ...
72 : !> \param ldb ...
73 : !> \param beta ...
74 : !> \param C ...
75 : !> \param ldc ...
76 : !> \param ctx ...
77 : ! **************************************************************************************************
78 106744 : SUBROUTINE local_gemm(opA, opB, m, n, k, &
79 53372 : alpha, A, lda, B, ldb, &
80 53372 : beta, C, ldc, ctx)
81 : CHARACTER, INTENT(in) :: opA
82 : CHARACTER, INTENT(in) :: opB
83 : INTEGER, INTENT(in) :: m
84 : INTEGER, INTENT(in) :: n
85 : INTEGER, INTENT(in) :: k
86 : REAL(8), INTENT(in) :: alpha
87 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
88 : REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
89 : #else
90 : REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
91 : #endif
92 : INTEGER, INTENT(in) :: lda
93 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
94 : REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
95 : #else
96 : REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
97 : #endif
98 :
99 : INTEGER, INTENT(in) :: ldb
100 : REAL(8), INTENT(in) :: beta
101 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
102 : REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
103 : #else
104 : REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
105 : #endif
106 : INTEGER, INTENT(in) :: ldc
107 : CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
108 :
109 : INTEGER :: handle
110 : ! no point of using SPLA offloading on CPU ONLY nodes
111 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
112 : INTEGER :: spla_op_A, spla_op_B, spla_error
113 : #endif
114 : CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
115 53372 : CALL timeset(routineN, handle)
116 :
117 : ! no point of using SPLA offloading on CPU ONLY nodes
118 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
119 : IF (do_dgemm == do_dgemm_spla) THEN
120 :
121 : IF (opA == 'N') spla_op_A = SPLA_OP_NONE
122 : IF (opA == 'T') spla_op_A = SPLA_OP_TRANSPOSE
123 :
124 : IF (opB == 'N') spla_op_B = SPLA_OP_NONE
125 : IF (opB == 'T') spla_op_B = SPLA_OP_TRANSPOSE
126 :
127 : #if __GNUC__ >= 9
128 : CPASSERT(IS_CONTIGUOUS(A))
129 : CPASSERT(IS_CONTIGUOUS(B))
130 : CPASSERT(IS_CONTIGUOUS(C))
131 : #endif
132 :
133 : CALL offload_activate_chosen_device()
134 : spla_error = spla_dgemm(spla_op_A, spla_op_B, &
135 : m, n, k, alpha, &
136 : c_loc(A), lda, &
137 : c_loc(B), ldb, &
138 : beta, c_loc(C), ldc, ctx%spla_context)
139 : IF (spla_error /= SPLA_SUCCESS) &
140 : CPABORT("spla_dgemm failed: "//cp_to_string(spla_error))
141 : ELSE
142 : #endif
143 : CALL dgemm(opA, opB, m, n, k, alpha, &
144 : A, lda, &
145 1523922 : B, ldb, beta, C, ldc)
146 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
147 : END IF
148 : #else
149 : MARK_USED(ctx)
150 : #endif
151 53372 : CALL timestop(handle)
152 :
153 53372 : END SUBROUTINE local_gemm
154 :
155 : ! **************************************************************************************************
156 : !> \brief create a context for handling gemm offloading
157 : !> \param ctx newly created context
158 : !> \param pu processing unit to run the (s,d,c,z}dgemm
159 : ! **************************************************************************************************
160 412 : SUBROUTINE local_gemm_create(ctx, pu)
161 : CLASS(local_gemm_ctxt_type), INTENT(out) :: ctx
162 : INTEGER, INTENT(in) :: pu
163 :
164 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
165 : INTEGER :: error_
166 :
167 : IF (.NOT. C_ASSOCIATED(ctx%spla_context)) THEN
168 : IF (do_dgemm == do_dgemm_spla) THEN
169 : CALL offload_activate_chosen_device()
170 :
171 : error_ = spla_ctx_create(ctx%spla_context, pu)
172 : IF (error_ /= SPLA_SUCCESS) &
173 : CPABORT("spla_ctx_create failed: "//cp_to_string(error_))
174 : ELSE
175 : ctx%spla_context = C_NULL_PTR
176 : END IF
177 : END IF
178 : #else
179 : MARK_USED(pu)
180 412 : ctx%spla_context = C_NULL_PTR
181 : #endif
182 412 : END SUBROUTINE local_gemm_create
183 :
184 : ! **************************************************************************************************
185 : !> \brief release resources associated to a gemm context
186 : !> \param ctx handle
187 : ! **************************************************************************************************
188 882 : SUBROUTINE local_gemm_destroy(ctx)
189 : CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
190 :
191 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
192 : INTEGER :: error_
193 :
194 : IF (do_dgemm == do_dgemm_spla) THEN
195 : CALL offload_activate_chosen_device()
196 :
197 : error_ = spla_ctx_destroy(ctx%spla_context)
198 : IF (error_ /= SPLA_SUCCESS) &
199 : CPABORT("spla_ctx_destroy failed: "//cp_to_string(error_))
200 : END IF
201 : #endif
202 882 : ctx%spla_context = C_NULL_PTR
203 882 : END SUBROUTINE local_gemm_destroy
204 :
205 : ! **************************************************************************************************
206 : !> \brief ...
207 : !> \param ctx ...
208 : !> \param opThresholdGPU ...
209 : ! **************************************************************************************************
210 412 : SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
211 : CLASS(local_gemm_ctxt_type), INTENT(INOUT) :: ctx
212 : INTEGER, INTENT(in) :: opThresholdGPU
213 :
214 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
215 : INTEGER :: error__
216 :
217 : CALL offload_activate_chosen_device()
218 : error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opThresholdGPU)
219 : #else
220 : MARK_USED(ctx)
221 : MARK_USED(opThresholdGPU)
222 : #endif
223 412 : END SUBROUTINE local_gemm_set_op_threshold_gpu
224 :
225 : ! **************************************************************************************************
226 : !> \brief ...
227 : !> \param dgemm_library ...
228 : ! **************************************************************************************************
229 9881 : SUBROUTINE local_gemm_set_library(dgemm_library)
230 : INTEGER, INTENT(IN) :: dgemm_library
231 :
232 9881 : do_dgemm = dgemm_library
233 9881 : END SUBROUTINE local_gemm_set_library
234 :
235 0 : END MODULE local_gemm_api
|