LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:a2cdc02) Lines: 62 70 88.6 %
Date: 2025-04-17 08:15:26 Functions: 6 8 75.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <assert.h>
       9             : #include <limits.h>
      10             : #include <omp.h>
      11             : #include <stdlib.h>
      12             : #include <string.h>
      13             : 
      14             : #include "../offload/offload_runtime.h"
      15             : #include "dbm_hyperparams.h"
      16             : #include "dbm_internal.h"
      17             : #include "dbm_library.h"
      18             : #include "dbm_mempool.h"
      19             : #include "dbm_multiply.h"
      20             : #include "dbm_multiply_comm.h"
      21             : #include "dbm_multiply_cpu.h"
      22             : #include "dbm_multiply_gpu.h"
      23             : 
      24             : #if defined(__LIBXSMM)
      25             : #include <libxsmm.h>
      26             : #endif
      27             : 
      28             : #if !defined(DBM_VALIDATE_AGAINST_LIBXSMM) && 0
      29             : #define DBM_VALIDATE_AGAINST_LIBXSMM
      30             : #endif
      31             : 
      32             : /*******************************************************************************
      33             :  * \brief Private routine for computing the max filter threshold for each row.
      34             :  * \author Ole Schuett
      35             :  ******************************************************************************/
      36      197367 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
      37             :                                    const double filter_eps) {
      38      197367 :   const int nrows = (trans) ? matrix->ncols : matrix->nrows;
      39      197367 :   int *nblocks_per_row = calloc(nrows, sizeof(int));
      40      197367 :   float *row_max_eps = malloc(nrows * sizeof(float));
      41      197367 :   assert(row_max_eps != NULL);
      42             : 
      43      197367 : #pragma omp parallel
      44             :   {
      45             : #pragma omp for
      46             :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
      47             :       dbm_shard_t *shard = &matrix->shards[ishard];
      48             :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
      49             :         const dbm_block_t *blk = &shard->blocks[iblock];
      50             :         const int row = (trans) ? blk->col : blk->row;
      51             : #pragma omp atomic
      52             :         nblocks_per_row[row]++;
      53             :       }
      54             :     }
      55             : #pragma omp single
      56             :     dbm_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
      57             : #pragma omp barrier
      58             : #pragma omp for
      59             :     for (int i = 0; i < nrows; i++) {
      60             :       const float f =
      61             :           ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
      62             :       row_max_eps[i] = f * f;
      63             :     }
      64             :   } // end of omp parallel region
      65             : 
      66      197367 :   free(nblocks_per_row);
      67      197367 :   return row_max_eps; // Ownership of row_max_eps transfers to caller.
      68             : }
      69             : 
      70             : /*******************************************************************************
      71             :  * \brief Private struct for storing the context of the multiplication backend.
      72             :  * \author Ole Schuett
      73             :  ******************************************************************************/
      74             : typedef struct {
      75             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
      76             :   dbm_multiply_gpu_context_t gpu;
      77             : #endif
      78             : } backend_context_t;
      79             : 
      80             : /*******************************************************************************
      81             :  * \brief Private routine for initializing the multiplication backend.
      82             :  * \author Ole Schuett
      83             :  ******************************************************************************/
      84      197367 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
      85      197367 :   backend_context_t *ctx = calloc(1, sizeof(backend_context_t));
      86             : 
      87             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
      88             :   dbm_multiply_gpu_start(DBM_MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
      89             :                          matrix_c->shards, &ctx->gpu);
      90             : #else
      91      197367 :   (void)matrix_c; // mark as used
      92             : #endif
      93             : 
      94      197367 :   return ctx;
      95             : }
      96             : 
      97             : /*******************************************************************************
      98             :  * \brief Private routine for handing newly arrived packs to the backend.
      99             :  * \author Ole Schuett
     100             :  ******************************************************************************/
     101           0 : static void backend_upload_packs(const dbm_pack_t *pack_a,
     102             :                                  const dbm_pack_t *pack_b,
     103             :                                  backend_context_t *ctx) {
     104             : 
     105             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     106             :   dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
     107             : #else
     108           0 :   (void)pack_a; // mark as used
     109           0 :   (void)pack_b;
     110           0 :   (void)ctx;
     111             : #endif
     112           0 : }
     113             : 
     114             : /*******************************************************************************
     115             :  * \brief Private routine for sending a batch to the multiplication backend.
     116             :  * \author Ole Schuett
     117             :  ******************************************************************************/
     118      215034 : static void backend_process_batch(const int ntasks,
     119             :                                   const dbm_task_t batch[ntasks],
     120             :                                   const double alpha, const dbm_pack_t *pack_a,
     121             :                                   const dbm_pack_t *pack_b, const int kshard,
     122             :                                   dbm_shard_t *shard_c,
     123             :                                   backend_context_t *ctx) {
     124             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     125             :   dbm_multiply_gpu_process_batch(ntasks, batch, alpha, kshard, &ctx->gpu);
     126             : #if defined(DBM_VALIDATE_AGAINST_LIBXSMM) && defined(__LIBXSMM)
     127             :   dbm_shard_gpu_t *const shard_g = &ctx->gpu.shards_c_dev[kshard];
     128             :   dbm_shard_t shard_r;
     129             :   dbm_shard_allocate_promised_blocks(shard_c);
     130             :   /* start transferring GPU result to host */
     131             :   assert(shard_c->data_size == shard_g->data_size);
     132             :   dbm_shard_init(&shard_r);
     133             :   dbm_shard_copy(&shard_r, shard_c);
     134             :   offloadMemcpyAsyncDtoH(shard_c->data, shard_g->data,
     135             :                          shard_c->data_size * sizeof(double), shard_g->stream);
     136             :   dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
     137             :                                  &shard_r);
     138             :   /* finish transferring GPU result to host */
     139             :   offloadStreamSynchronize(shard_g->stream);
     140             :   libxsmm_matdiff_info diff;
     141             :   libxsmm_matdiff_clear(&diff);
     142             :   for (int itask = 0; itask < ntasks; ++itask) {
     143             :     const dbm_task_t task = batch[itask];
     144             :     const double *const tst = &shard_c->data[task.offset_c];
     145             :     const double *const ref = &shard_r.data[task.offset_c];
     146             :     libxsmm_matdiff_info d;
     147             :     if (EXIT_SUCCESS == libxsmm_matdiff(&d, LIBXSMM_DATATYPE(double), task.m,
     148             :                                         task.n, ref, tst, NULL /*ldref*/,
     149             :                                         NULL /*ldtst*/)) {
     150             :       libxsmm_matdiff_reduce(&diff, &d);
     151             :     }
     152             :   }
     153             :   const char *const maxeps_env = getenv("DBM_MULTIPLY_MAXEPS");
     154             :   const double maxeps = (NULL == maxeps_env ? 1E-13 : fabs(atof(maxeps_env)));
     155             :   const double epsilon = libxsmm_matdiff_epsilon(&diff);
     156             :   if (maxeps < epsilon) {
     157             :     if (LIBXSMM_NOTNAN(diff.v_tst)) {
     158             :       fprintf(stderr, "INFO ACC/LIBDBM: diff=%g (|%g-%g|=%g)\n", epsilon,
     159             :               diff.v_ref, diff.v_tst, diff.linf_abs);
     160             :     } else {
     161             :       fprintf(stderr, "INFO ACC/LIBDBM: diff=%g\n", epsilon);
     162             :     }
     163             :   }
     164             :   dbm_shard_release(&shard_r);
     165             : #else
     166             :   (void)pack_a;
     167             :   (void)pack_b;
     168             :   (void)shard_c; // mark as used
     169             : #endif
     170             : #else
     171      215034 :   (void)kshard;
     172      215034 :   (void)ctx; // mark as used
     173      215034 :   dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b, shard_c);
     174             : #endif
     175      215034 : }
     176             : 
     177             : /*******************************************************************************
     178             :  * \brief Private routine for downloading results of the multiplication backend.
     179             :  * \author Ole Schuett
     180             :  ******************************************************************************/
     181           0 : static void backend_download_results(backend_context_t *ctx) {
     182             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     183             :   dbm_multiply_gpu_download_results(&ctx->gpu);
     184             : #else
     185           0 :   (void)ctx; // mark as used
     186             : #endif
     187           0 : }
     188             : 
     189             : /*******************************************************************************
     190             :  * \brief Private routine for shutting down the multiplication backend.
     191             :  * \author Ole Schuett
     192             :  ******************************************************************************/
     193      197367 : static void backend_stop(backend_context_t *ctx) {
     194             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     195             :   dbm_multiply_gpu_stop(&ctx->gpu);
     196             : #endif
     197      197367 :   free(ctx);
     198      197367 : }
     199             : 
     200             : /*******************************************************************************
     201             :  * \brief Private routine for multipling two packs.
     202             :  * \author Ole Schuett
     203             :  ******************************************************************************/
     204      214979 : static void multiply_packs(const bool transa, const bool transb,
     205             :                            const double alpha, const dbm_pack_t *pack_a,
     206             :                            const dbm_pack_t *pack_b,
     207             :                            const dbm_matrix_t *matrix_a,
     208             :                            const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
     209             :                            const bool retain_sparsity,
     210             :                            const float *rows_max_eps, int64_t *flop,
     211             :                            backend_context_t *ctx) {
     212      214979 :   const float alpha2 = alpha * alpha;
     213      214979 :   int64_t flop_sum = 0;
     214             : 
     215      214979 :   const int nshard_rows = matrix_c->dist->rows.nshards;
     216      214979 :   const int nshard_cols = matrix_c->dist->cols.nshards;
     217      214979 :   int *shard_row_start = calloc(nshard_rows, sizeof(int));
     218      214979 :   int *shard_col_start = calloc(nshard_cols, sizeof(int));
     219      214979 :   assert(NULL != shard_row_start && NULL != shard_col_start);
     220             : 
     221      214979 :   const int *sum_index_sizes_a =
     222             :       (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
     223      214979 :   const int *sum_index_sizes_b =
     224             :       (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
     225      214979 :   const int *free_index_sizes_a =
     226             :       (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
     227      214979 :   const int *free_index_sizes_b =
     228             :       (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
     229             : 
     230      214979 : #pragma omp parallel reduction(+ : flop_sum)
     231             :   {
     232             :     // Thread-private array covering given work in piece-wise fashion.
     233             :     dbm_task_t *batch =
     234             :         dbm_mempool_host_malloc(sizeof(dbm_task_t) * DBM_MAX_BATCH_SIZE);
     235             : 
     236             :     // Blocks are ordered first by shard. Creating lookup tables of boundaries.
     237             : #pragma omp for nowait
     238             :     for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
     239             :       const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
     240             :       const int prev_shard_row =
     241             :           pack_a->blocks[iblock - 1].free_index % nshard_rows;
     242             :       if (prev_shard_row != shard_row) {
     243             :         shard_row_start[shard_row] = iblock;
     244             :       }
     245             :     }
     246             : #pragma omp for
     247             :     for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
     248             :       const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
     249             :       const int prev_shard_col =
     250             :           pack_b->blocks[jblock - 1].free_index % nshard_cols;
     251             :       if (prev_shard_col != shard_col) {
     252             :         shard_col_start[shard_col] = jblock;
     253             :       }
     254             :     }
     255             : 
     256             : #pragma omp for collapse(2) DBM_OMP_SCHEDULE
     257             :     for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
     258             :       for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
     259             :         const int ishard = shard_row * nshard_cols + shard_col;
     260             :         dbm_shard_t *shard_c = &matrix_c->shards[ishard];
     261             :         int ntasks = 0;
     262             : 
     263             :         // Use a merge-join to find pairs of blocks with matching sum indices.
     264             :         // This utilizes that blocks within a shard are ordered by sum_index.
     265             :         const int iblock_start = shard_row_start[shard_row];
     266             :         int jblock_start = shard_col_start[shard_col];
     267             :         for (int iblock = iblock_start; iblock < pack_a->nblocks; iblock++) {
     268             :           const dbm_pack_block_t *blk_a = &pack_a->blocks[iblock];
     269             :           if (blk_a->free_index % nshard_rows != shard_row) {
     270             :             break;
     271             :           }
     272             :           for (int jblock = jblock_start; jblock < pack_b->nblocks; jblock++) {
     273             :             const dbm_pack_block_t *blk_b = &pack_b->blocks[jblock];
     274             :             if (blk_b->free_index % nshard_cols != shard_col) {
     275             :               jblock = pack_b->nblocks; // break
     276             :               continue;
     277             :             }
     278             :             if (blk_a->sum_index < blk_b->sum_index) {
     279             :               jblock = pack_b->nblocks; // break
     280             :               continue;
     281             :             }
     282             :             if (blk_a->sum_index > blk_b->sum_index) {
     283             :               jblock_start++;
     284             :               continue;
     285             :             }
     286             :             // Found block pair with blk_a->sum_index == blk_b->sum_index.
     287             : 
     288             :             // Check norms.
     289             :             const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
     290             :             if (result_norm < rows_max_eps[blk_a->free_index]) {
     291             :               continue;
     292             :             }
     293             : 
     294             :             // Check block sizes.
     295             :             const int m = free_index_sizes_a[blk_a->free_index];
     296             :             const int n = free_index_sizes_b[blk_b->free_index];
     297             :             const int k = sum_index_sizes_a[blk_a->sum_index];
     298             :             assert(m == matrix_c->row_sizes[blk_a->free_index]);
     299             :             assert(n == matrix_c->col_sizes[blk_b->free_index]);
     300             :             assert(k == sum_index_sizes_b[blk_b->sum_index]);
     301             : 
     302             :             // Get C block.
     303             :             const int row = blk_a->free_index, col = blk_b->free_index;
     304             :             dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
     305             :             if (blk_c == NULL && retain_sparsity) {
     306             :               continue;
     307             :             } else if (blk_c == NULL) {
     308             :               assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
     309             :               assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
     310             :                      matrix_c->dist->my_rank);
     311             :               blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
     312             :             }
     313             : 
     314             :             // Count flops.
     315             :             const int64_t task_flops = 2LL * m * n * k;
     316             :             if (task_flops == 0) {
     317             :               continue;
     318             :             }
     319             :             flop_sum += task_flops;
     320             :             dbm_library_counter_increment(m, n, k);
     321             : 
     322             :             // Add block multiplication to batch.
     323             :             batch[ntasks].m = m;
     324             :             batch[ntasks].n = n;
     325             :             batch[ntasks].k = k;
     326             :             batch[ntasks].offset_a = blk_a->offset;
     327             :             batch[ntasks].offset_b = blk_b->offset;
     328             :             batch[ntasks].offset_c = blk_c->offset;
     329             :             ++ntasks;
     330             : 
     331             :             if (ntasks == DBM_MAX_BATCH_SIZE) {
     332             :               backend_process_batch(ntasks, batch, alpha, pack_a, pack_b,
     333             :                                     ishard, shard_c, ctx);
     334             :               ntasks = 0;
     335             :             }
     336             :           }
     337             :         }
     338             :         backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
     339             :                               shard_c, ctx);
     340             :       }
     341             :     }
     342             : 
     343             :     dbm_mempool_host_free(batch);
     344             :   }
     345             : 
     346      214979 :   free(shard_row_start);
     347      214979 :   free(shard_col_start);
     348             : 
     349      214979 :   *flop += flop_sum;
     350      214979 : }
     351             : 
     352             : /*******************************************************************************
     353             :  * \brief Performs a multiplication of two dbm_matrix_t matrices.
     354             :  *        See dbm_matrix.h for details.
     355             :  * \author Ole Schuett
     356             :  ******************************************************************************/
     357      197367 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
     358             :                   const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
     359             :                   const double beta, dbm_matrix_t *matrix_c,
     360             :                   const bool retain_sparsity, const double filter_eps,
     361             :                   int64_t *flop) {
     362             : 
     363      197367 :   assert(omp_get_num_threads() == 1);
     364             : 
     365             :   // Throughout the matrix multiplication code the "sum_index" and "free_index"
     366             :   // denote the summation (aka dummy) and free index from the Einstein notation.
     367      197367 :   const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
     368      197367 :   const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
     369      197367 :   const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
     370      197367 :   const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
     371             : 
     372             :   // Sanity check matrix dimensions.
     373      197367 :   assert(num_sum_index_a == num_sum_index_b);
     374      197367 :   assert(num_free_index_a == matrix_c->nrows);
     375      197367 :   assert(num_free_index_b == matrix_c->ncols);
     376             : 
     377             :   // Prepare matrix_c.
     378      197367 :   dbm_scale(matrix_c, beta);
     379             : 
     380             :   // Start uploading matrix_c to the GPU.
     381      197367 :   backend_context_t *ctx = backend_start(matrix_c);
     382             : 
     383             :   // Compute filter thresholds for each row.
     384      197367 :   float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
     385             : 
     386             :   // Redistribute matrix_a and matrix_b across MPI ranks.
     387      197367 :   dbm_comm_iterator_t *iter =
     388      197367 :       dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
     389             : 
     390             :   // Main loop.
     391      197367 :   *flop = 0;
     392      197367 :   dbm_pack_t *pack_a, *pack_b;
     393      412346 :   while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
     394      214979 :     backend_upload_packs(pack_a, pack_b, ctx);
     395      214979 :     multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
     396             :                    matrix_c, retain_sparsity, rows_max_eps, flop, ctx);
     397             :   }
     398             : 
     399             :   // Start downloading matrix_c from the GPU.
     400      197367 :   backend_download_results(ctx);
     401             : 
     402             :   // Wait for all other MPI ranks to complete, then release ressources.
     403      197367 :   dbm_comm_iterator_stop(iter);
     404      197367 :   free(rows_max_eps);
     405      197367 :   backend_stop(ctx);
     406             : 
     407             :   // Final filter pass.
     408      197367 :   dbm_filter(matrix_c, filter_eps);
     409      197367 : }
     410             : 
     411             : // EOF

Generated by: LCOV version 1.15