LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply_cpu.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:85b8a9b) Lines: 89.1 % 46 41
Test Date: 2026-06-14 06:48:14 Functions: 100.0 % 2 2

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : #include "dbm_multiply_cpu.h"
       8              : #include "dbm_hyperparams.h"
       9              : 
      10              : #include <assert.h>
      11              : #include <stddef.h>
      12              : 
      13              : #if defined(__LIBXSMM)
      14              : #include <libxsmm.h>
      15              : #endif
      16              : #if defined(__LIBXS)
      17              : #include <libxs/libxs_gemm.h>
      18              : #endif
      19              : 
      20              : /*******************************************************************************
      21              :  * \brief Prototype for BLAS dgemm.
      22              :  * \author Ole Schuett
      23              :  ******************************************************************************/
      24              : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
      25              :             const int *k, const double *alpha, const double *a, const int *lda,
      26              :             const double *b, const int *ldb, const double *beta, double *c,
      27              :             const int *ldc);
      28              : 
      29              : /*******************************************************************************
      30              :  * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
      31              :  * \author Ole Schuett
      32              :  ******************************************************************************/
      33            0 : static inline void dbm_dgemm(const char transa, const char transb, const int m,
      34              :                              const int n, const int k, const double alpha,
      35              :                              const double *a, const int lda, const double *b,
      36              :                              const int ldb, const double beta, double *c,
      37              :                              const int ldc) {
      38            0 :   dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
      39              :          &ldc);
      40              : }
      41              : 
      42              : /*******************************************************************************
      43              :  * \brief Private hash function based on Szudzik's elegant pairing.
      44              :  *        Using unsigned int to return a positive number even after overflow.
      45              :  *        https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
      46              :  *        https://stackoverflow.com/a/13871379
      47              :  *        http://szudzik.com/ElegantPairing.pdf
      48              :  * \author Ole Schuett
      49              :  ******************************************************************************/
      50     58042318 : static inline unsigned int hash(const dbm_task_t task) {
      51     58042318 :   const unsigned int m = task.m, n = task.n, k = task.k;
      52     58042318 :   const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
      53     58042318 :   const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
      54     58042318 :   return mnk;
      55              : }
      56              : 
      57              : /*******************************************************************************
      58              :  * \brief Internal routine for executing the tasks in given batch on the CPU.
      59              :  * \author Ole Schuett
      60              :  ******************************************************************************/
      61       263542 : void dbm_multiply_cpu_process_batch(int ntasks, const dbm_task_t batch[ntasks],
      62              :                                     double alpha, const dbm_pack_t *pack_a,
      63              :                                     const dbm_pack_t *pack_b,
      64       263542 :                                     dbm_shard_t *shard_c, int options) {
      65       263542 :   if (0 >= ntasks) { // nothing to do
      66        41853 :     return;
      67              :   }
      68       221689 :   dbm_shard_allocate_promised_blocks(shard_c);
      69              : 
      70       221689 :   int batch_order[ntasks];
      71       221689 :   if (DBM_MULTIPLY_TASK_REORDER & options) {
      72              :     // Sort tasks approximately by m,n,k via bucket sort.
      73       221689 :     int buckets[DBM_BATCH_NUM_BUCKETS] = {0};
      74     29242848 :     for (int itask = 0; itask < ntasks; ++itask) {
      75     29021159 :       const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
      76     29021159 :       ++buckets[i];
      77              :     }
      78    221689000 :     for (int i = 1; i < DBM_BATCH_NUM_BUCKETS; ++i) {
      79    221467311 :       buckets[i] += buckets[i - 1];
      80              :     }
      81       221689 :     assert(buckets[DBM_BATCH_NUM_BUCKETS - 1] == ntasks);
      82     29242848 :     for (int itask = 0; itask < ntasks; ++itask) {
      83     29021159 :       const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
      84     29021159 :       --buckets[i];
      85     29021159 :       batch_order[buckets[i]] = itask;
      86              :     }
      87              :   } else {
      88            0 :     for (int itask = 0; itask < ntasks; ++itask) {
      89            0 :       batch_order[itask] = itask;
      90              :     }
      91              :   }
      92              : 
      93              : #if defined(__LIBXS)
      94       221689 :   const libxs_gemm_config_t *gemm_config = NULL;
      95       221689 :   int kernel_m = 0, kernel_n = 0, kernel_k = 0;
      96              : #endif
      97              : 
      98              :   // Loop over tasks.
      99       221689 :   dbm_task_t task_next = batch[batch_order[0]];
     100     29242848 :   for (int itask = 0; itask < ntasks; ++itask) {
     101     29021159 :     const dbm_task_t task = task_next;
     102     29021159 :     task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
     103              : 
     104              : #if defined(__LIBXS)
     105     29021159 :     if (0 == (DBM_MULTIPLY_BLAS_LIBRARY & options) &&
     106     28299035 :         (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k)) {
     107      1668843 :       const double beta = 1.0;
     108      1668843 :       gemm_config = libxs_gemm_dispatch(LIBXS_DATATYPE_F64, 'N', 'T', task.m,
     109              :                                         task.n, task.k, task.m, task.n, task.m,
     110              :                                         &alpha, &beta, NULL);
     111      1668843 :       kernel_m = task.m;
     112      1668843 :       kernel_n = task.n;
     113      1668843 :       kernel_k = task.k;
     114              :     }
     115              : #endif
     116              : 
     117     29021159 :     double *const data_a = pack_a->data + task.offset_a;
     118     29021159 :     double *const data_b = pack_b->data + task.offset_b;
     119     29021159 :     double *const data_c = shard_c->data + task.offset_c;
     120              : 
     121              : #if defined(__LIBXS)
     122     29021159 :     if (NULL != gemm_config) {
     123     29021159 :       libxs_gemm_call(gemm_config, data_a, data_b, data_c);
     124              :     } else
     125              : #endif
     126              :     {
     127            0 :       dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
     128              :                 task.n, 1.0, data_c, task.m);
     129              :     }
     130              :   }
     131              : }
     132              : 
     133              : // EOF
        

Generated by: LCOV version 2.0-1