Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_multiply_comm.h"
8 : #include "../mpiwrap/cp_mpi.h"
9 : #include "../offload/offload_mempool.h"
10 :
11 : #include <assert.h>
12 : #include <limits.h>
13 : #include <stdlib.h>
14 : #include <string.h>
15 :
16 : #if 1
17 : #define DBM_MULTIPLY_COMM_MEMPOOL
18 : #endif
19 :
20 : /*******************************************************************************
21 : * \brief Private routine for computing greatest common divisor of two numbers.
22 : * \author Ole Schuett
23 : ******************************************************************************/
24 505454 : static int gcd(const int a, const int b) {
25 505454 : if (a == 0) {
26 : return b;
27 : }
28 263396 : return gcd(b % a, a); // Euclid's algorithm.
29 : }
30 :
31 : /*******************************************************************************
32 : * \brief Private routine for computing least common multiple of two numbers.
33 : * \author Ole Schuett
34 : ******************************************************************************/
35 242058 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
36 :
37 : /*******************************************************************************
38 : * \brief Private routine for converting element counts to byte counts.
39 : * \author Hans Pabst
40 : ******************************************************************************/
41 2321772 : static int checked_byte_count(const int nelements, const size_t element_size) {
42 2321772 : assert(0 <= nelements);
43 2321772 : assert(element_size <= INT_MAX);
44 2321772 : assert(nelements <= INT_MAX / (int)element_size);
45 2321772 : return nelements * (int)element_size;
46 : }
47 :
48 : /*******************************************************************************
49 : * \brief Private routine for computing the sum of the given integers.
50 : * \author Ole Schuett
51 : ******************************************************************************/
52 : static inline int isum(const int n, const int input[n]) {
53 : int output = 0;
54 3290004 : for (int i = 0; i < n; i++) {
55 1139480 : output += input[i];
56 : }
57 1011044 : return output;
58 : }
59 :
60 : /*******************************************************************************
61 : * \brief Private routine for computing the cumulative sums of given numbers.
62 : * \author Ole Schuett and Hans Pabst
63 : ******************************************************************************/
64 1011044 : static inline void icumsum(const int n, const int input[n], int output[n]) {
65 1011044 : int oval = output[0] = 0, ival = input[0];
66 1139480 : for (int i = 1; i < n; i++) {
67 128436 : output[i] = (oval += ival);
68 128436 : ival = input[i];
69 : }
70 : }
71 :
72 : /*******************************************************************************
73 : * \brief Private routine computing received data counts from block metadata.
74 : * \author Hans Pabst
75 : ******************************************************************************/
76 505522 : static void compute_data_recv_count(const int nranks,
77 : const int blks_recv_count[nranks],
78 : const int blks_recv_displ[nranks],
79 : const int free_index_sizes[],
80 : const int sum_index_sizes[],
81 : const dbm_pack_block_t blks_recv[],
82 : int data_recv_count[nranks]) {
83 505522 : memset(data_recv_count, 0, nranks * sizeof(int));
84 1075262 : for (int irank = 0; irank < nranks; irank++) {
85 14386998 : for (int i = 0; i < blks_recv_count[irank]; i++) {
86 13817258 : const dbm_pack_block_t *const blk =
87 13817258 : &blks_recv[blks_recv_displ[irank] + i];
88 13817258 : const int block_size =
89 13817258 : free_index_sizes[blk->free_index] * sum_index_sizes[blk->sum_index];
90 13817258 : assert(block_size >= 0);
91 13817258 : assert(data_recv_count[irank] <= INT_MAX - block_size);
92 13817258 : data_recv_count[irank] += block_size;
93 : }
94 : }
95 505522 : }
96 :
97 : /*******************************************************************************
98 : * \brief Private struct used for planing during pack_matrix.
99 : * \author Ole Schuett
100 : ******************************************************************************/
101 : typedef struct {
102 : const dbm_block_t *blk; // source block
103 : int rank; // target mpi rank
104 : int row_size;
105 : int col_size;
106 : } plan_t;
107 :
108 : /*******************************************************************************
109 : * \brief Private routine for calculating tick indices in pack plans.
110 : * \author Maximilian Graml
111 : ******************************************************************************/
112 : static inline unsigned long long calculate_tick_index(int sum_index,
113 : int nticks) {
114 : // 1021 is used as a random prime to scramble the index
115 : return ((unsigned long long)sum_index * 1021ULL) % (unsigned long long)nticks;
116 : }
117 :
118 : /*******************************************************************************
119 : * \brief Private routine for planing packs.
120 : * \author Ole Schuett
121 : ******************************************************************************/
122 484116 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
123 : const dbm_matrix_t *matrix,
124 : const cp_mpi_comm_t comm,
125 : const dbm_dist_1d_t *dist_indices,
126 : const dbm_dist_1d_t *dist_ticks, const int nticks,
127 : const int npacks, plan_t *plans_per_pack[npacks],
128 : int nblks_per_pack[npacks],
129 : int ndata_per_pack[npacks]) {
130 484116 : memset(nblks_per_pack, 0, npacks * sizeof(int));
131 484116 : memset(ndata_per_pack, 0, npacks * sizeof(int));
132 :
133 484116 : #pragma omp parallel
134 : {
135 : // 1st pass: Compute number of blocks that will be send in each pack.
136 : int nblks_mythread[npacks];
137 : memset(nblks_mythread, 0, npacks * sizeof(int));
138 : #pragma omp for schedule(static)
139 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
140 : dbm_shard_t *shard = &matrix->shards[ishard];
141 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
142 : const dbm_block_t *blk = &shard->blocks[iblock];
143 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
144 : unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
145 : const int ipack = itick64 / dist_ticks->nranks;
146 : nblks_mythread[ipack]++;
147 : }
148 : }
149 :
150 : // Sum nblocks across threads and allocate arrays for plans.
151 : #pragma omp critical
152 : for (int ipack = 0; ipack < npacks; ipack++) {
153 : nblks_per_pack[ipack] += nblks_mythread[ipack];
154 : nblks_mythread[ipack] = nblks_per_pack[ipack];
155 : }
156 : #pragma omp barrier
157 : #pragma omp for
158 : for (int ipack = 0; ipack < npacks; ipack++) {
159 : const int nblks = nblks_per_pack[ipack];
160 : plans_per_pack[ipack] = malloc(nblks * sizeof(plan_t));
161 : assert(plans_per_pack[ipack] != NULL || nblks == 0);
162 : }
163 :
164 : // 2nd pass: Plan where to send each block.
165 : int ndata_mythread[npacks];
166 : memset(ndata_mythread, 0, npacks * sizeof(int));
167 : #pragma omp for schedule(static) // Need static to match previous loop.
168 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
169 : dbm_shard_t *shard = &matrix->shards[ishard];
170 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
171 : const dbm_block_t *blk = &shard->blocks[iblock];
172 : const int free_index = (trans_matrix) ? blk->col : blk->row;
173 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
174 : unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
175 : const int ipack = itick64 / dist_ticks->nranks;
176 : // Compute rank to which this block should be sent.
177 : const int coord_free_idx = dist_indices->index2coord[free_index];
178 : const int coord_sum_idx = itick64 % dist_ticks->nranks;
179 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
180 : (trans_dist) ? coord_free_idx : coord_sum_idx};
181 : const int rank = cp_mpi_cart_rank(comm, coords);
182 : const int row_size = matrix->row_sizes[blk->row];
183 : const int col_size = matrix->col_sizes[blk->col];
184 : ndata_mythread[ipack] += row_size * col_size;
185 : // Create plan.
186 : const int iplan = --nblks_mythread[ipack];
187 : plans_per_pack[ipack][iplan].blk = blk;
188 : plans_per_pack[ipack][iplan].rank = rank;
189 : plans_per_pack[ipack][iplan].row_size = row_size;
190 : plans_per_pack[ipack][iplan].col_size = col_size;
191 : }
192 : }
193 : #pragma omp critical
194 : for (int ipack = 0; ipack < npacks; ipack++) {
195 : ndata_per_pack[ipack] += ndata_mythread[ipack];
196 : }
197 : } // end of omp parallel region
198 484116 : }
199 :
200 : /*******************************************************************************
201 : * \brief Private routine for filling send buffers.
202 : * \author Ole Schuett
203 : ******************************************************************************/
204 505522 : static void fill_send_buffers(
205 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
206 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
207 : int blks_send_count[nranks], int data_send_count[nranks],
208 : int blks_send_displ[nranks], int data_send_displ[nranks],
209 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
210 505522 : memset(blks_send_count, 0, nranks * sizeof(int));
211 505522 : memset(data_send_count, 0, nranks * sizeof(int));
212 :
213 505522 : #pragma omp parallel
214 : {
215 : // 3th pass: Compute per rank nblks and ndata.
216 : int nblks_mythread[nranks], ndata_mythread[nranks];
217 : memset(nblks_mythread, 0, nranks * sizeof(int));
218 : memset(ndata_mythread, 0, nranks * sizeof(int));
219 : #pragma omp for schedule(static)
220 : for (int iblock = 0; iblock < nblks_send; iblock++) {
221 : const plan_t *plan = &plans[iblock];
222 : nblks_mythread[plan->rank] += 1;
223 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
224 : }
225 :
226 : // Sum nblks and ndata across threads.
227 : #pragma omp critical
228 : for (int irank = 0; irank < nranks; irank++) {
229 : blks_send_count[irank] += nblks_mythread[irank];
230 : data_send_count[irank] += ndata_mythread[irank];
231 : nblks_mythread[irank] = blks_send_count[irank];
232 : ndata_mythread[irank] = data_send_count[irank];
233 : }
234 : #pragma omp barrier
235 :
236 : // Compute send displacements.
237 : #pragma omp single
238 : {
239 : icumsum(nranks, blks_send_count, blks_send_displ);
240 : icumsum(nranks, data_send_count, data_send_displ);
241 : const int m = nranks - 1;
242 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
243 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
244 : }
245 : #pragma omp barrier
246 :
247 : // 4th pass: Fill blks_send and data_send arrays.
248 : #pragma omp for schedule(static) // Need static to match previous loop.
249 : for (int iblock = 0; iblock < nblks_send; iblock++) {
250 : const plan_t *const plan = &plans[iblock];
251 : const dbm_block_t *const blk = plan->blk;
252 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
253 : const dbm_shard_t *const shard = &matrix->shards[ishard];
254 : const double *blk_data = &shard->data[blk->offset];
255 : const int row_size = plan->row_size, col_size = plan->col_size;
256 : const int plan_size = row_size * col_size;
257 : const int irank = plan->rank;
258 :
259 : // The blk_send_data is ordered by rank, thread, and block.
260 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
261 : // ndata_mythread[irank]: Current threads offset within data for irank.
262 : nblks_mythread[irank] -= 1;
263 : ndata_mythread[irank] -= plan_size;
264 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
265 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
266 :
267 : double norm = 0.0; // Compute norm as double...
268 : if (trans_matrix) {
269 : // Transpose block to allow for outer-product style multiplication.
270 : for (int i = 0; i < row_size; i++) {
271 : for (int j = 0; j < col_size; j++) {
272 : const double element = blk_data[j * row_size + i];
273 : data_send[offset + i * col_size + j] = element;
274 : norm += element * element;
275 : }
276 : }
277 : blks_send[jblock].free_index = plan->blk->col;
278 : blks_send[jblock].sum_index = plan->blk->row;
279 : } else {
280 : for (int i = 0; i < plan_size; i++) {
281 : const double element = blk_data[i];
282 : data_send[offset + i] = element;
283 : norm += element * element;
284 : }
285 : blks_send[jblock].free_index = plan->blk->row;
286 : blks_send[jblock].sum_index = plan->blk->col;
287 : }
288 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
289 :
290 : // After the block exchange data_recv_displ will be added to the offsets.
291 : blks_send[jblock].offset = offset - data_send_displ[irank];
292 : }
293 : } // end of omp parallel region
294 505522 : }
295 :
296 : /*******************************************************************************
297 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
298 : * \author Ole Schuett
299 : ******************************************************************************/
300 75113207 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
301 75113207 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
302 75113207 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
303 75113207 : return blk_a->sum_index - blk_b->sum_index;
304 : }
305 :
306 : /*******************************************************************************
307 : * \brief Private routine for post-processing received blocks.
308 : * \author Ole Schuett
309 : ******************************************************************************/
310 505522 : static void postprocess_received_blocks(
311 : const int nranks, const int nshards, const int nblocks_recv,
312 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
313 : const int data_recv_displ[nranks],
314 505522 : dbm_pack_block_t blks_recv[nblocks_recv]) {
315 505522 : int nblocks_per_shard[nshards], shard_start[nshards];
316 505522 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
317 505522 : dbm_pack_block_t *blocks_tmp =
318 505522 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
319 505522 : assert(blocks_tmp != NULL || nblocks_recv == 0);
320 :
321 505522 : #pragma omp parallel
322 : {
323 : // Add data_recv_displ to recveived block offsets.
324 : for (int irank = 0; irank < nranks; irank++) {
325 : #pragma omp for
326 : for (int i = 0; i < blks_recv_count[irank]; i++) {
327 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
328 : }
329 : }
330 :
331 : // First use counting sort to group blocks by their free_index shard.
332 : int nblocks_mythread[nshards];
333 : memset(nblocks_mythread, 0, nshards * sizeof(int));
334 : #pragma omp for schedule(static)
335 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
336 : blocks_tmp[iblock] = blks_recv[iblock];
337 : const int ishard = blks_recv[iblock].free_index % nshards;
338 : nblocks_mythread[ishard]++;
339 : }
340 : #pragma omp critical
341 : for (int ishard = 0; ishard < nshards; ishard++) {
342 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
343 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
344 : }
345 : #pragma omp barrier
346 : #pragma omp single
347 : icumsum(nshards, nblocks_per_shard, shard_start);
348 : #pragma omp barrier
349 : #pragma omp for schedule(static) // Need static to match previous loop.
350 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
351 : const int ishard = blocks_tmp[iblock].free_index % nshards;
352 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
353 : blks_recv[jblock] = blocks_tmp[iblock];
354 : }
355 :
356 : // Then sort blocks within each shard by their sum_index.
357 : #pragma omp for
358 : for (int ishard = 0; ishard < nshards; ishard++) {
359 : if (nblocks_per_shard[ishard] > 1) {
360 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
361 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
362 : }
363 : }
364 : } // end of omp parallel region
365 :
366 505522 : free(blocks_tmp);
367 505522 : }
368 :
369 : /*******************************************************************************
370 : * \brief Private routine for redistributing a matrix along selected dimensions.
371 : * \author Ole Schuett
372 : ******************************************************************************/
373 484116 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
374 : const bool trans_dist,
375 : const dbm_matrix_t *restrict matrix,
376 : const dbm_distribution_t *restrict dist,
377 484116 : const int nticks) {
378 484116 : assert(cp_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
379 :
380 : // The row/col indicies are distributed along one cart dimension and the
381 : // ticks are distributed along the other cart dimension.
382 484116 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
383 484116 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
384 484116 : const int *free_index_sizes =
385 : (trans_matrix) ? matrix->col_sizes : matrix->row_sizes;
386 484116 : const int *sum_index_sizes =
387 : (trans_matrix) ? matrix->row_sizes : matrix->col_sizes;
388 :
389 : // Allocate packed matrix.
390 484116 : const int nsend_packs = nticks / dist_ticks->nranks;
391 484116 : assert(nsend_packs * dist_ticks->nranks == nticks);
392 484116 : dbm_packed_matrix_t packed;
393 484116 : packed.dist_indices = dist_indices;
394 484116 : packed.dist_ticks = dist_ticks;
395 484116 : packed.nsend_packs = nsend_packs;
396 484116 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
397 484116 : assert(packed.send_packs != NULL || nsend_packs == 0);
398 :
399 : // Plan all packs.
400 484116 : plan_t *plans_per_pack[nsend_packs];
401 484116 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
402 484116 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
403 : dist_ticks, nticks, nsend_packs, plans_per_pack,
404 : nblks_send_per_pack, ndata_send_per_pack);
405 :
406 : // Allocate send buffers for maximum number of blocks/data over all packs.
407 484116 : int nblks_send_max = 0, ndata_send_max = 0;
408 989638 : for (int ipack = 0; ipack < nsend_packs; ++ipack) {
409 505522 : nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
410 505522 : ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
411 : }
412 484116 : dbm_pack_block_t *blks_send =
413 484116 : cp_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
414 484116 : double *data_send = cp_mpi_alloc_mem(ndata_send_max * sizeof(double));
415 :
416 : // Cannot parallelize over packs (there might be too few of them).
417 989638 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
418 : // Fill send buffers according to plans.
419 505522 : const int nranks = dist->nranks;
420 505522 : int blks_send_count[nranks], data_send_count[nranks];
421 505522 : int blks_send_displ[nranks], data_send_displ[nranks];
422 505522 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
423 505522 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
424 : blks_send_count, data_send_count, blks_send_displ,
425 : data_send_displ, blks_send, data_send);
426 505522 : free(plans_per_pack[ipack]);
427 :
428 : // 1st communication: Exchange block counts.
429 505522 : int blks_recv_count[nranks], blks_recv_displ[nranks];
430 505522 : cp_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
431 1580784 : icumsum(nranks, blks_recv_count, blks_recv_displ);
432 505522 : const int nblocks_recv = isum(nranks, blks_recv_count);
433 :
434 : // 2nd communication: Exchange blocks.
435 505522 : dbm_pack_block_t *blks_recv =
436 505522 : cp_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
437 505522 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
438 505522 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
439 1075262 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
440 1139480 : blks_send_count_byte[i] =
441 569740 : checked_byte_count(blks_send_count[i], sizeof(dbm_pack_block_t));
442 1139480 : blks_send_displ_byte[i] =
443 569740 : checked_byte_count(blks_send_displ[i], sizeof(dbm_pack_block_t));
444 1139480 : blks_recv_count_byte[i] =
445 569740 : checked_byte_count(blks_recv_count[i], sizeof(dbm_pack_block_t));
446 569740 : blks_recv_displ_byte[i] =
447 569740 : checked_byte_count(blks_recv_displ[i], sizeof(dbm_pack_block_t));
448 : }
449 505522 : cp_mpi_alltoallv_byte(blks_send, blks_send_count_byte, blks_send_displ_byte,
450 : blks_recv, blks_recv_count_byte, blks_recv_displ_byte,
451 : dist->comm);
452 :
453 : // Compute data counts from the received block metadata.
454 505522 : int data_recv_count[nranks], data_recv_displ[nranks];
455 505522 : compute_data_recv_count(nranks, blks_recv_count, blks_recv_displ,
456 : free_index_sizes, sum_index_sizes, blks_recv,
457 : data_recv_count);
458 1580784 : icumsum(nranks, data_recv_count, data_recv_displ);
459 505522 : const int ndata_recv = isum(nranks, data_recv_count);
460 :
461 : // 4th communication: Exchange data.
462 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
463 505522 : double *data_recv =
464 505522 : offload_mempool_host_malloc(ndata_recv * sizeof(double));
465 : #else
466 : double *data_recv = cp_mpi_alloc_mem(ndata_recv * sizeof(double));
467 : #endif
468 505522 : cp_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
469 : data_recv, data_recv_count, data_recv_displ,
470 : dist->comm);
471 :
472 : // Post-process received blocks and assemble them into a pack.
473 505522 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
474 : blks_recv_count, blks_recv_displ,
475 : data_recv_displ, blks_recv);
476 505522 : packed.send_packs[ipack].nblocks = nblocks_recv;
477 505522 : packed.send_packs[ipack].data_size = ndata_recv;
478 505522 : packed.send_packs[ipack].blocks = blks_recv;
479 505522 : packed.send_packs[ipack].data = data_recv;
480 : }
481 :
482 : // Deallocate send buffers.
483 484116 : cp_mpi_free_mem(blks_send);
484 484116 : cp_mpi_free_mem(data_send);
485 :
486 : // Allocate pack_recv.
487 484116 : int max_nblocks = 0, max_data_size = 0;
488 989638 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
489 505522 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
490 505522 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
491 : }
492 484116 : cp_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
493 484116 : cp_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
494 484116 : packed.max_nblocks = max_nblocks;
495 484116 : packed.max_data_size = max_data_size;
496 968232 : packed.recv_pack.blocks =
497 484116 : cp_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
498 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
499 968232 : packed.recv_pack.data =
500 484116 : offload_mempool_host_malloc(packed.max_data_size * sizeof(double));
501 : #else
502 : packed.recv_pack.data =
503 : cp_mpi_alloc_mem(packed.max_data_size * sizeof(double));
504 : #endif
505 :
506 484116 : return packed; // Ownership of packed transfers to caller.
507 : }
508 :
509 : /*******************************************************************************
510 : * \brief Private routine for sending and receiving the pack for the given tick.
511 : * \author Ole Schuett
512 : ******************************************************************************/
513 526928 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
514 : dbm_packed_matrix_t *packed) {
515 526928 : const int nranks = packed->dist_ticks->nranks;
516 526928 : const int my_rank = packed->dist_ticks->my_rank;
517 :
518 : // Compute send rank and pack.
519 526928 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
520 526928 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
521 526928 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
522 526928 : const int send_ipack = send_itick / nranks;
523 526928 : assert(send_itick % nranks == my_rank);
524 :
525 : // Compute receive rank and pack.
526 526928 : const int recv_rank = itick % nranks;
527 526928 : const int recv_ipack = itick / nranks;
528 :
529 526928 : dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
530 526928 : if (send_rank == my_rank) {
531 505522 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
532 : return send_pack; // Local pack, no mpi needed.
533 : } else {
534 : // Exchange blocks.
535 64218 : const int nblocks_in_bytes = cp_mpi_sendrecv_byte(
536 21406 : /*sendbuf=*/send_pack->blocks,
537 : /*sendcound=*/
538 21406 : checked_byte_count(send_pack->nblocks, sizeof(dbm_pack_block_t)),
539 : /*dest=*/send_rank,
540 : /*sendtag=*/send_ipack,
541 21406 : /*recvbuf=*/packed->recv_pack.blocks,
542 : /*recvcount=*/
543 21406 : checked_byte_count(packed->max_nblocks, sizeof(dbm_pack_block_t)),
544 : /*source=*/recv_rank,
545 : /*recvtag=*/recv_ipack,
546 : /*comm=*/packed->dist_ticks->comm);
547 :
548 21406 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
549 21406 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
550 :
551 : // Exchange data.
552 42812 : packed->recv_pack.data_size = cp_mpi_sendrecv_double(
553 21406 : /*sendbuf=*/send_pack->data,
554 21406 : /*sendcound=*/send_pack->data_size,
555 : /*dest=*/send_rank,
556 : /*sendtag=*/send_ipack,
557 : /*recvbuf=*/packed->recv_pack.data,
558 21406 : /*recvcount=*/packed->max_data_size,
559 : /*source=*/recv_rank,
560 : /*recvtag=*/recv_ipack,
561 21406 : /*comm=*/packed->dist_ticks->comm);
562 :
563 21406 : return &packed->recv_pack;
564 : }
565 : }
566 :
567 : /*******************************************************************************
568 : * \brief Private routine for releasing a packed matrix.
569 : * \author Ole Schuett
570 : ******************************************************************************/
571 484116 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
572 484116 : cp_mpi_free_mem(packed->recv_pack.blocks);
573 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
574 484116 : offload_mempool_host_free(packed->recv_pack.data);
575 : #else
576 : cp_mpi_free_mem(packed->recv_pack.data);
577 : #endif
578 989638 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
579 505522 : cp_mpi_free_mem(packed->send_packs[ipack].blocks);
580 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
581 505522 : offload_mempool_host_free(packed->send_packs[ipack].data);
582 : #else
583 : cp_mpi_free_mem(packed->send_packs[ipack].data);
584 : #endif
585 : }
586 484116 : free(packed->send_packs);
587 484116 : }
588 :
589 : /*******************************************************************************
590 : * \brief Internal routine for creating a communication iterator.
591 : * \author Ole Schuett
592 : ******************************************************************************/
593 242058 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
594 : const bool transb,
595 : const dbm_matrix_t *matrix_a,
596 : const dbm_matrix_t *matrix_b,
597 : const dbm_matrix_t *matrix_c) {
598 242058 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
599 242058 : assert(iter != NULL);
600 242058 : iter->dist = matrix_c->dist;
601 :
602 : // During each communication tick we'll fetch a pack_a and pack_b.
603 : // Since the cart might be non-squared, the number of communication ticks is
604 : // chosen as the least common multiple of the cart's dimensions.
605 242058 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
606 242058 : iter->itick = 0;
607 :
608 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
609 242058 : iter->packed_a =
610 242058 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
611 242058 : iter->packed_b =
612 242058 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
613 :
614 242058 : return iter;
615 : }
616 :
617 : /*******************************************************************************
618 : * \brief Internal routine for retrieving next pair of packs of given iterator.
619 : * \author Ole Schuett
620 : ******************************************************************************/
621 505522 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
622 : dbm_pack_t **pack_b) {
623 505522 : if (iter->itick >= iter->nticks) {
624 : return false; // end of iterator reached
625 : }
626 :
627 : // Start each rank at a different tick to spread the load on the sources.
628 263464 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
629 263464 : const int itick = (iter->itick + shift) % iter->nticks;
630 263464 : *pack_a = sendrecv_pack(itick, iter->nticks, &iter->packed_a);
631 263464 : *pack_b = sendrecv_pack(itick, iter->nticks, &iter->packed_b);
632 :
633 263464 : ++iter->itick;
634 263464 : return true;
635 : }
636 :
637 : /*******************************************************************************
638 : * \brief Internal routine for releasing the given communication iterator.
639 : * \author Ole Schuett
640 : ******************************************************************************/
641 242058 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
642 242058 : free_packed_matrix(&iter->packed_a);
643 242058 : free_packed_matrix(&iter->packed_b);
644 242058 : free(iter);
645 242058 : }
646 :
647 : // EOF
|