Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include "dbm_multiply_comm.h"
9 :
10 : #include <assert.h>
11 : #include <stdlib.h>
12 : #include <string.h>
13 :
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_mempool.h"
16 : #include "dbm_mpi.h"
17 :
18 : /*******************************************************************************
19 : * \brief Private routine for computing greatest common divisor of two numbers.
20 : * \author Ole Schuett
21 : ******************************************************************************/
22 434134 : static int gcd(const int a, const int b) {
23 434134 : if (a == 0)
24 : return b;
25 225814 : return gcd(b % a, a); // Euclid's algorithm.
26 : }
27 :
28 : /*******************************************************************************
29 : * \brief Private routine for computing least common multiple of two numbers.
30 : * \author Ole Schuett
31 : ******************************************************************************/
32 208320 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
33 :
34 : /*******************************************************************************
35 : * \brief Private routine for computing the sum of the given integers.
36 : * \author Ole Schuett
37 : ******************************************************************************/
38 868404 : static inline int isum(const int n, const int input[n]) {
39 868404 : int output = 0;
40 1842180 : for (int i = 0; i < n; i++) {
41 973776 : output += input[i];
42 : }
43 868404 : return output;
44 : }
45 :
46 : /*******************************************************************************
47 : * \brief Private routine for computing the cumulative sums of given numbers.
48 : * \author Ole Schuett
49 : ******************************************************************************/
50 2171010 : static inline void icumsum(const int n, const int input[n], int output[n]) {
51 2171010 : output[0] = 0;
52 2381754 : for (int i = 1; i < n; i++) {
53 210744 : output[i] = output[i - 1] + input[i - 1];
54 : }
55 2171010 : }
56 :
57 : /*******************************************************************************
58 : * \brief Private struct used for planing during pack_matrix.
59 : * \author Ole Schuett
60 : ******************************************************************************/
61 : typedef struct {
62 : const dbm_block_t *blk; // source block
63 : int rank; // target mpi rank
64 : int row_size;
65 : int col_size;
66 : } plan_t;
67 :
68 : /*******************************************************************************
69 : * \brief Private routine for planing packs.
70 : * \author Ole Schuett
71 : ******************************************************************************/
72 416640 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
73 : const dbm_matrix_t *matrix,
74 : const dbm_mpi_comm_t comm,
75 : const dbm_dist_1d_t *dist_indices,
76 : const dbm_dist_1d_t *dist_ticks, const int nticks,
77 : const int npacks, plan_t *plans_per_pack[npacks],
78 : int nblks_per_pack[npacks],
79 : int ndata_per_pack[npacks]) {
80 :
81 416640 : memset(nblks_per_pack, 0, npacks * sizeof(int));
82 416640 : memset(ndata_per_pack, 0, npacks * sizeof(int));
83 :
84 416640 : #pragma omp parallel
85 : {
86 : // 1st pass: Compute number of blocks that will be send in each pack.
87 : int nblks_mythread[npacks];
88 : memset(nblks_mythread, 0, npacks * sizeof(int));
89 : #pragma omp for schedule(static)
90 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
91 : dbm_shard_t *shard = &matrix->shards[ishard];
92 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
93 : const dbm_block_t *blk = &shard->blocks[iblock];
94 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
95 : const int itick = (1021 * sum_index) % nticks; // 1021 = a random prime
96 : const int ipack = itick / dist_ticks->nranks;
97 : nblks_mythread[ipack]++;
98 : }
99 : }
100 :
101 : // Sum nblocks across threads and allocate arrays for plans.
102 : #pragma omp critical
103 : for (int ipack = 0; ipack < npacks; ipack++) {
104 : nblks_per_pack[ipack] += nblks_mythread[ipack];
105 : nblks_mythread[ipack] = nblks_per_pack[ipack];
106 : }
107 : #pragma omp barrier
108 : #pragma omp for
109 : for (int ipack = 0; ipack < npacks; ipack++) {
110 : const int nblks = nblks_per_pack[ipack];
111 : plans_per_pack[ipack] = malloc(nblks * sizeof(plan_t));
112 : assert(plans_per_pack[ipack] != NULL || nblks == 0);
113 : }
114 :
115 : // 2nd pass: Plan where to send each block.
116 : int ndata_mythread[npacks];
117 : memset(ndata_mythread, 0, npacks * sizeof(int));
118 : #pragma omp for schedule(static) // Need static to match previous loop.
119 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
120 : dbm_shard_t *shard = &matrix->shards[ishard];
121 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
122 : const dbm_block_t *blk = &shard->blocks[iblock];
123 : const int free_index = (trans_matrix) ? blk->col : blk->row;
124 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
125 : const int itick = (1021 * sum_index) % nticks; // Same mapping as above.
126 : const int ipack = itick / dist_ticks->nranks;
127 : // Compute rank to which this block should be sent.
128 : const int coord_free_idx = dist_indices->index2coord[free_index];
129 : const int coord_sum_idx = itick % dist_ticks->nranks;
130 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
131 : (trans_dist) ? coord_free_idx : coord_sum_idx};
132 : const int rank = dbm_mpi_cart_rank(comm, coords);
133 : const int row_size = matrix->row_sizes[blk->row];
134 : const int col_size = matrix->col_sizes[blk->col];
135 : ndata_mythread[ipack] += row_size * col_size;
136 : // Create plan.
137 : const int iplan = --nblks_mythread[ipack];
138 : plans_per_pack[ipack][iplan].blk = blk;
139 : plans_per_pack[ipack][iplan].rank = rank;
140 : plans_per_pack[ipack][iplan].row_size = row_size;
141 : plans_per_pack[ipack][iplan].col_size = col_size;
142 : }
143 : }
144 : #pragma omp critical
145 : for (int ipack = 0; ipack < npacks; ipack++) {
146 : ndata_per_pack[ipack] += ndata_mythread[ipack];
147 : }
148 : } // end of omp parallel region
149 416640 : }
150 :
151 : /*******************************************************************************
152 : * \brief Private routine for filling send buffers.
153 : * \author Ole Schuett
154 : ******************************************************************************/
155 434202 : static void fill_send_buffers(
156 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
157 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
158 : int blks_send_count[nranks], int data_send_count[nranks],
159 : int blks_send_displ[nranks], int data_send_displ[nranks],
160 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
161 :
162 434202 : memset(blks_send_count, 0, nranks * sizeof(int));
163 434202 : memset(data_send_count, 0, nranks * sizeof(int));
164 :
165 434202 : #pragma omp parallel
166 : {
167 : // 3th pass: Compute per rank nblks and ndata.
168 : int nblks_mythread[nranks], ndata_mythread[nranks];
169 : memset(nblks_mythread, 0, nranks * sizeof(int));
170 : memset(ndata_mythread, 0, nranks * sizeof(int));
171 : #pragma omp for schedule(static)
172 : for (int iblock = 0; iblock < nblks_send; iblock++) {
173 : const plan_t *plan = &plans[iblock];
174 : nblks_mythread[plan->rank] += 1;
175 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
176 : }
177 :
178 : // Sum nblks and ndata across threads.
179 : #pragma omp critical
180 : for (int irank = 0; irank < nranks; irank++) {
181 : blks_send_count[irank] += nblks_mythread[irank];
182 : data_send_count[irank] += ndata_mythread[irank];
183 : nblks_mythread[irank] = blks_send_count[irank];
184 : ndata_mythread[irank] = data_send_count[irank];
185 : }
186 : #pragma omp barrier
187 :
188 : // Compute send displacements.
189 : #pragma omp master
190 : {
191 : icumsum(nranks, blks_send_count, blks_send_displ);
192 : icumsum(nranks, data_send_count, data_send_displ);
193 : const int m = nranks - 1;
194 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
195 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
196 : }
197 : #pragma omp barrier
198 :
199 : // 4th pass: Fill blks_send and data_send arrays.
200 : #pragma omp for schedule(static) // Need static to match previous loop.
201 : for (int iblock = 0; iblock < nblks_send; iblock++) {
202 : const plan_t *plan = &plans[iblock];
203 : const dbm_block_t *blk = plan->blk;
204 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
205 : const dbm_shard_t *shard = &matrix->shards[ishard];
206 : const double *blk_data = &shard->data[blk->offset];
207 : const int row_size = plan->row_size, col_size = plan->col_size;
208 : const int plan_size = row_size * col_size;
209 : const int irank = plan->rank;
210 :
211 : // The blk_send_data is ordered by rank, thread, and block.
212 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
213 : // ndata_mythread[irank]: Current threads offset within data for irank.
214 : nblks_mythread[irank] -= 1;
215 : ndata_mythread[irank] -= plan_size;
216 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
217 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
218 :
219 : double norm = 0.0; // Compute norm as double...
220 : if (trans_matrix) {
221 : // Transpose block to allow for outer-product style multiplication.
222 : for (int i = 0; i < row_size; i++) {
223 : for (int j = 0; j < col_size; j++) {
224 : const double element = blk_data[j * row_size + i];
225 : data_send[offset + i * col_size + j] = element;
226 : norm += element * element;
227 : }
228 : }
229 : blks_send[jblock].free_index = plan->blk->col;
230 : blks_send[jblock].sum_index = plan->blk->row;
231 : } else {
232 : for (int i = 0; i < plan_size; i++) {
233 : const double element = blk_data[i];
234 : data_send[offset + i] = element;
235 : norm += element * element;
236 : }
237 : blks_send[jblock].free_index = plan->blk->row;
238 : blks_send[jblock].sum_index = plan->blk->col;
239 : }
240 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
241 :
242 : // After the block exchange data_recv_displ will be added to the offsets.
243 : blks_send[jblock].offset = offset - data_send_displ[irank];
244 : }
245 : } // end of omp parallel region
246 434202 : }
247 :
248 : /*******************************************************************************
249 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
250 : * \author Ole Schuett
251 : ******************************************************************************/
252 70680206 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
253 70680206 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
254 70680206 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
255 70680206 : return blk_a->sum_index - blk_b->sum_index;
256 : }
257 :
258 : /*******************************************************************************
259 : * \brief Private routine for post-processing received blocks.
260 : * \author Ole Schuett
261 : ******************************************************************************/
262 434202 : static void postprocess_received_blocks(
263 : const int nranks, const int nshards, const int nblocks_recv,
264 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
265 : const int data_recv_displ[nranks],
266 434202 : dbm_pack_block_t blks_recv[nblocks_recv]) {
267 :
268 434202 : int nblocks_per_shard[nshards], shard_start[nshards];
269 434202 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
270 434202 : dbm_pack_block_t *blocks_tmp =
271 434202 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
272 434202 : assert(blocks_tmp != NULL || nblocks_recv == 0);
273 :
274 434202 : #pragma omp parallel
275 : {
276 : // Add data_recv_displ to recveived block offsets.
277 : for (int irank = 0; irank < nranks; irank++) {
278 : #pragma omp for
279 : for (int i = 0; i < blks_recv_count[irank]; i++) {
280 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
281 : }
282 : }
283 :
284 : // First use counting sort to group blocks by their free_index shard.
285 : int nblocks_mythread[nshards];
286 : memset(nblocks_mythread, 0, nshards * sizeof(int));
287 : #pragma omp for schedule(static)
288 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
289 : blocks_tmp[iblock] = blks_recv[iblock];
290 : const int ishard = blks_recv[iblock].free_index % nshards;
291 : nblocks_mythread[ishard]++;
292 : }
293 : #pragma omp critical
294 : for (int ishard = 0; ishard < nshards; ishard++) {
295 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
296 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
297 : }
298 : #pragma omp barrier
299 : #pragma omp master
300 : icumsum(nshards, nblocks_per_shard, shard_start);
301 : #pragma omp barrier
302 : #pragma omp for schedule(static) // Need static to match previous loop.
303 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
304 : const int ishard = blocks_tmp[iblock].free_index % nshards;
305 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
306 : blks_recv[jblock] = blocks_tmp[iblock];
307 : }
308 :
309 : // Then sort blocks within each shard by their sum_index.
310 : #pragma omp for
311 : for (int ishard = 0; ishard < nshards; ishard++) {
312 : if (nblocks_per_shard[ishard] > 1) {
313 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
314 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
315 : }
316 : }
317 : } // end of omp parallel region
318 :
319 434202 : free(blocks_tmp);
320 434202 : }
321 :
322 : /*******************************************************************************
323 : * \brief Private routine for redistributing a matrix along selected dimensions.
324 : * \author Ole Schuett
325 : ******************************************************************************/
326 416640 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
327 : const bool trans_dist,
328 : const dbm_matrix_t *matrix,
329 : const dbm_distribution_t *dist,
330 416640 : const int nticks) {
331 :
332 416640 : assert(dbm_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
333 :
334 : // The row/col indicies are distributed along one cart dimension and the
335 : // ticks are distributed along the other cart dimension.
336 416640 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
337 416640 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
338 :
339 : // Allocate packed matrix.
340 416640 : const int nsend_packs = nticks / dist_ticks->nranks;
341 416640 : assert(nsend_packs * dist_ticks->nranks == nticks);
342 416640 : dbm_packed_matrix_t packed;
343 416640 : packed.dist_indices = dist_indices;
344 416640 : packed.dist_ticks = dist_ticks;
345 416640 : packed.nsend_packs = nsend_packs;
346 416640 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
347 416640 : assert(packed.send_packs != NULL || nsend_packs == 0);
348 :
349 : // Plan all packs.
350 416640 : plan_t *plans_per_pack[nsend_packs];
351 416640 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
352 416640 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
353 : dist_ticks, nticks, nsend_packs, plans_per_pack,
354 : nblks_send_per_pack, ndata_send_per_pack);
355 :
356 : // Allocate send buffers for maximum number of blocks/data over all packs.
357 416640 : int nblks_send_max = 0, ndata_send_max = 0;
358 850842 : for (int ipack = 0; ipack < nsend_packs; ++ipack) {
359 434202 : nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
360 434202 : ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
361 : }
362 416640 : dbm_pack_block_t *blks_send =
363 416640 : dbm_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
364 416640 : double *data_send = dbm_mpi_alloc_mem(ndata_send_max * sizeof(double));
365 :
366 : // Cannot parallelize over packs (there might be too few of them).
367 850842 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
368 : // Fill send buffers according to plans.
369 434202 : const int nranks = dist->nranks;
370 434202 : int blks_send_count[nranks], data_send_count[nranks];
371 434202 : int blks_send_displ[nranks], data_send_displ[nranks];
372 434202 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
373 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
374 : blks_send_count, data_send_count, blks_send_displ,
375 : data_send_displ, blks_send, data_send);
376 434202 : free(plans_per_pack[ipack]);
377 :
378 : // 1st communication: Exchange block counts.
379 434202 : int blks_recv_count[nranks], blks_recv_displ[nranks];
380 434202 : dbm_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
381 434202 : icumsum(nranks, blks_recv_count, blks_recv_displ);
382 434202 : const int nblocks_recv = isum(nranks, blks_recv_count);
383 :
384 : // 2nd communication: Exchange blocks.
385 434202 : dbm_pack_block_t *blks_recv =
386 434202 : dbm_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
387 434202 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
388 434202 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
389 921090 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
390 486888 : blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
391 486888 : blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
392 486888 : blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
393 486888 : blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
394 : }
395 434202 : dbm_mpi_alltoallv_byte(
396 : blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
397 434202 : blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
398 :
399 : // 3rd communication: Exchange data counts.
400 : // TODO: could be computed from blks_recv.
401 434202 : int data_recv_count[nranks], data_recv_displ[nranks];
402 434202 : dbm_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
403 434202 : icumsum(nranks, data_recv_count, data_recv_displ);
404 434202 : const int ndata_recv = isum(nranks, data_recv_count);
405 :
406 : // 4th communication: Exchange data.
407 434202 : double *data_recv = dbm_mempool_host_malloc(ndata_recv * sizeof(double));
408 434202 : dbm_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
409 : data_recv, data_recv_count, data_recv_displ,
410 434202 : dist->comm);
411 :
412 : // Post-process received blocks and assemble them into a pack.
413 434202 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
414 : blks_recv_count, blks_recv_displ,
415 : data_recv_displ, blks_recv);
416 434202 : packed.send_packs[ipack].nblocks = nblocks_recv;
417 434202 : packed.send_packs[ipack].data_size = ndata_recv;
418 434202 : packed.send_packs[ipack].blocks = blks_recv;
419 434202 : packed.send_packs[ipack].data = data_recv;
420 : }
421 :
422 : // Deallocate send buffers.
423 416640 : dbm_mpi_free_mem(blks_send);
424 416640 : dbm_mpi_free_mem(data_send);
425 :
426 : // Allocate pack_recv.
427 416640 : int max_nblocks = 0, max_data_size = 0;
428 850842 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
429 434202 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
430 434202 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
431 : }
432 416640 : dbm_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
433 416640 : dbm_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
434 416640 : packed.max_nblocks = max_nblocks;
435 416640 : packed.max_data_size = max_data_size;
436 833280 : packed.recv_pack.blocks =
437 416640 : dbm_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
438 833280 : packed.recv_pack.data =
439 416640 : dbm_mempool_host_malloc(packed.max_data_size * sizeof(double));
440 :
441 416640 : return packed; // Ownership of packed transfers to caller.
442 : }
443 :
444 : /*******************************************************************************
445 : * \brief Private routine for sending and receiving the pack for the given tick.
446 : * \author Ole Schuett
447 : ******************************************************************************/
448 451764 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
449 : dbm_packed_matrix_t *packed) {
450 451764 : const int nranks = packed->dist_ticks->nranks;
451 451764 : const int my_rank = packed->dist_ticks->my_rank;
452 :
453 : // Compute send rank and pack.
454 451764 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
455 451764 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
456 451764 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
457 451764 : const int send_ipack = send_itick / nranks;
458 451764 : assert(send_itick % nranks == my_rank);
459 :
460 : // Compute receive rank and pack.
461 451764 : const int recv_rank = itick % nranks;
462 451764 : const int recv_ipack = itick / nranks;
463 :
464 451764 : dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
465 451764 : if (send_rank == my_rank) {
466 434202 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
467 : return send_pack; // Local pack, no mpi needed.
468 : } else {
469 : // Exchange blocks.
470 35124 : const int nblocks_in_bytes = dbm_mpi_sendrecv_byte(
471 17562 : /*sendbuf=*/send_pack->blocks,
472 17562 : /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
473 : /*dest=*/send_rank,
474 : /*sendtag=*/send_ipack,
475 17562 : /*recvbuf=*/packed->recv_pack.blocks,
476 17562 : /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
477 : /*source=*/recv_rank,
478 : /*recvtag=*/recv_ipack,
479 17562 : /*comm=*/packed->dist_ticks->comm);
480 :
481 17562 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
482 17562 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
483 :
484 : // Exchange data.
485 35124 : packed->recv_pack.data_size = dbm_mpi_sendrecv_double(
486 17562 : /*sendbuf=*/send_pack->data,
487 : /*sendcound=*/send_pack->data_size,
488 : /*dest=*/send_rank,
489 : /*sendtag=*/send_ipack,
490 : /*recvbuf=*/packed->recv_pack.data,
491 : /*recvcount=*/packed->max_data_size,
492 : /*source=*/recv_rank,
493 : /*recvtag=*/recv_ipack,
494 17562 : /*comm=*/packed->dist_ticks->comm);
495 :
496 17562 : return &packed->recv_pack;
497 : }
498 : }
499 :
500 : /*******************************************************************************
501 : * \brief Private routine for releasing a packed matrix.
502 : * \author Ole Schuett
503 : ******************************************************************************/
504 416640 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
505 416640 : dbm_mpi_free_mem(packed->recv_pack.blocks);
506 416640 : dbm_mempool_host_free(packed->recv_pack.data);
507 850842 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
508 434202 : dbm_mpi_free_mem(packed->send_packs[ipack].blocks);
509 434202 : dbm_mempool_host_free(packed->send_packs[ipack].data);
510 : }
511 416640 : free(packed->send_packs);
512 416640 : }
513 :
514 : /*******************************************************************************
515 : * \brief Internal routine for creating a communication iterator.
516 : * \author Ole Schuett
517 : ******************************************************************************/
518 208320 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
519 : const bool transb,
520 : const dbm_matrix_t *matrix_a,
521 : const dbm_matrix_t *matrix_b,
522 : const dbm_matrix_t *matrix_c) {
523 :
524 208320 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
525 208320 : assert(iter != NULL);
526 208320 : iter->dist = matrix_c->dist;
527 :
528 : // During each communication tick we'll fetch a pack_a and pack_b.
529 : // Since the cart might be non-squared, the number of communication ticks is
530 : // chosen as the least common multiple of the cart's dimensions.
531 208320 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
532 208320 : iter->itick = 0;
533 :
534 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
535 208320 : iter->packed_a =
536 208320 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
537 208320 : iter->packed_b =
538 208320 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
539 :
540 208320 : return iter;
541 : }
542 :
543 : /*******************************************************************************
544 : * \brief Internal routine for retriving next pair of packs from given iterator.
545 : * \author Ole Schuett
546 : ******************************************************************************/
547 434202 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
548 : dbm_pack_t **pack_b) {
549 434202 : if (iter->itick >= iter->nticks) {
550 : return false; // end of iterator reached
551 : }
552 :
553 : // Start each rank at a different tick to spread the load on the sources.
554 225882 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
555 225882 : const int shifted_itick = (iter->itick + shift) % iter->nticks;
556 225882 : *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
557 225882 : *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
558 :
559 225882 : iter->itick++;
560 225882 : return true;
561 : }
562 :
563 : /*******************************************************************************
564 : * \brief Internal routine for releasing the given communication iterator.
565 : * \author Ole Schuett
566 : ******************************************************************************/
567 208320 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
568 208320 : free_packed_matrix(&iter->packed_a);
569 208320 : free_packed_matrix(&iter->packed_b);
570 208320 : free(iter);
571 208320 : }
572 :
573 : // EOF
|