Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <limits.h>
10 : #include <stdlib.h>
11 : #include <string.h>
12 :
13 : #include "../offload/offload_runtime.h"
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_library.h"
16 : #include "dbm_multiply.h"
17 : #include "dbm_multiply_comm.h"
18 : #include "dbm_multiply_cpu.h"
19 : #include "dbm_multiply_gpu.h"
20 : #include "dbm_multiply_internal.h"
21 :
22 : /*******************************************************************************
23 : * \brief Returns the larger of two given integer (missing from the C standard).
24 : * \author Ole Schuett
25 : ******************************************************************************/
26 : static inline int imax(int x, int y) { return (x > y ? x : y); }
27 :
28 : /*******************************************************************************
29 : * \brief Updates the min/max of a range of values (initially {INT_MAX, 0}).
30 : * \author Hans Pabst
31 : ******************************************************************************/
32 : static inline void min_max(int result[2], int value) {
33 : if (value < result[0]) {
34 : result[0] = value;
35 : }
36 : if (result[1] < value) {
37 : result[1] = value;
38 : }
39 : }
40 :
41 : /*******************************************************************************
42 : * \brief Private routine for computing the max filter threshold for each row.
43 : * \author Ole Schuett
44 : ******************************************************************************/
45 157017 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
46 : const double filter_eps) {
47 157017 : const int nrows = (trans) ? matrix->ncols : matrix->nrows;
48 157017 : int *nblocks_per_row = calloc(nrows, sizeof(int));
49 157017 : float *row_max_eps = malloc(nrows * sizeof(float));
50 :
51 157017 : #pragma omp parallel
52 : {
53 : #pragma omp for
54 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
55 : dbm_shard_t *shard = &matrix->shards[ishard];
56 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
57 : const dbm_block_t *blk = &shard->blocks[iblock];
58 : const int row = (trans) ? blk->col : blk->row;
59 : #pragma omp atomic
60 : nblocks_per_row[row]++;
61 : }
62 : }
63 : #pragma omp single
64 : dbm_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
65 : #pragma omp barrier
66 : #pragma omp for
67 : for (int i = 0; i < nrows; i++) {
68 : const float f =
69 : ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
70 : row_max_eps[i] = f * f;
71 : }
72 : } // end of omp parallel region
73 :
74 157017 : free(nblocks_per_row);
75 157017 : return row_max_eps; // Ownership of row_max_eps transfers to caller.
76 : }
77 :
78 : /*******************************************************************************
79 : * \brief Private struct for storing the context of the multiplication backend.
80 : * \author Ole Schuett
81 : ******************************************************************************/
82 : typedef struct {
83 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
84 : dbm_multiply_gpu_context_t gpu;
85 : #endif
86 : } backend_context_t;
87 :
88 : /*******************************************************************************
89 : * \brief Private routine for intializing the multiplication backend.
90 : * \author Ole Schuett
91 : ******************************************************************************/
92 157017 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
93 157017 : backend_context_t *ctx = calloc(1, sizeof(backend_context_t));
94 :
95 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
96 : dbm_multiply_gpu_start(MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
97 : matrix_c->shards, &ctx->gpu);
98 : #else
99 157017 : (void)matrix_c; // mark as used
100 : #endif
101 :
102 157017 : return ctx;
103 : }
104 :
105 : /*******************************************************************************
106 : * \brief Private routine for handing newly arrived packs to the backend.
107 : * \author Ole Schuett
108 : ******************************************************************************/
109 0 : static void backend_upload_packs(const dbm_pack_t *pack_a,
110 : const dbm_pack_t *pack_b,
111 : backend_context_t *ctx) {
112 :
113 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
114 : dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
115 : #else
116 0 : (void)pack_a; // mark as used
117 0 : (void)pack_b;
118 0 : (void)ctx;
119 : #endif
120 0 : }
121 :
122 : /*******************************************************************************
123 : * \brief Private routine for sending a batch to the multiplication backend.
124 : * \author Ole Schuett
125 : ******************************************************************************/
126 164044 : static void backend_process_batch(const int ntasks, dbm_task_t batch[ntasks],
127 : const int mnk_range[3][2], const double alpha,
128 : const dbm_pack_t *pack_a,
129 : const dbm_pack_t *pack_b, const int kshard,
130 : dbm_shard_t *shard_c,
131 : backend_context_t *ctx) {
132 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
133 : (void)pack_a; // mark as used
134 : (void)pack_b;
135 : (void)shard_c;
136 : dbm_multiply_gpu_process_batch(ntasks, batch, mnk_range, alpha, kshard,
137 : &ctx->gpu);
138 : #else
139 164044 : (void)mnk_range; // mark as used
140 164044 : (void)kshard;
141 164044 : (void)ctx;
142 164044 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b, shard_c);
143 : #endif
144 164044 : }
145 :
146 : /*******************************************************************************
147 : * \brief Private routine for downloading results of the multiplication backend.
148 : * \author Ole Schuett
149 : ******************************************************************************/
150 0 : static void backend_download_results(backend_context_t *ctx) {
151 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
152 : dbm_multiply_gpu_download_results(&ctx->gpu);
153 : #else
154 0 : (void)ctx; // mark as used
155 : #endif
156 0 : }
157 :
158 : /*******************************************************************************
159 : * \brief Private routine for shutting down the multiplication backend.
160 : * \author Ole Schuett
161 : ******************************************************************************/
162 157017 : static void backend_stop(backend_context_t *ctx) {
163 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
164 : dbm_multiply_gpu_stop(&ctx->gpu);
165 : #endif
166 157017 : free(ctx);
167 157017 : }
168 :
169 : /*******************************************************************************
170 : * \brief Private routine for multipling two packs.
171 : * \author Ole Schuett
172 : ******************************************************************************/
173 163677 : static void multiply_packs(const bool transa, const bool transb,
174 : const double alpha, const dbm_pack_t *pack_a,
175 : const dbm_pack_t *pack_b,
176 : const dbm_matrix_t *matrix_a,
177 : const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
178 : const bool retain_sparsity,
179 : const float *rows_max_eps, int64_t *flop,
180 163677 : backend_context_t *ctx) {
181 163677 : const float alpha2 = alpha * alpha;
182 163677 : int64_t flop_sum = 0;
183 :
184 163677 : const int nshard_rows = matrix_c->dist->rows.nshards;
185 163677 : const int nshard_cols = matrix_c->dist->cols.nshards;
186 163677 : int shard_row_start[nshard_rows], shard_col_start[nshard_cols];
187 163677 : memset(shard_row_start, 0, nshard_rows * sizeof(int));
188 163677 : memset(shard_col_start, 0, nshard_cols * sizeof(int));
189 :
190 163677 : const int *sum_index_sizes_a =
191 : (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
192 163677 : const int *sum_index_sizes_b =
193 : (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
194 163677 : const int *free_index_sizes_a =
195 : (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
196 163677 : const int *free_index_sizes_b =
197 : (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
198 :
199 163677 : #pragma omp parallel reduction(+ : flop_sum)
200 : {
201 :
202 : // Blocks are ordered first by shard. Creating lookup tables of boundaries.
203 : #pragma omp for
204 : for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
205 : const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
206 : const int prev_shard_row =
207 : pack_a->blocks[iblock - 1].free_index % nshard_rows;
208 : if (prev_shard_row != shard_row) {
209 : shard_row_start[shard_row] = iblock;
210 : }
211 : }
212 : #pragma omp for
213 : for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
214 : const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
215 : const int prev_shard_col =
216 : pack_b->blocks[jblock - 1].free_index % nshard_cols;
217 : if (prev_shard_col != shard_col) {
218 : shard_col_start[shard_col] = jblock;
219 : }
220 : }
221 :
222 : #pragma omp for collapse(2) schedule(dynamic)
223 : for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
224 : for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
225 : const int ishard = shard_row * nshard_cols + shard_col;
226 : dbm_shard_t *shard_c = &matrix_c->shards[ishard];
227 : dbm_task_t batch[MAX_BATCH_SIZE];
228 : int mnk_range[][2] = {{INT_MAX, 0}, {INT_MAX, 0}, {INT_MAX, 0}};
229 : int ntasks = 0;
230 :
231 : // Use a merge-join to find pairs of blocks with matching sum indices.
232 : // This utilizes that blocks within a shard are ordered by sum_index.
233 : const int iblock_start = shard_row_start[shard_row];
234 : int jblock_start = shard_col_start[shard_col];
235 : for (int iblock = iblock_start; iblock < pack_a->nblocks; iblock++) {
236 : const dbm_pack_block_t *blk_a = &pack_a->blocks[iblock];
237 : if (blk_a->free_index % nshard_rows != shard_row) {
238 : break;
239 : }
240 : for (int jblock = jblock_start; jblock < pack_b->nblocks; jblock++) {
241 : const dbm_pack_block_t *blk_b = &pack_b->blocks[jblock];
242 : if (blk_b->free_index % nshard_cols != shard_col) {
243 : break;
244 : }
245 : if (blk_a->sum_index < blk_b->sum_index) {
246 : break;
247 : }
248 : if (blk_a->sum_index > blk_b->sum_index) {
249 : jblock_start++;
250 : continue;
251 : }
252 : // Found block pair with blk_a->sum_index == blk_b->sum_index.
253 :
254 : // Check norms.
255 : const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
256 : if (result_norm < rows_max_eps[blk_a->free_index]) {
257 : continue;
258 : }
259 :
260 : // Check block sizes.
261 : const int m = free_index_sizes_a[blk_a->free_index];
262 : const int n = free_index_sizes_b[blk_b->free_index];
263 : const int k = sum_index_sizes_a[blk_a->sum_index];
264 : assert(m == matrix_c->row_sizes[blk_a->free_index]);
265 : assert(n == matrix_c->col_sizes[blk_b->free_index]);
266 : assert(k == sum_index_sizes_b[blk_b->sum_index]);
267 :
268 : // Get C block.
269 : const int row = blk_a->free_index, col = blk_b->free_index;
270 : dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
271 : if (blk_c == NULL && retain_sparsity) {
272 : continue;
273 : } else if (blk_c == NULL) {
274 : assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
275 : assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
276 : matrix_c->dist->my_rank);
277 : blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
278 : }
279 :
280 : // Count flops.
281 : dbm_library_counter_increment(m, n, k);
282 : const int task_flops = 2 * m * n * k;
283 : flop_sum += task_flops;
284 : if (task_flops == 0) {
285 : continue;
286 : }
287 :
288 : // Add block multiplication to batch.
289 : batch[ntasks].m = m;
290 : batch[ntasks].n = n;
291 : batch[ntasks].k = k;
292 : batch[ntasks].offset_a = blk_a->offset;
293 : batch[ntasks].offset_b = blk_b->offset;
294 : batch[ntasks].offset_c = blk_c->offset;
295 : ntasks++;
296 :
297 : // track MxN-shape covering an entire batch
298 : min_max(mnk_range[0], m);
299 : min_max(mnk_range[1], n);
300 : min_max(mnk_range[2], k);
301 :
302 : if (ntasks == MAX_BATCH_SIZE) {
303 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a,
304 : pack_b, ishard, shard_c, ctx);
305 : mnk_range[0][0] = mnk_range[1][0] = mnk_range[2][0] = INT_MAX;
306 : mnk_range[0][1] = mnk_range[1][1] = mnk_range[2][1] = 0;
307 : ntasks = 0;
308 : }
309 : }
310 : }
311 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a, pack_b,
312 : ishard, shard_c, ctx);
313 : }
314 : }
315 : }
316 163677 : *flop += flop_sum;
317 163677 : }
318 :
319 : /*******************************************************************************
320 : * \brief Performs a multiplication of two dbm_matrix_t matrices.
321 : * See dbm_matrix.h for details.
322 : * \author Ole Schuett
323 : ******************************************************************************/
324 157017 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
325 : const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
326 : const double beta, dbm_matrix_t *matrix_c,
327 : const bool retain_sparsity, const double filter_eps,
328 : int64_t *flop) {
329 :
330 157017 : assert(omp_get_num_threads() == 1);
331 :
332 : // Throughout the matrix multiplication code the "sum_index" and "free_index"
333 : // denote the summation (aka dummy) and free index from the Einstein notation.
334 157017 : const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
335 157017 : const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
336 157017 : const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
337 157017 : const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
338 :
339 : // Sanity check matrix dimensions.
340 157017 : assert(num_sum_index_a == num_sum_index_b);
341 157017 : assert(num_free_index_a == matrix_c->nrows);
342 157017 : assert(num_free_index_b == matrix_c->ncols);
343 :
344 : // Prepare matrix_c.
345 157017 : dbm_scale(matrix_c, beta);
346 :
347 : // Start uploading matrix_c to the GPU.
348 157017 : backend_context_t *ctx = backend_start(matrix_c);
349 :
350 : // Compute filter thresholds for each row.
351 157017 : float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
352 :
353 : // Redistribute matrix_a and matrix_b across MPI ranks.
354 157017 : dbm_comm_iterator_t *iter =
355 157017 : dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
356 :
357 : // Main loop.
358 157017 : *flop = 0;
359 157017 : dbm_pack_t *pack_a, *pack_b;
360 320694 : while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
361 163677 : backend_upload_packs(pack_a, pack_b, ctx);
362 163677 : multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
363 : matrix_c, retain_sparsity, rows_max_eps, flop, ctx);
364 : }
365 :
366 : // Start downloading matrix_c from the GPU.
367 157017 : backend_download_results(ctx);
368 :
369 : // Wait for all other MPI ranks to complete, then release ressources.
370 157017 : dbm_comm_iterator_stop(iter);
371 157017 : free(rows_max_eps);
372 157017 : backend_stop(ctx);
373 :
374 : // Compute average flops per rank.
375 157017 : dbm_mpi_sum_int64(flop, 1, matrix_c->dist->comm);
376 157017 : *flop = (*flop + matrix_c->dist->nranks - 1) / matrix_c->dist->nranks;
377 :
378 : // Final filter pass.
379 157017 : dbm_filter(matrix_c, filter_eps);
380 157017 : }
381 :
382 : // EOF
|