Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_multiply_cpu.h"
8 : #include "dbm_hyperparams.h"
9 :
10 : #include <assert.h>
11 : #include <stddef.h>
12 : #include <string.h>
13 :
14 : #if defined(__LIBXSMM)
15 : #include <libxsmm.h>
16 : #if !defined(DBM_LIBXSMM_PREFETCH)
17 : // #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_AL2_AHEAD
18 : #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE
19 : #endif
20 : #if LIBXSMM_VERSION4(1, 17, 0, 3710) > LIBXSMM_VERSION_NUMBER
21 : #define libxsmm_dispatch_gemm libxsmm_dispatch_gemm_v2
22 : #endif
23 : #endif
24 :
25 : /*******************************************************************************
26 : * \brief Prototype for BLAS dgemm.
27 : * \author Ole Schuett
28 : ******************************************************************************/
29 : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
30 : const int *k, const double *alpha, const double *a, const int *lda,
31 : const double *b, const int *ldb, const double *beta, double *c,
32 : const int *ldc);
33 :
34 : /*******************************************************************************
35 : * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
36 : * \author Ole Schuett
37 : ******************************************************************************/
38 5395004 : static inline void dbm_dgemm(const char transa, const char transb, const int m,
39 : const int n, const int k, const double alpha,
40 : const double *a, const int lda, const double *b,
41 : const int ldb, const double beta, double *c,
42 : const int ldc) {
43 :
44 5395004 : dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
45 : &ldc);
46 : }
47 :
48 : /*******************************************************************************
49 : * \brief Private hash function based on Szudzik's elegant pairing.
50 : * Using unsigned int to return a positive number even after overflow.
51 : * https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
52 : * https://stackoverflow.com/a/13871379
53 : * http://szudzik.com/ElegantPairing.pdf
54 : * \author Ole Schuett
55 : ******************************************************************************/
56 55882708 : static inline unsigned int hash(const dbm_task_t task) {
57 55882708 : const unsigned int m = task.m, n = task.n, k = task.k;
58 55882708 : const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
59 55882708 : const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
60 55882708 : return mnk;
61 : }
62 :
63 : /*******************************************************************************
64 : * \brief Internal routine for executing the tasks in given batch on the CPU.
65 : * \author Ole Schuett
66 : ******************************************************************************/
67 232460 : void dbm_multiply_cpu_process_batch(int ntasks, const dbm_task_t batch[ntasks],
68 : double alpha, const dbm_pack_t *pack_a,
69 : const dbm_pack_t *pack_b,
70 232460 : dbm_shard_t *shard_c, int options) {
71 :
72 232460 : if (0 >= ntasks) { // nothing to do
73 35875 : return;
74 : }
75 196585 : dbm_shard_allocate_promised_blocks(shard_c);
76 :
77 196585 : int batch_order[ntasks];
78 196585 : if (DBM_MULTIPLY_TASK_REORDER & options) {
79 : // Sort tasks approximately by m,n,k via bucket sort.
80 196585 : int buckets[DBM_BATCH_NUM_BUCKETS] = {0};
81 28137939 : for (int itask = 0; itask < ntasks; ++itask) {
82 27941354 : const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
83 27941354 : ++buckets[i];
84 : }
85 196585000 : for (int i = 1; i < DBM_BATCH_NUM_BUCKETS; ++i) {
86 196388415 : buckets[i] += buckets[i - 1];
87 : }
88 196585 : assert(buckets[DBM_BATCH_NUM_BUCKETS - 1] == ntasks);
89 28137939 : for (int itask = 0; itask < ntasks; ++itask) {
90 27941354 : const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
91 27941354 : --buckets[i];
92 27941354 : batch_order[buckets[i]] = itask;
93 : }
94 : } else {
95 0 : for (int itask = 0; itask < ntasks; ++itask) {
96 0 : batch_order[itask] = itask;
97 : }
98 : }
99 :
100 : #if defined(__LIBXSMM)
101 : // Prepare arguments for libxsmm's kernel-dispatch.
102 196585 : const int flags = LIBXSMM_GEMM_FLAG_TRANS_B; // transa = "N", transb = "T"
103 196585 : const int prefetch = DBM_LIBXSMM_PREFETCH;
104 196585 : int kernel_m = 0, kernel_n = 0, kernel_k = 0;
105 : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
106 : double *data_a_next = NULL, *data_b_next = NULL, *data_c_next = NULL;
107 : #endif
108 : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
109 196585 : libxsmm_gemmfunction kernel_func = NULL;
110 : #else
111 : libxsmm_dmmfunction kernel_func = NULL;
112 : const double beta = 1.0;
113 : #endif
114 : #endif
115 :
116 : // Loop over tasks.
117 196585 : dbm_task_t task_next = batch[batch_order[0]];
118 28137939 : for (int itask = 0; itask < ntasks; ++itask) {
119 27941354 : const dbm_task_t task = task_next;
120 27941354 : task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
121 :
122 : #if defined(__LIBXSMM)
123 27941354 : if (0 == (DBM_MULTIPLY_BLAS_LIBRARY & options) &&
124 27254844 : (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k)) {
125 1613684 : if (LIBXSMM_SMM(task.m, task.n, task.m, 1 /*assume in-$, no RFO*/,
126 : sizeof(double))) {
127 : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
128 1568857 : const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape(
129 : task.m, task.n, task.k, task.m /*lda*/, task.n /*ldb*/,
130 : task.m /*ldc*/, LIBXSMM_DATATYPE_F64 /*aprec*/,
131 : LIBXSMM_DATATYPE_F64 /*bprec*/, LIBXSMM_DATATYPE_F64 /*cprec*/,
132 : LIBXSMM_DATATYPE_F64 /*calcp*/);
133 1568857 : kernel_func =
134 : (LIBXSMM_FEQ(1.0, alpha)
135 1272297 : ? libxsmm_dispatch_gemm(shape, (libxsmm_bitfield)flags,
136 : (libxsmm_bitfield)prefetch)
137 1568857 : : NULL);
138 : #else
139 : kernel_func = libxsmm_dmmdispatch(task.m, task.n, task.k, NULL /*lda*/,
140 : NULL /*ldb*/, NULL /*ldc*/, &alpha,
141 : &beta, &flags, &prefetch);
142 : #endif
143 : } else {
144 : kernel_func = NULL;
145 : }
146 : kernel_m = task.m;
147 : kernel_n = task.n;
148 : kernel_k = task.k;
149 : }
150 : #endif
151 : // gemm_param wants non-const data even for A and B
152 27941354 : double *const data_a = pack_a->data + task.offset_a;
153 27941354 : double *const data_b = pack_b->data + task.offset_b;
154 27941354 : double *const data_c = shard_c->data + task.offset_c;
155 :
156 : #if defined(__LIBXSMM)
157 27941354 : if (kernel_func != NULL) {
158 : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
159 22546350 : libxsmm_gemm_param gemm_param;
160 22546350 : gemm_param.a.primary = data_a;
161 22546350 : gemm_param.b.primary = data_b;
162 22546350 : gemm_param.c.primary = data_c;
163 : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
164 : gemm_param.a.quaternary = pack_a->data + task_next.offset_a;
165 : gemm_param.b.quaternary = pack_b->data + task_next.offset_b;
166 : gemm_param.c.quaternary = shard_c->data + task_next.offset_c;
167 : #endif
168 22546350 : kernel_func(&gemm_param);
169 : #elif (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
170 : kernel_func(data_a, data_b, data_c, pack_a->data + task_next.offset_a,
171 : pack_b->data + task_next.offset_b,
172 : shard_c->data + task_next.offset_c);
173 : #else
174 : kernel_func(data_a, data_b, data_c);
175 : #endif
176 : } else
177 : #endif
178 : { // Fallback to BLAS when libxsmm is not available.
179 5395004 : dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
180 : task.n, 1.0, data_c, task.m);
181 : }
182 : }
183 : }
184 :
185 : // EOF
|