Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_multiply_cpu.h"
8 : #include "dbm_hyperparams.h"
9 :
10 : #include <assert.h>
11 : #include <stddef.h>
12 :
13 : #if defined(__LIBXSMM)
14 : #include <libxsmm.h>
15 : #endif
16 : #if defined(__LIBXS)
17 : #include <libxs/libxs_gemm.h>
18 : #endif
19 :
20 : /*******************************************************************************
21 : * \brief Prototype for BLAS dgemm.
22 : * \author Ole Schuett
23 : ******************************************************************************/
24 : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
25 : const int *k, const double *alpha, const double *a, const int *lda,
26 : const double *b, const int *ldb, const double *beta, double *c,
27 : const int *ldc);
28 :
29 : /*******************************************************************************
30 : * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
31 : * \author Ole Schuett
32 : ******************************************************************************/
33 0 : static inline void dbm_dgemm(const char transa, const char transb, const int m,
34 : const int n, const int k, const double alpha,
35 : const double *a, const int lda, const double *b,
36 : const int ldb, const double beta, double *c,
37 : const int ldc) {
38 0 : dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
39 : &ldc);
40 : }
41 :
42 : /*******************************************************************************
43 : * \brief Private hash function based on Szudzik's elegant pairing.
44 : * Using unsigned int to return a positive number even after overflow.
45 : * https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
46 : * https://stackoverflow.com/a/13871379
47 : * http://szudzik.com/ElegantPairing.pdf
48 : * \author Ole Schuett
49 : ******************************************************************************/
50 58042318 : static inline unsigned int hash(const dbm_task_t task) {
51 58042318 : const unsigned int m = task.m, n = task.n, k = task.k;
52 58042318 : const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
53 58042318 : const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
54 58042318 : return mnk;
55 : }
56 :
57 : /*******************************************************************************
58 : * \brief Internal routine for executing the tasks in given batch on the CPU.
59 : * \author Ole Schuett
60 : ******************************************************************************/
61 263542 : void dbm_multiply_cpu_process_batch(int ntasks, const dbm_task_t batch[ntasks],
62 : double alpha, const dbm_pack_t *pack_a,
63 : const dbm_pack_t *pack_b,
64 263542 : dbm_shard_t *shard_c, int options) {
65 263542 : if (0 >= ntasks) { // nothing to do
66 41853 : return;
67 : }
68 221689 : dbm_shard_allocate_promised_blocks(shard_c);
69 :
70 221689 : int batch_order[ntasks];
71 221689 : if (DBM_MULTIPLY_TASK_REORDER & options) {
72 : // Sort tasks approximately by m,n,k via bucket sort.
73 221689 : int buckets[DBM_BATCH_NUM_BUCKETS] = {0};
74 29242848 : for (int itask = 0; itask < ntasks; ++itask) {
75 29021159 : const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
76 29021159 : ++buckets[i];
77 : }
78 221689000 : for (int i = 1; i < DBM_BATCH_NUM_BUCKETS; ++i) {
79 221467311 : buckets[i] += buckets[i - 1];
80 : }
81 221689 : assert(buckets[DBM_BATCH_NUM_BUCKETS - 1] == ntasks);
82 29242848 : for (int itask = 0; itask < ntasks; ++itask) {
83 29021159 : const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
84 29021159 : --buckets[i];
85 29021159 : batch_order[buckets[i]] = itask;
86 : }
87 : } else {
88 0 : for (int itask = 0; itask < ntasks; ++itask) {
89 0 : batch_order[itask] = itask;
90 : }
91 : }
92 :
93 : #if defined(__LIBXS)
94 221689 : const libxs_gemm_config_t *gemm_config = NULL;
95 221689 : int kernel_m = 0, kernel_n = 0, kernel_k = 0;
96 : #endif
97 :
98 : // Loop over tasks.
99 221689 : dbm_task_t task_next = batch[batch_order[0]];
100 29242848 : for (int itask = 0; itask < ntasks; ++itask) {
101 29021159 : const dbm_task_t task = task_next;
102 29021159 : task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
103 :
104 : #if defined(__LIBXS)
105 29021159 : if (0 == (DBM_MULTIPLY_BLAS_LIBRARY & options) &&
106 28299035 : (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k)) {
107 1668843 : const double beta = 1.0;
108 1668843 : gemm_config = libxs_gemm_dispatch(LIBXS_DATATYPE_F64, 'N', 'T', task.m,
109 : task.n, task.k, task.m, task.n, task.m,
110 : &alpha, &beta, NULL);
111 1668843 : kernel_m = task.m;
112 1668843 : kernel_n = task.n;
113 1668843 : kernel_k = task.k;
114 : }
115 : #endif
116 :
117 29021159 : double *const data_a = pack_a->data + task.offset_a;
118 29021159 : double *const data_b = pack_b->data + task.offset_b;
119 29021159 : double *const data_c = shard_c->data + task.offset_c;
120 :
121 : #if defined(__LIBXS)
122 29021159 : if (NULL != gemm_config) {
123 29021159 : libxs_gemm_call(gemm_config, data_a, data_b, data_c);
124 : } else
125 : #endif
126 : {
127 0 : dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
128 : task.n, 1.0, data_c, task.m);
129 : }
130 : }
131 : }
132 :
133 : // EOF
|