LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:c24029e) Lines: 76.0 % 100 76
Test Date: 2026-07-04 06:36:57 Functions: 85.7 % 7 6

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : #include "dbm_multiply.h"
       8              : #include "../offload/offload_mempool.h"
       9              : #include "../offload/offload_runtime.h"
      10              : #include "dbm_hyperparams.h"
      11              : #include "dbm_internal.h"
      12              : #include "dbm_library.h"
      13              : #include "dbm_multiply_comm.h"
      14              : #include "dbm_multiply_cpu.h"
      15              : #include "dbm_multiply_gpu.h"
      16              : 
      17              : #include <assert.h>
      18              : #include <limits.h>
      19              : #include <math.h>
      20              : #include <omp.h>
      21              : #include <stdio.h>
      22              : #include <stdlib.h>
      23              : #include <string.h>
      24              : 
      25              : /*******************************************************************************
      26              :  * \brief Private routine for computing the max filter threshold for each row.
      27              :  * \author Ole Schuett
      28              :  ******************************************************************************/
      29       242490 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
      30              :                                    const double filter_eps) {
      31       242490 :   const int nrows = (trans) ? matrix->ncols : matrix->nrows;
      32       242490 :   int *nblocks_per_row = calloc(nrows, sizeof(int));
      33       242490 :   float *row_max_eps = malloc(nrows * sizeof(float));
      34       242490 :   assert((nblocks_per_row != NULL && row_max_eps != NULL) || nrows == 0);
      35              : 
      36       242490 : #pragma omp parallel
      37              :   {
      38              : #pragma omp for
      39              :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
      40              :       dbm_shard_t *shard = &matrix->shards[ishard];
      41              :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
      42              :         const dbm_block_t *blk = &shard->blocks[iblock];
      43              :         const int row = (trans) ? blk->col : blk->row;
      44              : #pragma omp atomic
      45              :         ++nblocks_per_row[row];
      46              :       }
      47              :     }
      48              : #pragma omp master
      49              :     cp_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
      50              : #pragma omp barrier
      51              : #pragma omp for
      52              :     for (int i = 0; i < nrows; i++) {
      53              :       const float f =
      54              :           ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
      55              :       row_max_eps[i] = f * f;
      56              :     }
      57              :   } // end of omp parallel region
      58              : 
      59       242490 :   free(nblocks_per_row);
      60       242490 :   return row_max_eps; // Ownership of row_max_eps transfers to caller.
      61              : }
      62              : 
      63              : /*******************************************************************************
      64              :  * \brief Private struct for storing the context of the multiplication backend.
      65              :  * \author Ole Schuett
      66              :  ******************************************************************************/
      67              : typedef struct {
      68              : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
      69              :   dbm_multiply_gpu_context_t gpu;
      70              : #endif
      71              :   int cpu_options; // Binary or'ed dbm_multiply_cpu_options (enum).
      72              : } backend_context_t;
      73              : 
      74              : /*******************************************************************************
      75              :  * \brief Private routine for initializing the multiplication backend.
      76              :  * \author Ole Schuett
      77              :  ******************************************************************************/
      78       242490 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
      79       242490 :   backend_context_t *const ctx = calloc(1, sizeof(backend_context_t));
      80              :   // BLAS and LIBXS benefit in general from DBM_MULTIPLY_TASK_REORDER.
      81       242490 :   ctx->cpu_options = DBM_MULTIPLY_TASK_REORDER;
      82              : 
      83              : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
      84              :   dbm_multiply_gpu_start(DBM_MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
      85              :                          matrix_c->shards, &ctx->gpu);
      86              : #else
      87       242490 :   (void)matrix_c; // mark as used
      88              : #endif
      89              : 
      90       242490 :   return ctx;
      91              : }
      92              : 
      93              : /*******************************************************************************
      94              :  * \brief Private routine for handing newly arrived packs to the backend.
      95              :  * \author Ole Schuett
      96              :  ******************************************************************************/
      97            0 : static bool backend_upload_packs(const dbm_pack_t *pack_a,
      98              :                                  const dbm_pack_t *pack_b,
      99              :                                  backend_context_t *ctx) {
     100              : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     101              :   return dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
     102              : #else
     103            0 :   (void)pack_a; // mark as used
     104            0 :   (void)pack_b;
     105            0 :   (void)ctx;
     106            0 :   return false;
     107              : #endif
     108              : }
     109              : 
     110              : /*******************************************************************************
     111              :  * \brief Private routine for sending a batch to the multiplication backend.
     112              :  * \author Ole Schuett
     113              :  ******************************************************************************/
     114       263974 : static void backend_process_batch(const int ntasks,
     115              :                                   const dbm_task_t batch[ntasks],
     116              :                                   const double alpha, const dbm_pack_t *pack_a,
     117              :                                   const dbm_pack_t *pack_b, const int kshard,
     118              :                                   dbm_shard_t *shard_c, const bool finish,
     119              :                                   const bool force_cpu,
     120              :                                   backend_context_t *ctx) {
     121       263974 :   if (NULL != ctx) {
     122              : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     123              :     if (!force_cpu) {
     124              :       dbm_multiply_gpu_process_batch(ntasks, batch, alpha, shard_c, kshard,
     125              :                                      finish, &ctx->gpu);
     126              :     } else
     127              : #endif
     128              :     {
     129       263974 :       (void)kshard;
     130       263974 :       (void)finish;
     131       263974 :       (void)force_cpu;
     132       263974 :       dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
     133              :                                      shard_c, ctx->cpu_options);
     134              :     }
     135              :   } else { // Validate against host (aka CPU).
     136            0 :     dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
     137              :                                    shard_c, DBM_MULTIPLY_BLAS_LIBRARY);
     138              :   }
     139       263974 : }
     140              : 
     141              : /*******************************************************************************
     142              :  * \brief Private routine for shutting down the multiplication backend.
     143              :  * \author Ole Schuett
     144              :  ******************************************************************************/
     145       242490 : static void backend_stop(backend_context_t *ctx) {
     146              : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     147              :   dbm_multiply_gpu_stop(&ctx->gpu);
     148              : #endif
     149       242490 :   free(ctx);
     150       242490 : }
     151              : 
     152              : /*******************************************************************************
     153              :  * \brief Private routine for multiplying two packs (C += alpha * A * B).
     154              :  *
     155              :  * Blocks in each pack are grouped by shard (free_index % nshards) and sorted
     156              :  * by sum_index within each group. The algorithm:
     157              :  *  1. Builds shard-boundary lookup tables for A (rows) and B (cols).
     158              :  *  2. For each (shard_row, shard_col) pair, determines the contiguous A and B
     159              :  *     block ranges belonging to that shard.
     160              :  *  3. Performs a merge-join over sum_index: advances A and B cursors in
     161              :  *lockstep, caching the B sub-range for each sum_index so that multiple A blocks
     162              :  *with the same sum_index reuse it without rescanning.
     163              :  *  4. Applies a norm-based filter (alpha^2 * norm_a * norm_b < eps) for early
     164              :  *     rejection before looking up or allocating the C block.
     165              :  *  5. Accumulates matching pairs into a batched GEMM task list, flushing to the
     166              :  *     backend (CPU or GPU) every DBM_MAX_BATCH_SIZE tasks.
     167              :  *
     168              :  * \author Ole Schuett and Hans Pabst
     169              :  ******************************************************************************/
     170       263896 : static void multiply_packs(const bool transa, const bool transb,
     171              :                            const double alpha, const dbm_pack_t *pack_a,
     172              :                            const dbm_pack_t *pack_b,
     173              :                            const dbm_matrix_t *matrix_a,
     174              :                            const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
     175              :                            const float *rows_max_eps,
     176              :                            const bool retain_sparsity, const bool force_cpu,
     177              :                            int64_t *flop, backend_context_t *ctx) {
     178              :   // For validation, FLOPS do not count, and relying on ctx is not necessary.
     179       263896 :   backend_context_t *const context = (NULL != flop ? ctx : NULL);
     180       263896 :   const float alpha2 = (float)(alpha * alpha);
     181       263896 :   int64_t flop_sum = 0;
     182              : 
     183       263896 :   const int nshard_rows = matrix_c->dist->rows.nshards;
     184       263896 :   const int nshard_cols = matrix_c->dist->cols.nshards;
     185       263896 :   int *shard_row_start = calloc(nshard_rows, sizeof(int));
     186       263896 :   int *shard_col_start = calloc(nshard_cols, sizeof(int));
     187       263896 :   assert(NULL != shard_row_start && NULL != shard_col_start);
     188              : 
     189       263896 :   const int *sum_index_sizes_a =
     190              :       (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
     191       263896 :   const int *sum_index_sizes_b =
     192              :       (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
     193       263896 :   const int *free_index_sizes_a =
     194              :       (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
     195       263896 :   const int *free_index_sizes_b =
     196              :       (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
     197              : 
     198       263896 : #pragma omp parallel reduction(+ : flop_sum)
     199              :   {
     200              :     // Thread-private array covering given work in piece-wise fashion.
     201              :     dbm_task_t *batch =
     202              :         offload_mempool_host_malloc(sizeof(dbm_task_t) * DBM_MAX_BATCH_SIZE);
     203              : 
     204              :     // Blocks are ordered first by shard. Creating lookup tables of boundaries.
     205              : #pragma omp for nowait
     206              :     for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
     207              :       const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
     208              :       const int prev_shard_row =
     209              :           pack_a->blocks[iblock - 1].free_index % nshard_rows;
     210              :       if (prev_shard_row != shard_row) {
     211              :         shard_row_start[shard_row] = iblock;
     212              :       }
     213              :     }
     214              : #pragma omp for
     215              :     for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
     216              :       const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
     217              :       const int prev_shard_col =
     218              :           pack_b->blocks[jblock - 1].free_index % nshard_cols;
     219              :       if (prev_shard_col != shard_col) {
     220              :         shard_col_start[shard_col] = jblock;
     221              :       }
     222              :     }
     223              : 
     224              : #pragma omp for collapse(2) DBM_OMP_SCHEDULE
     225              :     for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
     226              :       for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
     227              :         const int ishard = shard_row * nshard_cols + shard_col;
     228              :         dbm_shard_t *const shard_c = &matrix_c->shards[ishard];
     229              :         int ntasks = 0;
     230              : 
     231              :         // Determine contiguous block ranges for this shard in A and B.
     232              :         // Use a merge-join to find pairs of blocks with matching sum indices.
     233              :         // This utilizes that blocks within a shard are ordered by sum_index.
     234              :         const int iblock_start = shard_row_start[shard_row];
     235              :         int iblock_end = pack_a->nblocks;
     236              :         for (int t = iblock_start; t < pack_a->nblocks; ++t) {
     237              :           if (pack_a->blocks[t].free_index % nshard_rows != shard_row) {
     238              :             iblock_end = t;
     239              :             break;
     240              :           }
     241              :         }
     242              :         const int jblock_start = shard_col_start[shard_col];
     243              :         int jblock_end = pack_b->nblocks;
     244              :         for (int t = jblock_start; t < pack_b->nblocks; ++t) {
     245              :           if (pack_b->blocks[t].free_index % nshard_cols != shard_col) {
     246              :             jblock_end = t;
     247              :             break;
     248              :           }
     249              :         }
     250              :         if (iblock_start >= iblock_end || jblock_start >= jblock_end) {
     251              :           backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
     252              :                                 shard_c, true, force_cpu, context);
     253              :           continue;
     254              :         }
     255              : 
     256              :         // Merge over sum_index (both ranges sorted by sum_index).
     257              :         // Cache the B sub-range for each sum_index so that multiple A blocks
     258              :         // sharing the same sum_index reuse it without re-scanning B.
     259              :         int i = iblock_start, j = jblock_start, last_sum_index = -1;
     260              :         int b_range_start = -1, b_range_end = -1;
     261              : 
     262              :         while (i < iblock_end) {
     263              :           const dbm_pack_block_t *blk_a = &pack_a->blocks[i];
     264              :           const int sum_a = blk_a->sum_index;
     265              : 
     266              :           // Advance j until sum_b >= sum_a.
     267              :           while (j < jblock_end && pack_b->blocks[j].sum_index < sum_a) {
     268              :             ++j;
     269              :           }
     270              :           if (j >= jblock_end) {
     271              :             break; // No more matches possible.
     272              :           }
     273              : 
     274              :           const int sum_b = pack_b->blocks[j].sum_index;
     275              :           if (sum_b > sum_a) {
     276              :             ++i;
     277              :             continue; // Need next A block with higher sum_index.
     278              :           }
     279              : 
     280              :           // sum_a == sum_b: establish (or reuse) B range with this sum_index.
     281              :           if (sum_a != last_sum_index) {
     282              :             b_range_start = j;
     283              :             int t = j + 1;
     284              :             while (t < jblock_end && pack_b->blocks[t].sum_index == sum_a) {
     285              :               ++t;
     286              :             }
     287              :             b_range_end = t;
     288              :             last_sum_index = sum_a;
     289              :           }
     290              : 
     291              :           // Iterate over B blocks in current sum_index range.
     292              :           for (int jb = b_range_start; jb < b_range_end; ++jb) {
     293              :             const dbm_pack_block_t *const blk_b = &pack_b->blocks[jb];
     294              : 
     295              :             // Norm filter first (early reject).
     296              :             const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
     297              :             if (result_norm < rows_max_eps[blk_a->free_index]) {
     298              :               continue;
     299              :             }
     300              : 
     301              :             // Check block sizes.
     302              :             const int m = free_index_sizes_a[blk_a->free_index];
     303              :             const int n = free_index_sizes_b[blk_b->free_index];
     304              :             const int k = sum_index_sizes_a[sum_a];
     305              :             assert(m == matrix_c->row_sizes[blk_a->free_index]);
     306              :             assert(n == matrix_c->col_sizes[blk_b->free_index]);
     307              :             assert(k == sum_index_sizes_b[blk_b->sum_index]);
     308              : 
     309              :             if (m == 0 || n == 0 || k == 0) {
     310              :               continue;
     311              :             }
     312              : 
     313              :             // Get C block.
     314              :             const int row = blk_a->free_index, col = blk_b->free_index;
     315              :             dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
     316              :             if (blk_c == NULL) {
     317              :               if (retain_sparsity) {
     318              :                 continue;
     319              :               }
     320              :               assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
     321              :               assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
     322              :                      matrix_c->dist->my_rank);
     323              :               blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
     324              :             }
     325              : 
     326              :             // Count flops.
     327              :             const int64_t task_flops = 2LL * m * n * k;
     328              :             flop_sum += task_flops;
     329              :             dbm_library_counter_increment(m, n, k);
     330              : 
     331              :             // Add block multiplication to batch.
     332              :             dbm_task_t *const tptr = &batch[ntasks];
     333              :             tptr->offset_a = blk_a->offset;
     334              :             tptr->offset_b = blk_b->offset;
     335              :             tptr->offset_c = blk_c->offset;
     336              :             tptr->m = m;
     337              :             tptr->n = n;
     338              :             tptr->k = k;
     339              :             ++ntasks;
     340              : 
     341              :             if (ntasks == DBM_MAX_BATCH_SIZE) {
     342              :               backend_process_batch(ntasks, batch, alpha, pack_a, pack_b,
     343              :                                     ishard, shard_c, false, force_cpu, context);
     344              :               ntasks = 0;
     345              :             }
     346              :           }
     347              : 
     348              :           // Advance i; if next A block has same sum_index, B range is reused.
     349              :           ++i;
     350              :         }
     351              :         backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
     352              :                               shard_c, true, force_cpu, context);
     353              :       }
     354              :     }
     355              : 
     356              :     offload_mempool_host_free(batch);
     357              :   }
     358              : 
     359       263896 :   free(shard_row_start);
     360       263896 :   free(shard_col_start);
     361              : 
     362       263896 :   if (NULL != flop) {
     363       263896 :     *flop += flop_sum;
     364              :   }
     365       263896 : }
     366              : 
     367              : /*******************************************************************************
     368              :  * \brief Performs a multiplication of two dbm_matrix_t matrices.
     369              :  *        See dbm_matrix.h for details.
     370              :  * \author Ole Schuett
     371              :  ******************************************************************************/
     372       242490 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
     373              :                   const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
     374              :                   const double beta, dbm_matrix_t *matrix_c,
     375              :                   const bool retain_sparsity, const double filter_eps,
     376              :                   int64_t *flop) {
     377       242490 :   assert(omp_get_num_threads() == 1);
     378       242490 :   assert(matrix_a != NULL && matrix_b != NULL && matrix_c != NULL);
     379              : 
     380              :   // Throughout the matrix multiplication code the "sum_index" and "free_index"
     381              :   // denote the summation (aka dummy) and free index from the Einstein notation.
     382       242490 :   const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
     383       242490 :   const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
     384       242490 :   const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
     385       242490 :   const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
     386              : 
     387              :   // Sanity check matrix dimensions.
     388       242490 :   assert(num_sum_index_a == num_sum_index_b);
     389       242490 :   assert(num_free_index_a == matrix_c->nrows);
     390       242490 :   assert(num_free_index_b == matrix_c->ncols);
     391              : 
     392              :   // Prepare matrix_c (host).
     393       242490 :   dbm_scale(matrix_c, beta);
     394              : 
     395              :   // Determine if validation shall be performed.
     396       242490 :   const char *const maxeps_env = getenv("DBM_MULTIPLY_MAXEPS");
     397       242490 :   const char *const verify_env = getenv("DBM_MULTIPLY_VERIFY");
     398       242490 :   const double maxeps = (NULL == maxeps_env ? 1E-1 : fabs(atof(maxeps_env)));
     399       484980 :   const int verify =
     400       242490 :       (NULL == verify_env ? (NULL == maxeps_env ? 0 : 1) : atoi(verify_env));
     401       242490 :   dbm_matrix_t *matrix_d = NULL;
     402       242490 :   if (0 != verify) {
     403            0 :     dbm_distribution_t *const dist_shared = matrix_c->dist;
     404            0 :     dbm_create(&matrix_d, dist_shared, matrix_c->name, matrix_c->nrows,
     405            0 :                matrix_c->ncols, matrix_c->row_sizes, matrix_c->col_sizes);
     406            0 :     dbm_copy(matrix_d, matrix_c);
     407              :   }
     408              : 
     409              :   // Compute filter thresholds for each row.
     410       242490 :   float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
     411              : 
     412              :   // Start uploading matrix_c to the GPU.
     413       242490 :   backend_context_t *ctx = backend_start(matrix_c);
     414              : 
     415              :   // Redistribute matrix_a and matrix_b across MPI ranks.
     416       242490 :   dbm_comm_iterator_t *iter =
     417       242490 :       dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
     418              : 
     419              :   // Count flops if requested.
     420       242490 :   if (NULL != flop) {
     421       242490 :     *flop = 0;
     422              :   }
     423              : 
     424              :   // Main loop.
     425              :   dbm_pack_t *pack_a, *pack_b;
     426       506386 :   while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
     427       263896 :     const bool uploaded = backend_upload_packs(pack_a, pack_b, ctx);
     428       263896 :     (void)uploaded; // mark used
     429       263896 :     multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
     430              :                    matrix_c, rows_max_eps, retain_sparsity, false /*!uploaded*/,
     431              :                    flop, ctx);
     432              :   }
     433              : 
     434              :   // Wait for all other MPI ranks to complete, then release ressources.
     435       242490 :   dbm_comm_iterator_stop(iter);
     436       242490 :   backend_stop(ctx);
     437              : 
     438       242490 :   if (NULL != matrix_d) {
     439            0 :     ctx = backend_start(matrix_d);
     440            0 :     iter =
     441            0 :         dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_d);
     442            0 :     while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
     443            0 :       multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
     444              :                      matrix_d, rows_max_eps, retain_sparsity, true, NULL, ctx);
     445              :     }
     446            0 :     dbm_comm_iterator_stop(iter);
     447            0 :     backend_stop(ctx);
     448            0 :     const double epsilon = dbm_maxeps(matrix_d, matrix_c);
     449            0 :     if (maxeps < epsilon) {
     450            0 :       if (1 == verify) {
     451            0 :         fprintf(stderr, "WARN ACC/LIBDBM: diff=%g\n", epsilon);
     452              :       } else {
     453            0 :         fprintf(stderr, "ERROR ACC/LIBDBM: diff=%g\n", epsilon);
     454            0 :         exit(EXIT_FAILURE);
     455              :       }
     456              :     }
     457            0 :     dbm_release(matrix_d);
     458              :   }
     459              : 
     460              :   // Release filter thresholds.
     461       242490 :   free(rows_max_eps);
     462              : 
     463              :   // Final filter pass.
     464       242490 :   dbm_filter(matrix_c, filter_eps);
     465       242490 : }
     466              : 
     467              : // EOF
        

Generated by: LCOV version 2.0-1