LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply_cpu.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:f56c6e3) Lines: 96.2 % 52 50
Test Date: 2025-11-15 06:45:58 Functions: 100.0 % 2 2

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : #include "dbm_multiply_cpu.h"
       8              : #include "dbm_hyperparams.h"
       9              : 
      10              : #include <assert.h>
      11              : #include <stddef.h>
      12              : #include <string.h>
      13              : 
      14              : #if defined(__LIBXSMM)
      15              : #include <libxsmm.h>
      16              : #if !defined(DBM_LIBXSMM_PREFETCH)
      17              : // #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_AL2_AHEAD
      18              : #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE
      19              : #endif
      20              : #if LIBXSMM_VERSION4(1, 17, 0, 3710) > LIBXSMM_VERSION_NUMBER
      21              : #define libxsmm_dispatch_gemm libxsmm_dispatch_gemm_v2
      22              : #endif
      23              : #endif
      24              : 
      25              : /*******************************************************************************
      26              :  * \brief Prototype for BLAS dgemm.
      27              :  * \author Ole Schuett
      28              :  ******************************************************************************/
      29              : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
      30              :             const int *k, const double *alpha, const double *a, const int *lda,
      31              :             const double *b, const int *ldb, const double *beta, double *c,
      32              :             const int *ldc);
      33              : 
      34              : /*******************************************************************************
      35              :  * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
      36              :  * \author Ole Schuett
      37              :  ******************************************************************************/
      38      5395004 : static inline void dbm_dgemm(const char transa, const char transb, const int m,
      39              :                              const int n, const int k, const double alpha,
      40              :                              const double *a, const int lda, const double *b,
      41              :                              const int ldb, const double beta, double *c,
      42              :                              const int ldc) {
      43              : 
      44      5395004 :   dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
      45              :          &ldc);
      46              : }
      47              : 
      48              : /*******************************************************************************
      49              :  * \brief Private hash function based on Szudzik's elegant pairing.
      50              :  *        Using unsigned int to return a positive number even after overflow.
      51              :  *        https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
      52              :  *        https://stackoverflow.com/a/13871379
      53              :  *        http://szudzik.com/ElegantPairing.pdf
      54              :  * \author Ole Schuett
      55              :  ******************************************************************************/
      56     55882708 : static inline unsigned int hash(const dbm_task_t task) {
      57     55882708 :   const unsigned int m = task.m, n = task.n, k = task.k;
      58     55882708 :   const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
      59     55882708 :   const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
      60     55882708 :   return mnk;
      61              : }
      62              : 
      63              : /*******************************************************************************
      64              :  * \brief Internal routine for executing the tasks in given batch on the CPU.
      65              :  * \author Ole Schuett
      66              :  ******************************************************************************/
      67       232460 : void dbm_multiply_cpu_process_batch(int ntasks, const dbm_task_t batch[ntasks],
      68              :                                     double alpha, const dbm_pack_t *pack_a,
      69              :                                     const dbm_pack_t *pack_b,
      70       232460 :                                     dbm_shard_t *shard_c, int options) {
      71              : 
      72       232460 :   if (0 >= ntasks) { // nothing to do
      73        35875 :     return;
      74              :   }
      75       196585 :   dbm_shard_allocate_promised_blocks(shard_c);
      76              : 
      77       196585 :   int batch_order[ntasks];
      78       196585 :   if (DBM_MULTIPLY_TASK_REORDER & options) {
      79              :     // Sort tasks approximately by m,n,k via bucket sort.
      80       196585 :     int buckets[DBM_BATCH_NUM_BUCKETS] = {0};
      81     28137939 :     for (int itask = 0; itask < ntasks; ++itask) {
      82     27941354 :       const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
      83     27941354 :       ++buckets[i];
      84              :     }
      85    196585000 :     for (int i = 1; i < DBM_BATCH_NUM_BUCKETS; ++i) {
      86    196388415 :       buckets[i] += buckets[i - 1];
      87              :     }
      88       196585 :     assert(buckets[DBM_BATCH_NUM_BUCKETS - 1] == ntasks);
      89     28137939 :     for (int itask = 0; itask < ntasks; ++itask) {
      90     27941354 :       const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
      91     27941354 :       --buckets[i];
      92     27941354 :       batch_order[buckets[i]] = itask;
      93              :     }
      94              :   } else {
      95            0 :     for (int itask = 0; itask < ntasks; ++itask) {
      96            0 :       batch_order[itask] = itask;
      97              :     }
      98              :   }
      99              : 
     100              : #if defined(__LIBXSMM)
     101              :   // Prepare arguments for libxsmm's kernel-dispatch.
     102       196585 :   const int flags = LIBXSMM_GEMM_FLAG_TRANS_B; // transa = "N", transb = "T"
     103       196585 :   const int prefetch = DBM_LIBXSMM_PREFETCH;
     104       196585 :   int kernel_m = 0, kernel_n = 0, kernel_k = 0;
     105              : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
     106              :   double *data_a_next = NULL, *data_b_next = NULL, *data_c_next = NULL;
     107              : #endif
     108              : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
     109       196585 :   libxsmm_gemmfunction kernel_func = NULL;
     110              : #else
     111              :   libxsmm_dmmfunction kernel_func = NULL;
     112              :   const double beta = 1.0;
     113              : #endif
     114              : #endif
     115              : 
     116              :   // Loop over tasks.
     117       196585 :   dbm_task_t task_next = batch[batch_order[0]];
     118     28137939 :   for (int itask = 0; itask < ntasks; ++itask) {
     119     27941354 :     const dbm_task_t task = task_next;
     120     27941354 :     task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
     121              : 
     122              : #if defined(__LIBXSMM)
     123     27941354 :     if (0 == (DBM_MULTIPLY_BLAS_LIBRARY & options) &&
     124     27254844 :         (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k)) {
     125      1613684 :       if (LIBXSMM_SMM(task.m, task.n, task.m, 1 /*assume in-$, no RFO*/,
     126              :                       sizeof(double))) {
     127              : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
     128      1568857 :         const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape(
     129              :             task.m, task.n, task.k, task.m /*lda*/, task.n /*ldb*/,
     130              :             task.m /*ldc*/, LIBXSMM_DATATYPE_F64 /*aprec*/,
     131              :             LIBXSMM_DATATYPE_F64 /*bprec*/, LIBXSMM_DATATYPE_F64 /*cprec*/,
     132              :             LIBXSMM_DATATYPE_F64 /*calcp*/);
     133      1568857 :         kernel_func =
     134              :             (LIBXSMM_FEQ(1.0, alpha)
     135      1272297 :                  ? libxsmm_dispatch_gemm(shape, (libxsmm_bitfield)flags,
     136              :                                          (libxsmm_bitfield)prefetch)
     137      1568857 :                  : NULL);
     138              : #else
     139              :         kernel_func = libxsmm_dmmdispatch(task.m, task.n, task.k, NULL /*lda*/,
     140              :                                           NULL /*ldb*/, NULL /*ldc*/, &alpha,
     141              :                                           &beta, &flags, &prefetch);
     142              : #endif
     143              :       } else {
     144              :         kernel_func = NULL;
     145              :       }
     146              :       kernel_m = task.m;
     147              :       kernel_n = task.n;
     148              :       kernel_k = task.k;
     149              :     }
     150              : #endif
     151              :     // gemm_param wants non-const data even for A and B
     152     27941354 :     double *const data_a = pack_a->data + task.offset_a;
     153     27941354 :     double *const data_b = pack_b->data + task.offset_b;
     154     27941354 :     double *const data_c = shard_c->data + task.offset_c;
     155              : 
     156              : #if defined(__LIBXSMM)
     157     27941354 :     if (kernel_func != NULL) {
     158              : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
     159     22546350 :       libxsmm_gemm_param gemm_param;
     160     22546350 :       gemm_param.a.primary = data_a;
     161     22546350 :       gemm_param.b.primary = data_b;
     162     22546350 :       gemm_param.c.primary = data_c;
     163              : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
     164              :       gemm_param.a.quaternary = pack_a->data + task_next.offset_a;
     165              :       gemm_param.b.quaternary = pack_b->data + task_next.offset_b;
     166              :       gemm_param.c.quaternary = shard_c->data + task_next.offset_c;
     167              : #endif
     168     22546350 :       kernel_func(&gemm_param);
     169              : #elif (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
     170              :       kernel_func(data_a, data_b, data_c, pack_a->data + task_next.offset_a,
     171              :                   pack_b->data + task_next.offset_b,
     172              :                   shard_c->data + task_next.offset_c);
     173              : #else
     174              :       kernel_func(data_a, data_b, data_c);
     175              : #endif
     176              :     } else
     177              : #endif
     178              :     { // Fallback to BLAS when libxsmm is not available.
     179      5395004 :       dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
     180              :                 task.n, 1.0, data_c, task.m);
     181              :     }
     182              :   }
     183              : }
     184              : 
     185              : // EOF
        

Generated by: LCOV version 2.0-1