Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include "dbm_multiply_comm.h"
9 :
10 : #include <assert.h>
11 : #include <stdlib.h>
12 : #include <string.h>
13 :
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_mempool.h"
16 :
17 : /*******************************************************************************
18 : * \brief Returns the larger of two given integer (missing from the C standard)
19 : * \author Ole Schuett
20 : ******************************************************************************/
21 259162 : static inline int imax(int x, int y) { return (x > y ? x : y); }
22 :
23 : /*******************************************************************************
24 : * \brief Private routine for computing greatest common divisor of two numbers.
25 : * \author Ole Schuett
26 : ******************************************************************************/
27 258810 : static int gcd(const int a, const int b) {
28 258810 : if (a == 0)
29 : return b;
30 132681 : return gcd(b % a, a); // Euclid's algorithm.
31 : }
32 :
33 : /*******************************************************************************
34 : * \brief Private routine for computing least common multiple of two numbers.
35 : * \author Ole Schuett
36 : ******************************************************************************/
37 126129 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
38 :
39 : /*******************************************************************************
40 : * \brief Private routine for computing the sum of the given integers.
41 : * \author Ole Schuett
42 : ******************************************************************************/
43 518324 : static inline int isum(const int n, const int input[n]) {
44 518324 : int output = 0;
45 1078072 : for (int i = 0; i < n; i++) {
46 559748 : output += input[i];
47 : }
48 518324 : return output;
49 : }
50 :
51 : /*******************************************************************************
52 : * \brief Private routine for computing the cumulative sums of given numbers.
53 : * \author Ole Schuett
54 : ******************************************************************************/
55 1295810 : static inline void icumsum(const int n, const int input[n], int output[n]) {
56 1295810 : output[0] = 0;
57 1378658 : for (int i = 1; i < n; i++) {
58 82848 : output[i] = output[i - 1] + input[i - 1];
59 : }
60 1295810 : }
61 :
62 : /*******************************************************************************
63 : * \brief Private struct used for planing during pack_matrix.
64 : * \author Ole Schuett
65 : ******************************************************************************/
66 : typedef struct {
67 : const dbm_block_t *blk; // source block
68 : int rank; // target mpi rank
69 : int row_size;
70 : int col_size;
71 : } plan_t;
72 :
73 : /*******************************************************************************
74 : * \brief Private routine for planing packs.
75 : * \author Ole Schuett
76 : ******************************************************************************/
77 252258 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
78 : const dbm_matrix_t *matrix,
79 : const dbm_mpi_comm_t comm,
80 : const dbm_dist_1d_t *dist_indices,
81 : const dbm_dist_1d_t *dist_ticks, const int nticks,
82 : const int npacks, plan_t *plans_per_pack[npacks],
83 : int nblks_per_pack[npacks],
84 : int ndata_per_pack[npacks]) {
85 :
86 252258 : memset(nblks_per_pack, 0, npacks * sizeof(int));
87 252258 : memset(ndata_per_pack, 0, npacks * sizeof(int));
88 :
89 252258 : #pragma omp parallel
90 : {
91 : // 1st pass: Compute number of blocks that will be send in each pack.
92 : int nblks_mythread[npacks];
93 : memset(nblks_mythread, 0, npacks * sizeof(int));
94 : #pragma omp for schedule(static)
95 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
96 : dbm_shard_t *shard = &matrix->shards[ishard];
97 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
98 : const dbm_block_t *blk = &shard->blocks[iblock];
99 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
100 : const int itick = (1021 * sum_index) % nticks; // 1021 = a random prime
101 : const int ipack = itick / dist_ticks->nranks;
102 : nblks_mythread[ipack]++;
103 : }
104 : }
105 :
106 : // Sum nblocks across threads and allocate arrays for plans.
107 : #pragma omp critical
108 : for (int ipack = 0; ipack < npacks; ipack++) {
109 : nblks_per_pack[ipack] += nblks_mythread[ipack];
110 : nblks_mythread[ipack] = nblks_per_pack[ipack];
111 : }
112 : #pragma omp barrier
113 : #pragma omp for
114 : for (int ipack = 0; ipack < npacks; ipack++) {
115 : plans_per_pack[ipack] = malloc(nblks_per_pack[ipack] * sizeof(plan_t));
116 : }
117 :
118 : // 2nd pass: Plan where to send each block.
119 : int ndata_mythread[npacks];
120 : memset(ndata_mythread, 0, npacks * sizeof(int));
121 : #pragma omp for schedule(static) // Need static to match previous loop.
122 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
123 : dbm_shard_t *shard = &matrix->shards[ishard];
124 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
125 : const dbm_block_t *blk = &shard->blocks[iblock];
126 : const int free_index = (trans_matrix) ? blk->col : blk->row;
127 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
128 : const int itick = (1021 * sum_index) % nticks; // Same mapping as above.
129 : const int ipack = itick / dist_ticks->nranks;
130 : // Compute rank to which this block should be sent.
131 : const int coord_free_idx = dist_indices->index2coord[free_index];
132 : const int coord_sum_idx = itick % dist_ticks->nranks;
133 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
134 : (trans_dist) ? coord_free_idx : coord_sum_idx};
135 : const int rank = dbm_mpi_cart_rank(comm, coords);
136 : const int row_size = matrix->row_sizes[blk->row];
137 : const int col_size = matrix->col_sizes[blk->col];
138 : ndata_mythread[ipack] += row_size * col_size;
139 : // Create plan.
140 : const int iplan = --nblks_mythread[ipack];
141 : plans_per_pack[ipack][iplan].blk = blk;
142 : plans_per_pack[ipack][iplan].rank = rank;
143 : plans_per_pack[ipack][iplan].row_size = row_size;
144 : plans_per_pack[ipack][iplan].col_size = col_size;
145 : }
146 : }
147 : #pragma omp critical
148 : for (int ipack = 0; ipack < npacks; ipack++) {
149 : ndata_per_pack[ipack] += ndata_mythread[ipack];
150 : }
151 : } // end of omp parallel region
152 252258 : }
153 :
154 : /*******************************************************************************
155 : * \brief Private routine for filling send buffers.
156 : * \author Ole Schuett
157 : ******************************************************************************/
158 259162 : static void fill_send_buffers(
159 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
160 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
161 : int blks_send_count[nranks], int data_send_count[nranks],
162 : int blks_send_displ[nranks], int data_send_displ[nranks],
163 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
164 :
165 259162 : memset(blks_send_count, 0, nranks * sizeof(int));
166 259162 : memset(data_send_count, 0, nranks * sizeof(int));
167 :
168 259162 : #pragma omp parallel
169 : {
170 : // 3th pass: Compute per rank nblks and ndata.
171 : int nblks_mythread[nranks], ndata_mythread[nranks];
172 : memset(nblks_mythread, 0, nranks * sizeof(int));
173 : memset(ndata_mythread, 0, nranks * sizeof(int));
174 : #pragma omp for schedule(static)
175 : for (int iblock = 0; iblock < nblks_send; iblock++) {
176 : const plan_t *plan = &plans[iblock];
177 : nblks_mythread[plan->rank] += 1;
178 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
179 : }
180 :
181 : // Sum nblks and ndata across threads.
182 : #pragma omp critical
183 : for (int irank = 0; irank < nranks; irank++) {
184 : blks_send_count[irank] += nblks_mythread[irank];
185 : data_send_count[irank] += ndata_mythread[irank];
186 : nblks_mythread[irank] = blks_send_count[irank];
187 : ndata_mythread[irank] = data_send_count[irank];
188 : }
189 : #pragma omp barrier
190 :
191 : // Compute send displacements.
192 : #pragma omp master
193 : {
194 : icumsum(nranks, blks_send_count, blks_send_displ);
195 : icumsum(nranks, data_send_count, data_send_displ);
196 : const int m = nranks - 1;
197 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
198 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
199 : }
200 : #pragma omp barrier
201 :
202 : // 4th pass: Fill blks_send and data_send arrays.
203 : #pragma omp for schedule(static) // Need static to match previous loop.
204 : for (int iblock = 0; iblock < nblks_send; iblock++) {
205 : const plan_t *plan = &plans[iblock];
206 : const dbm_block_t *blk = plan->blk;
207 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
208 : const dbm_shard_t *shard = &matrix->shards[ishard];
209 : const double *blk_data = &shard->data[blk->offset];
210 : const int row_size = plan->row_size, col_size = plan->col_size;
211 : const int irank = plan->rank;
212 :
213 : // The blk_send_data is ordered by rank, thread, and block.
214 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
215 : // ndata_mythread[irank]: Current threads offset within data for irank.
216 : nblks_mythread[irank] -= 1;
217 : ndata_mythread[irank] -= row_size * col_size;
218 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
219 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
220 :
221 : double norm = 0.0; // Compute norm as double...
222 : if (trans_matrix) {
223 : // Transpose block to allow for outer-product style multiplication.
224 : for (int i = 0; i < row_size; i++) {
225 : for (int j = 0; j < col_size; j++) {
226 : const double element = blk_data[j * row_size + i];
227 : norm += element * element;
228 : data_send[offset + i * col_size + j] = element;
229 : }
230 : }
231 : blks_send[jblock].free_index = plan->blk->col;
232 : blks_send[jblock].sum_index = plan->blk->row;
233 : } else {
234 : for (int i = 0; i < row_size * col_size; i++) {
235 : const double element = blk_data[i];
236 : norm += element * element;
237 : data_send[offset + i] = element;
238 : }
239 : blks_send[jblock].free_index = plan->blk->row;
240 : blks_send[jblock].sum_index = plan->blk->col;
241 : }
242 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
243 :
244 : // After the block exchange data_recv_displ will be added to the offsets.
245 : blks_send[jblock].offset = offset - data_send_displ[irank];
246 : }
247 : } // end of omp parallel region
248 259162 : }
249 :
250 : /*******************************************************************************
251 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
252 : * \author Ole Schuett
253 : ******************************************************************************/
254 59096854 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
255 59096854 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
256 59096854 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
257 59096854 : return blk_a->sum_index - blk_b->sum_index;
258 : }
259 :
260 : /*******************************************************************************
261 : * \brief Private routine for post-processing received blocks.
262 : * \author Ole Schuett
263 : ******************************************************************************/
264 259162 : static void postprocess_received_blocks(
265 : const int nranks, const int nshards, const int nblocks_recv,
266 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
267 : const int data_recv_displ[nranks],
268 259162 : dbm_pack_block_t blks_recv[nblocks_recv]) {
269 :
270 259162 : int nblocks_per_shard[nshards], shard_start[nshards];
271 259162 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
272 259162 : dbm_pack_block_t *blocks_tmp =
273 259162 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
274 :
275 259162 : #pragma omp parallel
276 : {
277 : // Add data_recv_displ to recveived block offsets.
278 : for (int irank = 0; irank < nranks; irank++) {
279 : #pragma omp for
280 : for (int i = 0; i < blks_recv_count[irank]; i++) {
281 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
282 : }
283 : }
284 :
285 : // First use counting sort to group blocks by their free_index shard.
286 : int nblocks_mythread[nshards];
287 : memset(nblocks_mythread, 0, nshards * sizeof(int));
288 : #pragma omp for schedule(static)
289 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
290 : blocks_tmp[iblock] = blks_recv[iblock];
291 : const int ishard = blks_recv[iblock].free_index % nshards;
292 : nblocks_mythread[ishard]++;
293 : }
294 : #pragma omp critical
295 : for (int ishard = 0; ishard < nshards; ishard++) {
296 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
297 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
298 : }
299 : #pragma omp barrier
300 : #pragma omp master
301 : icumsum(nshards, nblocks_per_shard, shard_start);
302 : #pragma omp barrier
303 : #pragma omp for schedule(static) // Need static to match previous loop.
304 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
305 : const int ishard = blocks_tmp[iblock].free_index % nshards;
306 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
307 : blks_recv[jblock] = blocks_tmp[iblock];
308 : }
309 :
310 : // Then sort blocks within each shard by their sum_index.
311 : #pragma omp for
312 : for (int ishard = 0; ishard < nshards; ishard++) {
313 : if (nblocks_per_shard[ishard] > 1) {
314 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
315 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
316 : }
317 : }
318 : } // end of omp parallel region
319 :
320 259162 : free(blocks_tmp);
321 259162 : }
322 :
323 : /*******************************************************************************
324 : * \brief Private routine for redistributing a matrix along selected dimensions.
325 : * \author Ole Schuett
326 : ******************************************************************************/
327 252258 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
328 : const bool trans_dist,
329 : const dbm_matrix_t *matrix,
330 : const dbm_distribution_t *dist,
331 252258 : const int nticks) {
332 :
333 252258 : assert(dbm_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
334 :
335 : // The row/col indicies are distributed along one cart dimension and the
336 : // ticks are distributed along the other cart dimension.
337 252258 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
338 252258 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
339 :
340 : // Allocate packed matrix.
341 252258 : const int nsend_packs = nticks / dist_ticks->nranks;
342 252258 : assert(nsend_packs * dist_ticks->nranks == nticks);
343 252258 : dbm_packed_matrix_t packed;
344 252258 : packed.dist_indices = dist_indices;
345 252258 : packed.dist_ticks = dist_ticks;
346 252258 : packed.nsend_packs = nsend_packs;
347 252258 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
348 :
349 : // Plan all packs.
350 252258 : plan_t *plans_per_pack[nsend_packs];
351 252258 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
352 252258 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
353 : dist_ticks, nticks, nsend_packs, plans_per_pack,
354 : nblks_send_per_pack, ndata_send_per_pack);
355 :
356 : // Can not parallelize over packs because there might be too few of them.
357 511420 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
358 : // Allocate send buffers.
359 259162 : dbm_pack_block_t *blks_send =
360 259162 : malloc(nblks_send_per_pack[ipack] * sizeof(dbm_pack_block_t));
361 259162 : double *data_send =
362 259162 : dbm_mempool_host_malloc(ndata_send_per_pack[ipack] * sizeof(double));
363 :
364 : // Fill send buffers according to plans.
365 259162 : const int nranks = dist->nranks;
366 259162 : int blks_send_count[nranks], data_send_count[nranks];
367 259162 : int blks_send_displ[nranks], data_send_displ[nranks];
368 259162 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
369 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
370 : blks_send_count, data_send_count, blks_send_displ,
371 : data_send_displ, blks_send, data_send);
372 259162 : free(plans_per_pack[ipack]);
373 :
374 : // 1st communication: Exchange block counts.
375 259162 : int blks_recv_count[nranks], blks_recv_displ[nranks];
376 259162 : dbm_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
377 259162 : icumsum(nranks, blks_recv_count, blks_recv_displ);
378 259162 : const int nblocks_recv = isum(nranks, blks_recv_count);
379 :
380 : // 2nd communication: Exchange blocks.
381 259162 : dbm_pack_block_t *blks_recv =
382 259162 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
383 259162 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
384 259162 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
385 539036 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
386 279874 : blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
387 279874 : blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
388 279874 : blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
389 279874 : blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
390 : }
391 259162 : dbm_mpi_alltoallv_byte(
392 : blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
393 259162 : blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
394 259162 : free(blks_send);
395 :
396 : // 3rd communication: Exchange data counts.
397 : // TODO: could be computed from blks_recv.
398 259162 : int data_recv_count[nranks], data_recv_displ[nranks];
399 259162 : dbm_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
400 259162 : icumsum(nranks, data_recv_count, data_recv_displ);
401 259162 : const int ndata_recv = isum(nranks, data_recv_count);
402 :
403 : // 4th communication: Exchange data.
404 259162 : double *data_recv = dbm_mempool_host_malloc(ndata_recv * sizeof(double));
405 259162 : dbm_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
406 : data_recv, data_recv_count, data_recv_displ,
407 259162 : dist->comm);
408 259162 : dbm_mempool_free(data_send);
409 :
410 : // Post-process received blocks and assemble them into a pack.
411 259162 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
412 : blks_recv_count, blks_recv_displ,
413 : data_recv_displ, blks_recv);
414 259162 : packed.send_packs[ipack].nblocks = nblocks_recv;
415 259162 : packed.send_packs[ipack].data_size = ndata_recv;
416 259162 : packed.send_packs[ipack].blocks = blks_recv;
417 259162 : packed.send_packs[ipack].data = data_recv;
418 : }
419 :
420 : // Allocate pack_recv.
421 252258 : int max_nblocks = 0, max_data_size = 0;
422 511420 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
423 259162 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
424 259162 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
425 : }
426 252258 : dbm_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
427 252258 : dbm_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
428 252258 : packed.max_nblocks = max_nblocks;
429 252258 : packed.max_data_size = max_data_size;
430 252258 : packed.recv_pack.blocks =
431 252258 : malloc(packed.max_nblocks * sizeof(dbm_pack_block_t));
432 504516 : packed.recv_pack.data =
433 252258 : dbm_mempool_host_malloc(packed.max_data_size * sizeof(double));
434 :
435 252258 : return packed; // Ownership of packed transfers to caller.
436 : }
437 :
438 : /*******************************************************************************
439 : * \brief Private routine for sending and receiving the pack for the given tick.
440 : * \author Ole Schuett
441 : ******************************************************************************/
442 266066 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
443 : dbm_packed_matrix_t *packed) {
444 266066 : const int nranks = packed->dist_ticks->nranks;
445 266066 : const int my_rank = packed->dist_ticks->my_rank;
446 :
447 : // Compute send rank and pack.
448 266066 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
449 266066 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
450 266066 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
451 266066 : const int send_ipack = send_itick / nranks;
452 266066 : assert(send_itick % nranks == my_rank);
453 :
454 : // Compute receive rank and pack.
455 266066 : const int recv_rank = itick % nranks;
456 266066 : const int recv_ipack = itick / nranks;
457 :
458 266066 : if (send_rank == my_rank) {
459 259162 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
460 259162 : return &packed->send_packs[send_ipack]; // Local pack, no mpi needed.
461 : } else {
462 6904 : const dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
463 :
464 : // Exchange blocks.
465 13808 : const int nblocks_in_bytes = dbm_mpi_sendrecv_byte(
466 6904 : /*sendbuf=*/send_pack->blocks,
467 6904 : /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
468 : /*dest=*/send_rank,
469 : /*sendtag=*/send_ipack,
470 6904 : /*recvbuf=*/packed->recv_pack.blocks,
471 6904 : /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
472 : /*source=*/recv_rank,
473 : /*recvtag=*/recv_ipack,
474 6904 : /*comm=*/packed->dist_ticks->comm);
475 :
476 6904 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
477 6904 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
478 :
479 : // Exchange data.
480 13808 : packed->recv_pack.data_size = dbm_mpi_sendrecv_double(
481 6904 : /*sendbuf=*/send_pack->data,
482 6904 : /*sendcound=*/send_pack->data_size,
483 : /*dest=*/send_rank,
484 : /*sendtag=*/send_ipack,
485 : /*recvbuf=*/packed->recv_pack.data,
486 : /*recvcount=*/packed->max_data_size,
487 : /*source=*/recv_rank,
488 : /*recvtag=*/recv_ipack,
489 6904 : /*comm=*/packed->dist_ticks->comm);
490 :
491 6904 : return &packed->recv_pack;
492 : }
493 : }
494 :
495 : /*******************************************************************************
496 : * \brief Private routine for releasing a packed matrix.
497 : * \author Ole Schuett
498 : ******************************************************************************/
499 252258 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
500 252258 : free(packed->recv_pack.blocks);
501 252258 : dbm_mempool_free(packed->recv_pack.data);
502 511420 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
503 259162 : free(packed->send_packs[ipack].blocks);
504 259162 : dbm_mempool_free(packed->send_packs[ipack].data);
505 : }
506 252258 : free(packed->send_packs);
507 252258 : }
508 :
509 : /*******************************************************************************
510 : * \brief Internal routine for creating a communication iterator.
511 : * \author Ole Schuett
512 : ******************************************************************************/
513 126129 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
514 : const bool transb,
515 : const dbm_matrix_t *matrix_a,
516 : const dbm_matrix_t *matrix_b,
517 : const dbm_matrix_t *matrix_c) {
518 :
519 126129 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
520 126129 : iter->dist = matrix_c->dist;
521 :
522 : // During each communication tick we'll fetch a pack_a and pack_b.
523 : // Since the cart might be non-squared, the number of communication ticks is
524 : // chosen as the least common multiple of the cart's dimensions.
525 126129 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
526 126129 : iter->itick = 0;
527 :
528 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
529 126129 : iter->packed_a =
530 126129 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
531 126129 : iter->packed_b =
532 126129 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
533 :
534 126129 : return iter;
535 : }
536 :
537 : /*******************************************************************************
538 : * \brief Internal routine for retriving next pair of packs from given iterator.
539 : * \author Ole Schuett
540 : ******************************************************************************/
541 259162 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
542 : dbm_pack_t **pack_b) {
543 259162 : if (iter->itick >= iter->nticks) {
544 : return false; // end of iterator reached
545 : }
546 :
547 : // Start each rank at a different tick to spread the load on the sources.
548 133033 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
549 133033 : const int shifted_itick = (iter->itick + shift) % iter->nticks;
550 133033 : *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
551 133033 : *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
552 :
553 133033 : iter->itick++;
554 133033 : return true;
555 : }
556 :
557 : /*******************************************************************************
558 : * \brief Internal routine for releasing the given communication iterator.
559 : * \author Ole Schuett
560 : ******************************************************************************/
561 126129 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
562 126129 : free_packed_matrix(&iter->packed_a);
563 126129 : free_packed_matrix(&iter->packed_b);
564 126129 : free(iter);
565 126129 : }
566 :
567 : // EOF
|