Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_multiply.h"
8 : #include "../offload/offload_mempool.h"
9 : #include "../offload/offload_runtime.h"
10 : #include "dbm_hyperparams.h"
11 : #include "dbm_internal.h"
12 : #include "dbm_library.h"
13 : #include "dbm_multiply_comm.h"
14 : #include "dbm_multiply_cpu.h"
15 : #include "dbm_multiply_gpu.h"
16 :
17 : #include <assert.h>
18 : #include <limits.h>
19 : #include <math.h>
20 : #include <omp.h>
21 : #include <stdio.h>
22 : #include <stdlib.h>
23 : #include <string.h>
24 :
25 : /*******************************************************************************
26 : * \brief Private routine for computing the max filter threshold for each row.
27 : * \author Ole Schuett
28 : ******************************************************************************/
29 242490 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
30 : const double filter_eps) {
31 242490 : const int nrows = (trans) ? matrix->ncols : matrix->nrows;
32 242490 : int *nblocks_per_row = calloc(nrows, sizeof(int));
33 242490 : float *row_max_eps = malloc(nrows * sizeof(float));
34 242490 : assert((nblocks_per_row != NULL && row_max_eps != NULL) || nrows == 0);
35 :
36 242490 : #pragma omp parallel
37 : {
38 : #pragma omp for
39 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
40 : dbm_shard_t *shard = &matrix->shards[ishard];
41 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
42 : const dbm_block_t *blk = &shard->blocks[iblock];
43 : const int row = (trans) ? blk->col : blk->row;
44 : #pragma omp atomic
45 : ++nblocks_per_row[row];
46 : }
47 : }
48 : #pragma omp master
49 : cp_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
50 : #pragma omp barrier
51 : #pragma omp for
52 : for (int i = 0; i < nrows; i++) {
53 : const float f =
54 : ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
55 : row_max_eps[i] = f * f;
56 : }
57 : } // end of omp parallel region
58 :
59 242490 : free(nblocks_per_row);
60 242490 : return row_max_eps; // Ownership of row_max_eps transfers to caller.
61 : }
62 :
63 : /*******************************************************************************
64 : * \brief Private struct for storing the context of the multiplication backend.
65 : * \author Ole Schuett
66 : ******************************************************************************/
67 : typedef struct {
68 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
69 : dbm_multiply_gpu_context_t gpu;
70 : #endif
71 : int cpu_options; // Binary or'ed dbm_multiply_cpu_options (enum).
72 : } backend_context_t;
73 :
74 : /*******************************************************************************
75 : * \brief Private routine for initializing the multiplication backend.
76 : * \author Ole Schuett
77 : ******************************************************************************/
78 242490 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
79 242490 : backend_context_t *const ctx = calloc(1, sizeof(backend_context_t));
80 : // BLAS and LIBXS benefit in general from DBM_MULTIPLY_TASK_REORDER.
81 242490 : ctx->cpu_options = DBM_MULTIPLY_TASK_REORDER;
82 :
83 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
84 : dbm_multiply_gpu_start(DBM_MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
85 : matrix_c->shards, &ctx->gpu);
86 : #else
87 242490 : (void)matrix_c; // mark as used
88 : #endif
89 :
90 242490 : return ctx;
91 : }
92 :
93 : /*******************************************************************************
94 : * \brief Private routine for handing newly arrived packs to the backend.
95 : * \author Ole Schuett
96 : ******************************************************************************/
97 0 : static bool backend_upload_packs(const dbm_pack_t *pack_a,
98 : const dbm_pack_t *pack_b,
99 : backend_context_t *ctx) {
100 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
101 : return dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
102 : #else
103 0 : (void)pack_a; // mark as used
104 0 : (void)pack_b;
105 0 : (void)ctx;
106 0 : return false;
107 : #endif
108 : }
109 :
110 : /*******************************************************************************
111 : * \brief Private routine for sending a batch to the multiplication backend.
112 : * \author Ole Schuett
113 : ******************************************************************************/
114 263974 : static void backend_process_batch(const int ntasks,
115 : const dbm_task_t batch[ntasks],
116 : const double alpha, const dbm_pack_t *pack_a,
117 : const dbm_pack_t *pack_b, const int kshard,
118 : dbm_shard_t *shard_c, const bool finish,
119 : const bool force_cpu,
120 : backend_context_t *ctx) {
121 263974 : if (NULL != ctx) {
122 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
123 : if (!force_cpu) {
124 : dbm_multiply_gpu_process_batch(ntasks, batch, alpha, shard_c, kshard,
125 : finish, &ctx->gpu);
126 : } else
127 : #endif
128 : {
129 263974 : (void)kshard;
130 263974 : (void)finish;
131 263974 : (void)force_cpu;
132 263974 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
133 : shard_c, ctx->cpu_options);
134 : }
135 : } else { // Validate against host (aka CPU).
136 0 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
137 : shard_c, DBM_MULTIPLY_BLAS_LIBRARY);
138 : }
139 263974 : }
140 :
141 : /*******************************************************************************
142 : * \brief Private routine for shutting down the multiplication backend.
143 : * \author Ole Schuett
144 : ******************************************************************************/
145 242490 : static void backend_stop(backend_context_t *ctx) {
146 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
147 : dbm_multiply_gpu_stop(&ctx->gpu);
148 : #endif
149 242490 : free(ctx);
150 242490 : }
151 :
152 : /*******************************************************************************
153 : * \brief Private routine for multiplying two packs (C += alpha * A * B).
154 : *
155 : * Blocks in each pack are grouped by shard (free_index % nshards) and sorted
156 : * by sum_index within each group. The algorithm:
157 : * 1. Builds shard-boundary lookup tables for A (rows) and B (cols).
158 : * 2. For each (shard_row, shard_col) pair, determines the contiguous A and B
159 : * block ranges belonging to that shard.
160 : * 3. Performs a merge-join over sum_index: advances A and B cursors in
161 : *lockstep, caching the B sub-range for each sum_index so that multiple A blocks
162 : *with the same sum_index reuse it without rescanning.
163 : * 4. Applies a norm-based filter (alpha^2 * norm_a * norm_b < eps) for early
164 : * rejection before looking up or allocating the C block.
165 : * 5. Accumulates matching pairs into a batched GEMM task list, flushing to the
166 : * backend (CPU or GPU) every DBM_MAX_BATCH_SIZE tasks.
167 : *
168 : * \author Ole Schuett and Hans Pabst
169 : ******************************************************************************/
170 263896 : static void multiply_packs(const bool transa, const bool transb,
171 : const double alpha, const dbm_pack_t *pack_a,
172 : const dbm_pack_t *pack_b,
173 : const dbm_matrix_t *matrix_a,
174 : const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
175 : const float *rows_max_eps,
176 : const bool retain_sparsity, const bool force_cpu,
177 : int64_t *flop, backend_context_t *ctx) {
178 : // For validation, FLOPS do not count, and relying on ctx is not necessary.
179 263896 : backend_context_t *const context = (NULL != flop ? ctx : NULL);
180 263896 : const float alpha2 = (float)(alpha * alpha);
181 263896 : int64_t flop_sum = 0;
182 :
183 263896 : const int nshard_rows = matrix_c->dist->rows.nshards;
184 263896 : const int nshard_cols = matrix_c->dist->cols.nshards;
185 263896 : int *shard_row_start = calloc(nshard_rows, sizeof(int));
186 263896 : int *shard_col_start = calloc(nshard_cols, sizeof(int));
187 263896 : assert(NULL != shard_row_start && NULL != shard_col_start);
188 :
189 263896 : const int *sum_index_sizes_a =
190 : (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
191 263896 : const int *sum_index_sizes_b =
192 : (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
193 263896 : const int *free_index_sizes_a =
194 : (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
195 263896 : const int *free_index_sizes_b =
196 : (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
197 :
198 263896 : #pragma omp parallel reduction(+ : flop_sum)
199 : {
200 : // Thread-private array covering given work in piece-wise fashion.
201 : dbm_task_t *batch =
202 : offload_mempool_host_malloc(sizeof(dbm_task_t) * DBM_MAX_BATCH_SIZE);
203 :
204 : // Blocks are ordered first by shard. Creating lookup tables of boundaries.
205 : #pragma omp for nowait
206 : for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
207 : const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
208 : const int prev_shard_row =
209 : pack_a->blocks[iblock - 1].free_index % nshard_rows;
210 : if (prev_shard_row != shard_row) {
211 : shard_row_start[shard_row] = iblock;
212 : }
213 : }
214 : #pragma omp for
215 : for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
216 : const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
217 : const int prev_shard_col =
218 : pack_b->blocks[jblock - 1].free_index % nshard_cols;
219 : if (prev_shard_col != shard_col) {
220 : shard_col_start[shard_col] = jblock;
221 : }
222 : }
223 :
224 : #pragma omp for collapse(2) DBM_OMP_SCHEDULE
225 : for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
226 : for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
227 : const int ishard = shard_row * nshard_cols + shard_col;
228 : dbm_shard_t *const shard_c = &matrix_c->shards[ishard];
229 : int ntasks = 0;
230 :
231 : // Determine contiguous block ranges for this shard in A and B.
232 : // Use a merge-join to find pairs of blocks with matching sum indices.
233 : // This utilizes that blocks within a shard are ordered by sum_index.
234 : const int iblock_start = shard_row_start[shard_row];
235 : int iblock_end = pack_a->nblocks;
236 : for (int t = iblock_start; t < pack_a->nblocks; ++t) {
237 : if (pack_a->blocks[t].free_index % nshard_rows != shard_row) {
238 : iblock_end = t;
239 : break;
240 : }
241 : }
242 : const int jblock_start = shard_col_start[shard_col];
243 : int jblock_end = pack_b->nblocks;
244 : for (int t = jblock_start; t < pack_b->nblocks; ++t) {
245 : if (pack_b->blocks[t].free_index % nshard_cols != shard_col) {
246 : jblock_end = t;
247 : break;
248 : }
249 : }
250 : if (iblock_start >= iblock_end || jblock_start >= jblock_end) {
251 : backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
252 : shard_c, true, force_cpu, context);
253 : continue;
254 : }
255 :
256 : // Merge over sum_index (both ranges sorted by sum_index).
257 : // Cache the B sub-range for each sum_index so that multiple A blocks
258 : // sharing the same sum_index reuse it without re-scanning B.
259 : int i = iblock_start, j = jblock_start, last_sum_index = -1;
260 : int b_range_start = -1, b_range_end = -1;
261 :
262 : while (i < iblock_end) {
263 : const dbm_pack_block_t *blk_a = &pack_a->blocks[i];
264 : const int sum_a = blk_a->sum_index;
265 :
266 : // Advance j until sum_b >= sum_a.
267 : while (j < jblock_end && pack_b->blocks[j].sum_index < sum_a) {
268 : ++j;
269 : }
270 : if (j >= jblock_end) {
271 : break; // No more matches possible.
272 : }
273 :
274 : const int sum_b = pack_b->blocks[j].sum_index;
275 : if (sum_b > sum_a) {
276 : ++i;
277 : continue; // Need next A block with higher sum_index.
278 : }
279 :
280 : // sum_a == sum_b: establish (or reuse) B range with this sum_index.
281 : if (sum_a != last_sum_index) {
282 : b_range_start = j;
283 : int t = j + 1;
284 : while (t < jblock_end && pack_b->blocks[t].sum_index == sum_a) {
285 : ++t;
286 : }
287 : b_range_end = t;
288 : last_sum_index = sum_a;
289 : }
290 :
291 : // Iterate over B blocks in current sum_index range.
292 : for (int jb = b_range_start; jb < b_range_end; ++jb) {
293 : const dbm_pack_block_t *const blk_b = &pack_b->blocks[jb];
294 :
295 : // Norm filter first (early reject).
296 : const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
297 : if (result_norm < rows_max_eps[blk_a->free_index]) {
298 : continue;
299 : }
300 :
301 : // Check block sizes.
302 : const int m = free_index_sizes_a[blk_a->free_index];
303 : const int n = free_index_sizes_b[blk_b->free_index];
304 : const int k = sum_index_sizes_a[sum_a];
305 : assert(m == matrix_c->row_sizes[blk_a->free_index]);
306 : assert(n == matrix_c->col_sizes[blk_b->free_index]);
307 : assert(k == sum_index_sizes_b[blk_b->sum_index]);
308 :
309 : if (m == 0 || n == 0 || k == 0) {
310 : continue;
311 : }
312 :
313 : // Get C block.
314 : const int row = blk_a->free_index, col = blk_b->free_index;
315 : dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
316 : if (blk_c == NULL) {
317 : if (retain_sparsity) {
318 : continue;
319 : }
320 : assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
321 : assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
322 : matrix_c->dist->my_rank);
323 : blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
324 : }
325 :
326 : // Count flops.
327 : const int64_t task_flops = 2LL * m * n * k;
328 : flop_sum += task_flops;
329 : dbm_library_counter_increment(m, n, k);
330 :
331 : // Add block multiplication to batch.
332 : dbm_task_t *const tptr = &batch[ntasks];
333 : tptr->offset_a = blk_a->offset;
334 : tptr->offset_b = blk_b->offset;
335 : tptr->offset_c = blk_c->offset;
336 : tptr->m = m;
337 : tptr->n = n;
338 : tptr->k = k;
339 : ++ntasks;
340 :
341 : if (ntasks == DBM_MAX_BATCH_SIZE) {
342 : backend_process_batch(ntasks, batch, alpha, pack_a, pack_b,
343 : ishard, shard_c, false, force_cpu, context);
344 : ntasks = 0;
345 : }
346 : }
347 :
348 : // Advance i; if next A block has same sum_index, B range is reused.
349 : ++i;
350 : }
351 : backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
352 : shard_c, true, force_cpu, context);
353 : }
354 : }
355 :
356 : offload_mempool_host_free(batch);
357 : }
358 :
359 263896 : free(shard_row_start);
360 263896 : free(shard_col_start);
361 :
362 263896 : if (NULL != flop) {
363 263896 : *flop += flop_sum;
364 : }
365 263896 : }
366 :
367 : /*******************************************************************************
368 : * \brief Performs a multiplication of two dbm_matrix_t matrices.
369 : * See dbm_matrix.h for details.
370 : * \author Ole Schuett
371 : ******************************************************************************/
372 242490 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
373 : const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
374 : const double beta, dbm_matrix_t *matrix_c,
375 : const bool retain_sparsity, const double filter_eps,
376 : int64_t *flop) {
377 242490 : assert(omp_get_num_threads() == 1);
378 242490 : assert(matrix_a != NULL && matrix_b != NULL && matrix_c != NULL);
379 :
380 : // Throughout the matrix multiplication code the "sum_index" and "free_index"
381 : // denote the summation (aka dummy) and free index from the Einstein notation.
382 242490 : const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
383 242490 : const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
384 242490 : const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
385 242490 : const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
386 :
387 : // Sanity check matrix dimensions.
388 242490 : assert(num_sum_index_a == num_sum_index_b);
389 242490 : assert(num_free_index_a == matrix_c->nrows);
390 242490 : assert(num_free_index_b == matrix_c->ncols);
391 :
392 : // Prepare matrix_c (host).
393 242490 : dbm_scale(matrix_c, beta);
394 :
395 : // Determine if validation shall be performed.
396 242490 : const char *const maxeps_env = getenv("DBM_MULTIPLY_MAXEPS");
397 242490 : const char *const verify_env = getenv("DBM_MULTIPLY_VERIFY");
398 242490 : const double maxeps = (NULL == maxeps_env ? 1E-1 : fabs(atof(maxeps_env)));
399 484980 : const int verify =
400 242490 : (NULL == verify_env ? (NULL == maxeps_env ? 0 : 1) : atoi(verify_env));
401 242490 : dbm_matrix_t *matrix_d = NULL;
402 242490 : if (0 != verify) {
403 0 : dbm_distribution_t *const dist_shared = matrix_c->dist;
404 0 : dbm_create(&matrix_d, dist_shared, matrix_c->name, matrix_c->nrows,
405 0 : matrix_c->ncols, matrix_c->row_sizes, matrix_c->col_sizes);
406 0 : dbm_copy(matrix_d, matrix_c);
407 : }
408 :
409 : // Compute filter thresholds for each row.
410 242490 : float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
411 :
412 : // Start uploading matrix_c to the GPU.
413 242490 : backend_context_t *ctx = backend_start(matrix_c);
414 :
415 : // Redistribute matrix_a and matrix_b across MPI ranks.
416 242490 : dbm_comm_iterator_t *iter =
417 242490 : dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
418 :
419 : // Count flops if requested.
420 242490 : if (NULL != flop) {
421 242490 : *flop = 0;
422 : }
423 :
424 : // Main loop.
425 : dbm_pack_t *pack_a, *pack_b;
426 506386 : while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
427 263896 : const bool uploaded = backend_upload_packs(pack_a, pack_b, ctx);
428 263896 : (void)uploaded; // mark used
429 263896 : multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
430 : matrix_c, rows_max_eps, retain_sparsity, false /*!uploaded*/,
431 : flop, ctx);
432 : }
433 :
434 : // Wait for all other MPI ranks to complete, then release ressources.
435 242490 : dbm_comm_iterator_stop(iter);
436 242490 : backend_stop(ctx);
437 :
438 242490 : if (NULL != matrix_d) {
439 0 : ctx = backend_start(matrix_d);
440 0 : iter =
441 0 : dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_d);
442 0 : while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
443 0 : multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
444 : matrix_d, rows_max_eps, retain_sparsity, true, NULL, ctx);
445 : }
446 0 : dbm_comm_iterator_stop(iter);
447 0 : backend_stop(ctx);
448 0 : const double epsilon = dbm_maxeps(matrix_d, matrix_c);
449 0 : if (maxeps < epsilon) {
450 0 : if (1 == verify) {
451 0 : fprintf(stderr, "WARN ACC/LIBDBM: diff=%g\n", epsilon);
452 : } else {
453 0 : fprintf(stderr, "ERROR ACC/LIBDBM: diff=%g\n", epsilon);
454 0 : exit(EXIT_FAILURE);
455 : }
456 : }
457 0 : dbm_release(matrix_d);
458 : }
459 :
460 : // Release filter thresholds.
461 242490 : free(rows_max_eps);
462 :
463 : // Final filter pass.
464 242490 : dbm_filter(matrix_c, filter_eps);
465 242490 : }
466 :
467 : // EOF
|