Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_multiply_comm.h"
8 : #include "../mpiwrap/cp_mpi.h"
9 : #include "../offload/offload_mempool.h"
10 :
11 : #include <assert.h>
12 : #include <limits.h>
13 : #include <stdlib.h>
14 : #include <string.h>
15 :
16 : #if 1
17 : #define DBM_MULTIPLY_COMM_MEMPOOL
18 : #endif
19 :
20 : /*******************************************************************************
21 : * \brief Private routine for computing greatest common divisor of two numbers.
22 : * \author Ole Schuett
23 : ******************************************************************************/
24 506318 : static int gcd(const int a, const int b) {
25 506318 : if (a == 0) {
26 : return b;
27 : }
28 263828 : return gcd(b % a, a); // Euclid's algorithm.
29 : }
30 :
31 : /*******************************************************************************
32 : * \brief Private routine for computing least common multiple of two numbers.
33 : * \author Ole Schuett
34 : ******************************************************************************/
35 242490 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
36 :
37 : /*******************************************************************************
38 : * \brief Private routine for converting element counts to byte counts.
39 : * \author Hans Pabst
40 : ******************************************************************************/
41 2325228 : static int checked_byte_count(const int nelements, const size_t element_size) {
42 : // MPI displacements/counts are int; unchecked multiplication could overflow
43 : // silently and corrupt the alltoallv layout.
44 2325228 : assert(0 <= nelements);
45 2325228 : assert(element_size <= INT_MAX);
46 2325228 : assert(nelements <= INT_MAX / (int)element_size);
47 2325228 : return nelements * (int)element_size;
48 : }
49 :
50 : /*******************************************************************************
51 : * \brief Private routine for computing the sum of the given integers.
52 : * \author Ole Schuett
53 : ******************************************************************************/
54 : static inline int isum(const int n, const int input[n]) {
55 : int output = 0;
56 3295188 : for (int i = 0; i < n; i++) {
57 1141208 : output += input[i];
58 : }
59 1012772 : return output;
60 : }
61 :
62 : /*******************************************************************************
63 : * \brief Private routine for computing the cumulative sums of given numbers.
64 : * \author Ole Schuett and Hans Pabst
65 : ******************************************************************************/
66 1012772 : static inline void icumsum(const int n, const int input[n], int output[n]) {
67 : // Exclusive prefix sum using carried accumulators to avoid a load from
68 : // output[i-1] each iteration (removes loop-carried memory dependency).
69 1012772 : int oval = output[0] = 0, ival = input[0];
70 1141208 : for (int i = 1; i < n; i++) {
71 128436 : output[i] = (oval += ival);
72 128436 : ival = input[i];
73 : }
74 : }
75 :
76 : /*******************************************************************************
77 : * \brief Private routine computing received data counts from block metadata.
78 : * \author Hans Pabst
79 : ******************************************************************************/
80 506386 : static void compute_data_recv_count(const int nranks,
81 : const int blks_recv_count[nranks],
82 : const int blks_recv_displ[nranks],
83 : const int free_index_sizes[],
84 : const int sum_index_sizes[],
85 : const dbm_pack_block_t blks_recv[],
86 : int data_recv_count[nranks]) {
87 : // Derives per-rank data counts locally from the already-received block
88 : // metadata, eliminating a collective alltoall for exchanging data amounts.
89 506386 : memset(data_recv_count, 0, nranks * sizeof(int));
90 1076990 : for (int irank = 0; irank < nranks; irank++) {
91 14392248 : for (int i = 0; i < blks_recv_count[irank]; i++) {
92 13821644 : const dbm_pack_block_t *const blk =
93 13821644 : &blks_recv[blks_recv_displ[irank] + i];
94 13821644 : const int block_size =
95 13821644 : free_index_sizes[blk->free_index] * sum_index_sizes[blk->sum_index];
96 13821644 : assert(block_size >= 0);
97 13821644 : assert(data_recv_count[irank] <= INT_MAX - block_size);
98 13821644 : data_recv_count[irank] += block_size;
99 : }
100 : }
101 506386 : }
102 :
103 : /*******************************************************************************
104 : * \brief Private struct used for planing during pack_matrix.
105 : * \author Ole Schuett
106 : ******************************************************************************/
107 : typedef struct {
108 : const dbm_block_t *blk; // source block
109 : int rank; // target mpi rank
110 : int row_size;
111 : int col_size;
112 : } plan_t;
113 :
114 : /*******************************************************************************
115 : * \brief Private routine for calculating tick indices in pack plans.
116 : * \author Maximilian Graml
117 : ******************************************************************************/
118 : static inline unsigned long long calculate_tick_index(int sum_index,
119 : int nticks) {
120 : // 1021 is used as a random prime to scramble the index
121 : return ((unsigned long long)sum_index * 1021ULL) % (unsigned long long)nticks;
122 : }
123 :
124 : /*******************************************************************************
125 : * \brief Private routine for planing packs.
126 : * \author Ole Schuett
127 : ******************************************************************************/
128 484980 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
129 : const dbm_matrix_t *matrix,
130 : const cp_mpi_comm_t comm,
131 : const dbm_dist_1d_t *dist_indices,
132 : const dbm_dist_1d_t *dist_ticks, const int nticks,
133 : const int npacks, plan_t *plans_per_pack[npacks],
134 : int nblks_per_pack[npacks],
135 : int ndata_per_pack[npacks]) {
136 484980 : memset(nblks_per_pack, 0, npacks * sizeof(int));
137 484980 : memset(ndata_per_pack, 0, npacks * sizeof(int));
138 :
139 484980 : #pragma omp parallel
140 : {
141 : // 1st pass: Compute number of blocks that will be send in each pack.
142 : int nblks_mythread[npacks];
143 : memset(nblks_mythread, 0, npacks * sizeof(int));
144 : #pragma omp for schedule(static)
145 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
146 : dbm_shard_t *shard = &matrix->shards[ishard];
147 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
148 : const dbm_block_t *blk = &shard->blocks[iblock];
149 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
150 : unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
151 : const int ipack = itick64 / dist_ticks->nranks;
152 : nblks_mythread[ipack]++;
153 : }
154 : }
155 :
156 : // Sum nblocks across threads and allocate arrays for plans.
157 : #pragma omp critical
158 : for (int ipack = 0; ipack < npacks; ipack++) {
159 : nblks_per_pack[ipack] += nblks_mythread[ipack];
160 : nblks_mythread[ipack] = nblks_per_pack[ipack];
161 : }
162 : #pragma omp barrier
163 : #pragma omp for
164 : for (int ipack = 0; ipack < npacks; ipack++) {
165 : const int nblks = nblks_per_pack[ipack];
166 : plans_per_pack[ipack] = malloc(nblks * sizeof(plan_t));
167 : assert(plans_per_pack[ipack] != NULL || nblks == 0);
168 : }
169 :
170 : // 2nd pass: Plan where to send each block.
171 : int ndata_mythread[npacks];
172 : memset(ndata_mythread, 0, npacks * sizeof(int));
173 : #pragma omp for schedule(static) // Need static to match previous loop.
174 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
175 : dbm_shard_t *shard = &matrix->shards[ishard];
176 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
177 : const dbm_block_t *blk = &shard->blocks[iblock];
178 : const int free_index = (trans_matrix) ? blk->col : blk->row;
179 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
180 : unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
181 : const int ipack = itick64 / dist_ticks->nranks;
182 : // Compute rank to which this block should be sent.
183 : const int coord_free_idx = dist_indices->index2coord[free_index];
184 : const int coord_sum_idx = itick64 % dist_ticks->nranks;
185 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
186 : (trans_dist) ? coord_free_idx : coord_sum_idx};
187 : const int rank = cp_mpi_cart_rank(comm, coords);
188 : const int row_size = matrix->row_sizes[blk->row];
189 : const int col_size = matrix->col_sizes[blk->col];
190 : ndata_mythread[ipack] += row_size * col_size;
191 : // Create plan.
192 : const int iplan = --nblks_mythread[ipack];
193 : plans_per_pack[ipack][iplan].blk = blk;
194 : plans_per_pack[ipack][iplan].rank = rank;
195 : plans_per_pack[ipack][iplan].row_size = row_size;
196 : plans_per_pack[ipack][iplan].col_size = col_size;
197 : }
198 : }
199 : #pragma omp critical
200 : for (int ipack = 0; ipack < npacks; ipack++) {
201 : ndata_per_pack[ipack] += ndata_mythread[ipack];
202 : }
203 : } // end of omp parallel region
204 484980 : }
205 :
206 : /*******************************************************************************
207 : * \brief Private routine for filling send buffers.
208 : * \author Ole Schuett
209 : ******************************************************************************/
210 506386 : static void fill_send_buffers(
211 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
212 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
213 : int blks_send_count[nranks], int data_send_count[nranks],
214 : int blks_send_displ[nranks], int data_send_displ[nranks],
215 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
216 506386 : memset(blks_send_count, 0, nranks * sizeof(int));
217 506386 : memset(data_send_count, 0, nranks * sizeof(int));
218 :
219 506386 : #pragma omp parallel
220 : {
221 : // 3th pass: Compute per rank nblks and ndata.
222 : int nblks_mythread[nranks], ndata_mythread[nranks];
223 : memset(nblks_mythread, 0, nranks * sizeof(int));
224 : memset(ndata_mythread, 0, nranks * sizeof(int));
225 : #pragma omp for schedule(static)
226 : for (int iblock = 0; iblock < nblks_send; iblock++) {
227 : const plan_t *plan = &plans[iblock];
228 : nblks_mythread[plan->rank] += 1;
229 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
230 : }
231 :
232 : // Sum nblks and ndata across threads.
233 : #pragma omp critical
234 : for (int irank = 0; irank < nranks; irank++) {
235 : blks_send_count[irank] += nblks_mythread[irank];
236 : data_send_count[irank] += ndata_mythread[irank];
237 : nblks_mythread[irank] = blks_send_count[irank];
238 : ndata_mythread[irank] = data_send_count[irank];
239 : }
240 : #pragma omp barrier
241 :
242 : // Compute send displacements.
243 : #pragma omp single
244 : {
245 : icumsum(nranks, blks_send_count, blks_send_displ);
246 : icumsum(nranks, data_send_count, data_send_displ);
247 : const int m = nranks - 1;
248 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
249 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
250 : }
251 : #pragma omp barrier
252 :
253 : // 4th pass: Fill blks_send and data_send arrays.
254 : #pragma omp for schedule(static) // Need static to match previous loop.
255 : for (int iblock = 0; iblock < nblks_send; iblock++) {
256 : const plan_t *const plan = &plans[iblock];
257 : const dbm_block_t *const blk = plan->blk;
258 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
259 : const dbm_shard_t *const shard = &matrix->shards[ishard];
260 : const double *blk_data = &shard->data[blk->offset];
261 : const int row_size = plan->row_size, col_size = plan->col_size;
262 : const int plan_size = row_size * col_size;
263 : const int irank = plan->rank;
264 :
265 : // The blk_send_data is ordered by rank, thread, and block.
266 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
267 : // ndata_mythread[irank]: Current threads offset within data for irank.
268 : nblks_mythread[irank] -= 1;
269 : ndata_mythread[irank] -= plan_size;
270 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
271 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
272 :
273 : double norm = 0.0; // Compute norm as double...
274 : if (trans_matrix) {
275 : // Transpose block to allow for outer-product style multiplication.
276 : for (int i = 0; i < row_size; i++) {
277 : for (int j = 0; j < col_size; j++) {
278 : const double element = blk_data[j * row_size + i];
279 : data_send[offset + i * col_size + j] = element;
280 : norm += element * element;
281 : }
282 : }
283 : blks_send[jblock].free_index = plan->blk->col;
284 : blks_send[jblock].sum_index = plan->blk->row;
285 : } else {
286 : for (int i = 0; i < plan_size; i++) {
287 : const double element = blk_data[i];
288 : data_send[offset + i] = element;
289 : norm += element * element;
290 : }
291 : blks_send[jblock].free_index = plan->blk->row;
292 : blks_send[jblock].sum_index = plan->blk->col;
293 : }
294 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
295 :
296 : // After the block exchange data_recv_displ will be added to the offsets.
297 : blks_send[jblock].offset = offset - data_send_displ[irank];
298 : }
299 : } // end of omp parallel region
300 506386 : }
301 :
302 : /*******************************************************************************
303 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
304 : * \author Ole Schuett
305 : ******************************************************************************/
306 75118936 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
307 75118936 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
308 75118936 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
309 75118936 : return blk_a->sum_index - blk_b->sum_index;
310 : }
311 :
312 : /*******************************************************************************
313 : * \brief Private routine for post-processing received blocks.
314 : * \author Ole Schuett
315 : ******************************************************************************/
316 506386 : static void postprocess_received_blocks(
317 : const int nranks, const int nshards, const int nblocks_recv,
318 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
319 : const int data_recv_displ[nranks],
320 506386 : dbm_pack_block_t blks_recv[nblocks_recv]) {
321 506386 : int nblocks_per_shard[nshards], shard_start[nshards];
322 506386 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
323 506386 : dbm_pack_block_t *blocks_tmp =
324 506386 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
325 506386 : assert(blocks_tmp != NULL || nblocks_recv == 0);
326 :
327 506386 : #pragma omp parallel
328 : {
329 : // Add data_recv_displ to recveived block offsets.
330 : for (int irank = 0; irank < nranks; irank++) {
331 : #pragma omp for
332 : for (int i = 0; i < blks_recv_count[irank]; i++) {
333 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
334 : }
335 : }
336 :
337 : // First use counting sort to group blocks by their free_index shard.
338 : int nblocks_mythread[nshards];
339 : memset(nblocks_mythread, 0, nshards * sizeof(int));
340 : #pragma omp for schedule(static)
341 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
342 : blocks_tmp[iblock] = blks_recv[iblock];
343 : const int ishard = blks_recv[iblock].free_index % nshards;
344 : nblocks_mythread[ishard]++;
345 : }
346 : #pragma omp critical
347 : for (int ishard = 0; ishard < nshards; ishard++) {
348 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
349 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
350 : }
351 : #pragma omp barrier
352 : #pragma omp single
353 : icumsum(nshards, nblocks_per_shard, shard_start);
354 : #pragma omp barrier
355 : #pragma omp for schedule(static) // Need static to match previous loop.
356 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
357 : const int ishard = blocks_tmp[iblock].free_index % nshards;
358 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
359 : blks_recv[jblock] = blocks_tmp[iblock];
360 : }
361 :
362 : // Then sort blocks within each shard by their sum_index.
363 : #pragma omp for
364 : for (int ishard = 0; ishard < nshards; ishard++) {
365 : if (nblocks_per_shard[ishard] > 1) {
366 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
367 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
368 : }
369 : }
370 : } // end of omp parallel region
371 :
372 506386 : free(blocks_tmp);
373 506386 : }
374 :
375 : /*******************************************************************************
376 : * \brief Private routine for redistributing a matrix along selected dimensions.
377 : * \author Ole Schuett
378 : ******************************************************************************/
379 484980 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
380 : const bool trans_dist,
381 : const dbm_matrix_t *restrict matrix,
382 : const dbm_distribution_t *restrict dist,
383 484980 : const int nticks) {
384 484980 : assert(cp_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
385 :
386 : // The row/col indicies are distributed along one cart dimension and the
387 : // ticks are distributed along the other cart dimension.
388 484980 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
389 484980 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
390 484980 : const int *free_index_sizes =
391 : (trans_matrix) ? matrix->col_sizes : matrix->row_sizes;
392 484980 : const int *sum_index_sizes =
393 : (trans_matrix) ? matrix->row_sizes : matrix->col_sizes;
394 :
395 : // Allocate packed matrix.
396 484980 : const int nsend_packs = nticks / dist_ticks->nranks;
397 484980 : assert(nsend_packs * dist_ticks->nranks == nticks);
398 484980 : dbm_packed_matrix_t packed;
399 484980 : packed.dist_indices = dist_indices;
400 484980 : packed.dist_ticks = dist_ticks;
401 484980 : packed.nsend_packs = nsend_packs;
402 484980 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
403 484980 : assert(packed.send_packs != NULL || nsend_packs == 0);
404 :
405 : // Plan all packs.
406 484980 : plan_t *plans_per_pack[nsend_packs];
407 484980 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
408 484980 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
409 : dist_ticks, nticks, nsend_packs, plans_per_pack,
410 : nblks_send_per_pack, ndata_send_per_pack);
411 :
412 : // Allocate send buffers for maximum number of blocks/data over all packs.
413 484980 : int nblks_send_max = 0, ndata_send_max = 0;
414 991366 : for (int ipack = 0; ipack < nsend_packs; ++ipack) {
415 506386 : nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
416 506386 : ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
417 : }
418 484980 : dbm_pack_block_t *blks_send =
419 484980 : cp_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
420 484980 : double *data_send = cp_mpi_alloc_mem(ndata_send_max * sizeof(double));
421 :
422 : // Cannot parallelize over packs (there might be too few of them).
423 991366 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
424 : // Fill send buffers according to plans.
425 506386 : const int nranks = dist->nranks;
426 506386 : int blks_send_count[nranks], data_send_count[nranks];
427 506386 : int blks_send_displ[nranks], data_send_displ[nranks];
428 506386 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
429 506386 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
430 : blks_send_count, data_send_count, blks_send_displ,
431 : data_send_displ, blks_send, data_send);
432 506386 : free(plans_per_pack[ipack]);
433 :
434 : // 1st communication: Exchange block counts.
435 506386 : int blks_recv_count[nranks], blks_recv_displ[nranks];
436 506386 : cp_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
437 1583376 : icumsum(nranks, blks_recv_count, blks_recv_displ);
438 506386 : const int nblocks_recv = isum(nranks, blks_recv_count);
439 :
440 : // 2nd communication: Exchange blocks.
441 506386 : dbm_pack_block_t *blks_recv =
442 506386 : cp_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
443 506386 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
444 506386 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
445 1076990 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
446 1141208 : blks_send_count_byte[i] =
447 570604 : checked_byte_count(blks_send_count[i], sizeof(dbm_pack_block_t));
448 1141208 : blks_send_displ_byte[i] =
449 570604 : checked_byte_count(blks_send_displ[i], sizeof(dbm_pack_block_t));
450 1141208 : blks_recv_count_byte[i] =
451 570604 : checked_byte_count(blks_recv_count[i], sizeof(dbm_pack_block_t));
452 570604 : blks_recv_displ_byte[i] =
453 570604 : checked_byte_count(blks_recv_displ[i], sizeof(dbm_pack_block_t));
454 : }
455 506386 : cp_mpi_alltoallv_byte(blks_send, blks_send_count_byte, blks_send_displ_byte,
456 : blks_recv, blks_recv_count_byte, blks_recv_displ_byte,
457 : dist->comm);
458 :
459 : // 3rd compute data counts from the received block metadata.
460 506386 : int data_recv_count[nranks], data_recv_displ[nranks];
461 506386 : compute_data_recv_count(nranks, blks_recv_count, blks_recv_displ,
462 : free_index_sizes, sum_index_sizes, blks_recv,
463 : data_recv_count);
464 1583376 : icumsum(nranks, data_recv_count, data_recv_displ);
465 506386 : const int ndata_recv = isum(nranks, data_recv_count);
466 :
467 : // 4th communication: Exchange data.
468 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
469 506386 : double *data_recv =
470 506386 : offload_mempool_host_malloc(ndata_recv * sizeof(double));
471 : #else
472 : double *data_recv = cp_mpi_alloc_mem(ndata_recv * sizeof(double));
473 : #endif
474 506386 : cp_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
475 : data_recv, data_recv_count, data_recv_displ,
476 : dist->comm);
477 :
478 : // 5th post-process received blocks and assemble them into a pack.
479 506386 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
480 : blks_recv_count, blks_recv_displ,
481 : data_recv_displ, blks_recv);
482 506386 : packed.send_packs[ipack].nblocks = nblocks_recv;
483 506386 : packed.send_packs[ipack].data_size = ndata_recv;
484 506386 : packed.send_packs[ipack].blocks = blks_recv;
485 506386 : packed.send_packs[ipack].data = data_recv;
486 : }
487 :
488 : // Deallocate send buffers.
489 484980 : cp_mpi_free_mem(blks_send);
490 484980 : cp_mpi_free_mem(data_send);
491 :
492 : // Allocate pack_recv.
493 484980 : int max_nblocks = 0, max_data_size = 0;
494 991366 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
495 506386 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
496 506386 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
497 : }
498 484980 : cp_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
499 484980 : cp_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
500 484980 : packed.max_nblocks = max_nblocks;
501 484980 : packed.max_data_size = max_data_size;
502 969960 : packed.recv_pack.blocks =
503 484980 : cp_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
504 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
505 969960 : packed.recv_pack.data =
506 484980 : offload_mempool_host_malloc(packed.max_data_size * sizeof(double));
507 : #else
508 : packed.recv_pack.data =
509 : cp_mpi_alloc_mem(packed.max_data_size * sizeof(double));
510 : #endif
511 :
512 484980 : return packed; // Ownership of packed transfers to caller.
513 : }
514 :
515 : /*******************************************************************************
516 : * \brief Private routine for sending and receiving the pack for the given tick.
517 : * \author Ole Schuett
518 : ******************************************************************************/
519 527792 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
520 : dbm_packed_matrix_t *packed) {
521 527792 : const int nranks = packed->dist_ticks->nranks;
522 527792 : const int my_rank = packed->dist_ticks->my_rank;
523 :
524 : // Compute send rank and pack.
525 527792 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
526 527792 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
527 527792 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
528 527792 : const int send_ipack = send_itick / nranks;
529 527792 : assert(send_itick % nranks == my_rank);
530 :
531 : // Compute receive rank and pack.
532 527792 : const int recv_rank = itick % nranks;
533 527792 : const int recv_ipack = itick / nranks;
534 :
535 527792 : dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
536 527792 : if (send_rank == my_rank) {
537 506386 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
538 : return send_pack; // Local pack, no mpi needed.
539 : } else {
540 : // Exchange blocks.
541 64218 : const int nblocks_in_bytes = cp_mpi_sendrecv_byte(
542 21406 : /*sendbuf=*/send_pack->blocks,
543 : /*sendcound=*/
544 21406 : checked_byte_count(send_pack->nblocks, sizeof(dbm_pack_block_t)),
545 : /*dest=*/send_rank,
546 : /*sendtag=*/send_ipack,
547 21406 : /*recvbuf=*/packed->recv_pack.blocks,
548 : /*recvcount=*/
549 21406 : checked_byte_count(packed->max_nblocks, sizeof(dbm_pack_block_t)),
550 : /*source=*/recv_rank,
551 : /*recvtag=*/recv_ipack,
552 : /*comm=*/packed->dist_ticks->comm);
553 :
554 21406 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
555 21406 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
556 :
557 : // Exchange data.
558 42812 : packed->recv_pack.data_size = cp_mpi_sendrecv_double(
559 21406 : /*sendbuf=*/send_pack->data,
560 21406 : /*sendcound=*/send_pack->data_size,
561 : /*dest=*/send_rank,
562 : /*sendtag=*/send_ipack,
563 : /*recvbuf=*/packed->recv_pack.data,
564 21406 : /*recvcount=*/packed->max_data_size,
565 : /*source=*/recv_rank,
566 : /*recvtag=*/recv_ipack,
567 21406 : /*comm=*/packed->dist_ticks->comm);
568 :
569 21406 : return &packed->recv_pack;
570 : }
571 : }
572 :
573 : /*******************************************************************************
574 : * \brief Private routine for releasing a packed matrix.
575 : * \author Ole Schuett
576 : ******************************************************************************/
577 484980 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
578 484980 : cp_mpi_free_mem(packed->recv_pack.blocks);
579 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
580 484980 : offload_mempool_host_free(packed->recv_pack.data);
581 : #else
582 : cp_mpi_free_mem(packed->recv_pack.data);
583 : #endif
584 991366 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
585 506386 : cp_mpi_free_mem(packed->send_packs[ipack].blocks);
586 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
587 506386 : offload_mempool_host_free(packed->send_packs[ipack].data);
588 : #else
589 : cp_mpi_free_mem(packed->send_packs[ipack].data);
590 : #endif
591 : }
592 484980 : free(packed->send_packs);
593 484980 : }
594 :
595 : /*******************************************************************************
596 : * \brief Internal routine for creating a communication iterator.
597 : * \author Ole Schuett
598 : ******************************************************************************/
599 242490 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
600 : const bool transb,
601 : const dbm_matrix_t *matrix_a,
602 : const dbm_matrix_t *matrix_b,
603 : const dbm_matrix_t *matrix_c) {
604 242490 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
605 242490 : assert(iter != NULL);
606 242490 : iter->dist = matrix_c->dist;
607 :
608 : // During each communication tick we'll fetch a pack_a and pack_b.
609 : // Since the cart might be non-squared, the number of communication ticks is
610 : // chosen as the least common multiple of the cart's dimensions.
611 242490 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
612 242490 : iter->itick = 0;
613 :
614 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
615 242490 : iter->packed_a =
616 242490 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
617 242490 : iter->packed_b =
618 242490 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
619 :
620 242490 : return iter;
621 : }
622 :
623 : /*******************************************************************************
624 : * \brief Internal routine for retrieving next pair of packs of given iterator.
625 : * \author Ole Schuett
626 : ******************************************************************************/
627 506386 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
628 : dbm_pack_t **pack_b) {
629 506386 : if (iter->itick >= iter->nticks) {
630 : return false; // end of iterator reached
631 : }
632 :
633 : // Start each rank at a different tick to spread the load on the sources.
634 263896 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
635 263896 : const int itick = (iter->itick + shift) % iter->nticks;
636 263896 : *pack_a = sendrecv_pack(itick, iter->nticks, &iter->packed_a);
637 263896 : *pack_b = sendrecv_pack(itick, iter->nticks, &iter->packed_b);
638 :
639 263896 : ++iter->itick;
640 263896 : return true;
641 : }
642 :
643 : /*******************************************************************************
644 : * \brief Internal routine for releasing the given communication iterator.
645 : * \author Ole Schuett
646 : ******************************************************************************/
647 242490 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
648 242490 : free_packed_matrix(&iter->packed_a);
649 242490 : free_packed_matrix(&iter->packed_b);
650 242490 : free(iter);
651 242490 : }
652 :
653 : // EOF
|