LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply_comm.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:c24029e) Lines: 100.0 % 188 188
Test Date: 2026-07-04 06:36:57 Functions: 100.0 % 14 14

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : #include "dbm_multiply_comm.h"
       8              : #include "../mpiwrap/cp_mpi.h"
       9              : #include "../offload/offload_mempool.h"
      10              : 
      11              : #include <assert.h>
      12              : #include <limits.h>
      13              : #include <stdlib.h>
      14              : #include <string.h>
      15              : 
      16              : #if 1
      17              : #define DBM_MULTIPLY_COMM_MEMPOOL
      18              : #endif
      19              : 
      20              : /*******************************************************************************
      21              :  * \brief Private routine for computing greatest common divisor of two numbers.
      22              :  * \author Ole Schuett
      23              :  ******************************************************************************/
      24       506318 : static int gcd(const int a, const int b) {
      25       506318 :   if (a == 0) {
      26              :     return b;
      27              :   }
      28       263828 :   return gcd(b % a, a); // Euclid's algorithm.
      29              : }
      30              : 
      31              : /*******************************************************************************
      32              :  * \brief Private routine for computing least common multiple of two numbers.
      33              :  * \author Ole Schuett
      34              :  ******************************************************************************/
      35       242490 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
      36              : 
      37              : /*******************************************************************************
      38              :  * \brief Private routine for converting element counts to byte counts.
      39              :  * \author Hans Pabst
      40              :  ******************************************************************************/
      41      2325228 : static int checked_byte_count(const int nelements, const size_t element_size) {
      42              :   // MPI displacements/counts are int; unchecked multiplication could overflow
      43              :   // silently and corrupt the alltoallv layout.
      44      2325228 :   assert(0 <= nelements);
      45      2325228 :   assert(element_size <= INT_MAX);
      46      2325228 :   assert(nelements <= INT_MAX / (int)element_size);
      47      2325228 :   return nelements * (int)element_size;
      48              : }
      49              : 
      50              : /*******************************************************************************
      51              :  * \brief Private routine for computing the sum of the given integers.
      52              :  * \author Ole Schuett
      53              :  ******************************************************************************/
      54              : static inline int isum(const int n, const int input[n]) {
      55              :   int output = 0;
      56      3295188 :   for (int i = 0; i < n; i++) {
      57      1141208 :     output += input[i];
      58              :   }
      59      1012772 :   return output;
      60              : }
      61              : 
      62              : /*******************************************************************************
      63              :  * \brief Private routine for computing the cumulative sums of given numbers.
      64              :  * \author Ole Schuett and Hans Pabst
      65              :  ******************************************************************************/
      66      1012772 : static inline void icumsum(const int n, const int input[n], int output[n]) {
      67              :   // Exclusive prefix sum using carried accumulators to avoid a load from
      68              :   // output[i-1] each iteration (removes loop-carried memory dependency).
      69      1012772 :   int oval = output[0] = 0, ival = input[0];
      70      1141208 :   for (int i = 1; i < n; i++) {
      71       128436 :     output[i] = (oval += ival);
      72       128436 :     ival = input[i];
      73              :   }
      74              : }
      75              : 
      76              : /*******************************************************************************
      77              :  * \brief Private routine computing received data counts from block metadata.
      78              :  * \author Hans Pabst
      79              :  ******************************************************************************/
      80       506386 : static void compute_data_recv_count(const int nranks,
      81              :                                     const int blks_recv_count[nranks],
      82              :                                     const int blks_recv_displ[nranks],
      83              :                                     const int free_index_sizes[],
      84              :                                     const int sum_index_sizes[],
      85              :                                     const dbm_pack_block_t blks_recv[],
      86              :                                     int data_recv_count[nranks]) {
      87              :   // Derives per-rank data counts locally from the already-received block
      88              :   // metadata, eliminating a collective alltoall for exchanging data amounts.
      89       506386 :   memset(data_recv_count, 0, nranks * sizeof(int));
      90      1076990 :   for (int irank = 0; irank < nranks; irank++) {
      91     14392248 :     for (int i = 0; i < blks_recv_count[irank]; i++) {
      92     13821644 :       const dbm_pack_block_t *const blk =
      93     13821644 :           &blks_recv[blks_recv_displ[irank] + i];
      94     13821644 :       const int block_size =
      95     13821644 :           free_index_sizes[blk->free_index] * sum_index_sizes[blk->sum_index];
      96     13821644 :       assert(block_size >= 0);
      97     13821644 :       assert(data_recv_count[irank] <= INT_MAX - block_size);
      98     13821644 :       data_recv_count[irank] += block_size;
      99              :     }
     100              :   }
     101       506386 : }
     102              : 
     103              : /*******************************************************************************
     104              :  * \brief Private struct used for planing during pack_matrix.
     105              :  * \author Ole Schuett
     106              :  ******************************************************************************/
     107              : typedef struct {
     108              :   const dbm_block_t *blk; // source block
     109              :   int rank;               // target mpi rank
     110              :   int row_size;
     111              :   int col_size;
     112              : } plan_t;
     113              : 
     114              : /*******************************************************************************
     115              :  * \brief Private routine for calculating tick indices in pack plans.
     116              :  * \author Maximilian Graml
     117              :  ******************************************************************************/
     118              : static inline unsigned long long calculate_tick_index(int sum_index,
     119              :                                                       int nticks) {
     120              :   // 1021 is used as a random prime to scramble the index
     121              :   return ((unsigned long long)sum_index * 1021ULL) % (unsigned long long)nticks;
     122              : }
     123              : 
     124              : /*******************************************************************************
     125              :  * \brief Private routine for planing packs.
     126              :  * \author Ole Schuett
     127              :  ******************************************************************************/
     128       484980 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
     129              :                               const dbm_matrix_t *matrix,
     130              :                               const cp_mpi_comm_t comm,
     131              :                               const dbm_dist_1d_t *dist_indices,
     132              :                               const dbm_dist_1d_t *dist_ticks, const int nticks,
     133              :                               const int npacks, plan_t *plans_per_pack[npacks],
     134              :                               int nblks_per_pack[npacks],
     135              :                               int ndata_per_pack[npacks]) {
     136       484980 :   memset(nblks_per_pack, 0, npacks * sizeof(int));
     137       484980 :   memset(ndata_per_pack, 0, npacks * sizeof(int));
     138              : 
     139       484980 : #pragma omp parallel
     140              :   {
     141              :     // 1st pass: Compute number of blocks that will be send in each pack.
     142              :     int nblks_mythread[npacks];
     143              :     memset(nblks_mythread, 0, npacks * sizeof(int));
     144              : #pragma omp for schedule(static)
     145              :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
     146              :       dbm_shard_t *shard = &matrix->shards[ishard];
     147              :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
     148              :         const dbm_block_t *blk = &shard->blocks[iblock];
     149              :         const int sum_index = (trans_matrix) ? blk->row : blk->col;
     150              :         unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
     151              :         const int ipack = itick64 / dist_ticks->nranks;
     152              :         nblks_mythread[ipack]++;
     153              :       }
     154              :     }
     155              : 
     156              :     // Sum nblocks across threads and allocate arrays for plans.
     157              : #pragma omp critical
     158              :     for (int ipack = 0; ipack < npacks; ipack++) {
     159              :       nblks_per_pack[ipack] += nblks_mythread[ipack];
     160              :       nblks_mythread[ipack] = nblks_per_pack[ipack];
     161              :     }
     162              : #pragma omp barrier
     163              : #pragma omp for
     164              :     for (int ipack = 0; ipack < npacks; ipack++) {
     165              :       const int nblks = nblks_per_pack[ipack];
     166              :       plans_per_pack[ipack] = malloc(nblks * sizeof(plan_t));
     167              :       assert(plans_per_pack[ipack] != NULL || nblks == 0);
     168              :     }
     169              : 
     170              :     // 2nd pass: Plan where to send each block.
     171              :     int ndata_mythread[npacks];
     172              :     memset(ndata_mythread, 0, npacks * sizeof(int));
     173              : #pragma omp for schedule(static) // Need static to match previous loop.
     174              :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
     175              :       dbm_shard_t *shard = &matrix->shards[ishard];
     176              :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
     177              :         const dbm_block_t *blk = &shard->blocks[iblock];
     178              :         const int free_index = (trans_matrix) ? blk->col : blk->row;
     179              :         const int sum_index = (trans_matrix) ? blk->row : blk->col;
     180              :         unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
     181              :         const int ipack = itick64 / dist_ticks->nranks;
     182              :         // Compute rank to which this block should be sent.
     183              :         const int coord_free_idx = dist_indices->index2coord[free_index];
     184              :         const int coord_sum_idx = itick64 % dist_ticks->nranks;
     185              :         const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
     186              :                                (trans_dist) ? coord_free_idx : coord_sum_idx};
     187              :         const int rank = cp_mpi_cart_rank(comm, coords);
     188              :         const int row_size = matrix->row_sizes[blk->row];
     189              :         const int col_size = matrix->col_sizes[blk->col];
     190              :         ndata_mythread[ipack] += row_size * col_size;
     191              :         // Create plan.
     192              :         const int iplan = --nblks_mythread[ipack];
     193              :         plans_per_pack[ipack][iplan].blk = blk;
     194              :         plans_per_pack[ipack][iplan].rank = rank;
     195              :         plans_per_pack[ipack][iplan].row_size = row_size;
     196              :         plans_per_pack[ipack][iplan].col_size = col_size;
     197              :       }
     198              :     }
     199              : #pragma omp critical
     200              :     for (int ipack = 0; ipack < npacks; ipack++) {
     201              :       ndata_per_pack[ipack] += ndata_mythread[ipack];
     202              :     }
     203              :   } // end of omp parallel region
     204       484980 : }
     205              : 
     206              : /*******************************************************************************
     207              :  * \brief Private routine for filling send buffers.
     208              :  * \author Ole Schuett
     209              :  ******************************************************************************/
     210       506386 : static void fill_send_buffers(
     211              :     const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
     212              :     const int ndata_send, plan_t plans[nblks_send], const int nranks,
     213              :     int blks_send_count[nranks], int data_send_count[nranks],
     214              :     int blks_send_displ[nranks], int data_send_displ[nranks],
     215              :     dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
     216       506386 :   memset(blks_send_count, 0, nranks * sizeof(int));
     217       506386 :   memset(data_send_count, 0, nranks * sizeof(int));
     218              : 
     219       506386 : #pragma omp parallel
     220              :   {
     221              :     // 3th pass: Compute per rank nblks and ndata.
     222              :     int nblks_mythread[nranks], ndata_mythread[nranks];
     223              :     memset(nblks_mythread, 0, nranks * sizeof(int));
     224              :     memset(ndata_mythread, 0, nranks * sizeof(int));
     225              : #pragma omp for schedule(static)
     226              :     for (int iblock = 0; iblock < nblks_send; iblock++) {
     227              :       const plan_t *plan = &plans[iblock];
     228              :       nblks_mythread[plan->rank] += 1;
     229              :       ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
     230              :     }
     231              : 
     232              :     // Sum nblks and ndata across threads.
     233              : #pragma omp critical
     234              :     for (int irank = 0; irank < nranks; irank++) {
     235              :       blks_send_count[irank] += nblks_mythread[irank];
     236              :       data_send_count[irank] += ndata_mythread[irank];
     237              :       nblks_mythread[irank] = blks_send_count[irank];
     238              :       ndata_mythread[irank] = data_send_count[irank];
     239              :     }
     240              : #pragma omp barrier
     241              : 
     242              :     // Compute send displacements.
     243              : #pragma omp single
     244              :     {
     245              :       icumsum(nranks, blks_send_count, blks_send_displ);
     246              :       icumsum(nranks, data_send_count, data_send_displ);
     247              :       const int m = nranks - 1;
     248              :       assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
     249              :       assert(ndata_send == data_send_displ[m] + data_send_count[m]);
     250              :     }
     251              : #pragma omp barrier
     252              : 
     253              :     // 4th pass: Fill blks_send and data_send arrays.
     254              : #pragma omp for schedule(static) // Need static to match previous loop.
     255              :     for (int iblock = 0; iblock < nblks_send; iblock++) {
     256              :       const plan_t *const plan = &plans[iblock];
     257              :       const dbm_block_t *const blk = plan->blk;
     258              :       const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
     259              :       const dbm_shard_t *const shard = &matrix->shards[ishard];
     260              :       const double *blk_data = &shard->data[blk->offset];
     261              :       const int row_size = plan->row_size, col_size = plan->col_size;
     262              :       const int plan_size = row_size * col_size;
     263              :       const int irank = plan->rank;
     264              : 
     265              :       // The blk_send_data is ordered by rank, thread, and block.
     266              :       //   data_send_displ[irank]: Start of data for irank within blk_send_data.
     267              :       //   ndata_mythread[irank]: Current threads offset within data for irank.
     268              :       nblks_mythread[irank] -= 1;
     269              :       ndata_mythread[irank] -= plan_size;
     270              :       const int offset = data_send_displ[irank] + ndata_mythread[irank];
     271              :       const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
     272              : 
     273              :       double norm = 0.0; // Compute norm as double...
     274              :       if (trans_matrix) {
     275              :         // Transpose block to allow for outer-product style multiplication.
     276              :         for (int i = 0; i < row_size; i++) {
     277              :           for (int j = 0; j < col_size; j++) {
     278              :             const double element = blk_data[j * row_size + i];
     279              :             data_send[offset + i * col_size + j] = element;
     280              :             norm += element * element;
     281              :           }
     282              :         }
     283              :         blks_send[jblock].free_index = plan->blk->col;
     284              :         blks_send[jblock].sum_index = plan->blk->row;
     285              :       } else {
     286              :         for (int i = 0; i < plan_size; i++) {
     287              :           const double element = blk_data[i];
     288              :           data_send[offset + i] = element;
     289              :           norm += element * element;
     290              :         }
     291              :         blks_send[jblock].free_index = plan->blk->row;
     292              :         blks_send[jblock].sum_index = plan->blk->col;
     293              :       }
     294              :       blks_send[jblock].norm = (float)norm; // ...store norm as float.
     295              : 
     296              :       // After the block exchange data_recv_displ will be added to the offsets.
     297              :       blks_send[jblock].offset = offset - data_send_displ[irank];
     298              :     }
     299              :   } // end of omp parallel region
     300       506386 : }
     301              : 
     302              : /*******************************************************************************
     303              :  * \brief Private comperator passed to qsort to compare two blocks by sum_index.
     304              :  * \author Ole Schuett
     305              :  ******************************************************************************/
     306     75118936 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
     307     75118936 :   const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
     308     75118936 :   const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
     309     75118936 :   return blk_a->sum_index - blk_b->sum_index;
     310              : }
     311              : 
     312              : /*******************************************************************************
     313              :  * \brief Private routine for post-processing received blocks.
     314              :  * \author Ole Schuett
     315              :  ******************************************************************************/
     316       506386 : static void postprocess_received_blocks(
     317              :     const int nranks, const int nshards, const int nblocks_recv,
     318              :     const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
     319              :     const int data_recv_displ[nranks],
     320       506386 :     dbm_pack_block_t blks_recv[nblocks_recv]) {
     321       506386 :   int nblocks_per_shard[nshards], shard_start[nshards];
     322       506386 :   memset(nblocks_per_shard, 0, nshards * sizeof(int));
     323       506386 :   dbm_pack_block_t *blocks_tmp =
     324       506386 :       malloc(nblocks_recv * sizeof(dbm_pack_block_t));
     325       506386 :   assert(blocks_tmp != NULL || nblocks_recv == 0);
     326              : 
     327       506386 : #pragma omp parallel
     328              :   {
     329              :     // Add data_recv_displ to recveived block offsets.
     330              :     for (int irank = 0; irank < nranks; irank++) {
     331              : #pragma omp for
     332              :       for (int i = 0; i < blks_recv_count[irank]; i++) {
     333              :         blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
     334              :       }
     335              :     }
     336              : 
     337              :     // First use counting sort to group blocks by their free_index shard.
     338              :     int nblocks_mythread[nshards];
     339              :     memset(nblocks_mythread, 0, nshards * sizeof(int));
     340              : #pragma omp for schedule(static)
     341              :     for (int iblock = 0; iblock < nblocks_recv; iblock++) {
     342              :       blocks_tmp[iblock] = blks_recv[iblock];
     343              :       const int ishard = blks_recv[iblock].free_index % nshards;
     344              :       nblocks_mythread[ishard]++;
     345              :     }
     346              : #pragma omp critical
     347              :     for (int ishard = 0; ishard < nshards; ishard++) {
     348              :       nblocks_per_shard[ishard] += nblocks_mythread[ishard];
     349              :       nblocks_mythread[ishard] = nblocks_per_shard[ishard];
     350              :     }
     351              : #pragma omp barrier
     352              : #pragma omp single
     353              :     icumsum(nshards, nblocks_per_shard, shard_start);
     354              : #pragma omp barrier
     355              : #pragma omp for schedule(static) // Need static to match previous loop.
     356              :     for (int iblock = 0; iblock < nblocks_recv; iblock++) {
     357              :       const int ishard = blocks_tmp[iblock].free_index % nshards;
     358              :       const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
     359              :       blks_recv[jblock] = blocks_tmp[iblock];
     360              :     }
     361              : 
     362              :     // Then sort blocks within each shard by their sum_index.
     363              : #pragma omp for
     364              :     for (int ishard = 0; ishard < nshards; ishard++) {
     365              :       if (nblocks_per_shard[ishard] > 1) {
     366              :         qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
     367              :               sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
     368              :       }
     369              :     }
     370              :   } // end of omp parallel region
     371              : 
     372       506386 :   free(blocks_tmp);
     373       506386 : }
     374              : 
     375              : /*******************************************************************************
     376              :  * \brief Private routine for redistributing a matrix along selected dimensions.
     377              :  * \author Ole Schuett
     378              :  ******************************************************************************/
     379       484980 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
     380              :                                        const bool trans_dist,
     381              :                                        const dbm_matrix_t *restrict matrix,
     382              :                                        const dbm_distribution_t *restrict dist,
     383       484980 :                                        const int nticks) {
     384       484980 :   assert(cp_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
     385              : 
     386              :   // The row/col indicies are distributed along one cart dimension and the
     387              :   // ticks are distributed along the other cart dimension.
     388       484980 :   const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
     389       484980 :   const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
     390       484980 :   const int *free_index_sizes =
     391              :       (trans_matrix) ? matrix->col_sizes : matrix->row_sizes;
     392       484980 :   const int *sum_index_sizes =
     393              :       (trans_matrix) ? matrix->row_sizes : matrix->col_sizes;
     394              : 
     395              :   // Allocate packed matrix.
     396       484980 :   const int nsend_packs = nticks / dist_ticks->nranks;
     397       484980 :   assert(nsend_packs * dist_ticks->nranks == nticks);
     398       484980 :   dbm_packed_matrix_t packed;
     399       484980 :   packed.dist_indices = dist_indices;
     400       484980 :   packed.dist_ticks = dist_ticks;
     401       484980 :   packed.nsend_packs = nsend_packs;
     402       484980 :   packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
     403       484980 :   assert(packed.send_packs != NULL || nsend_packs == 0);
     404              : 
     405              :   // Plan all packs.
     406       484980 :   plan_t *plans_per_pack[nsend_packs];
     407       484980 :   int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
     408       484980 :   create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
     409              :                     dist_ticks, nticks, nsend_packs, plans_per_pack,
     410              :                     nblks_send_per_pack, ndata_send_per_pack);
     411              : 
     412              :   // Allocate send buffers for maximum number of blocks/data over all packs.
     413       484980 :   int nblks_send_max = 0, ndata_send_max = 0;
     414       991366 :   for (int ipack = 0; ipack < nsend_packs; ++ipack) {
     415       506386 :     nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
     416       506386 :     ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
     417              :   }
     418       484980 :   dbm_pack_block_t *blks_send =
     419       484980 :       cp_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
     420       484980 :   double *data_send = cp_mpi_alloc_mem(ndata_send_max * sizeof(double));
     421              : 
     422              :   // Cannot parallelize over packs (there might be too few of them).
     423       991366 :   for (int ipack = 0; ipack < nsend_packs; ipack++) {
     424              :     // Fill send buffers according to plans.
     425       506386 :     const int nranks = dist->nranks;
     426       506386 :     int blks_send_count[nranks], data_send_count[nranks];
     427       506386 :     int blks_send_displ[nranks], data_send_displ[nranks];
     428       506386 :     fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
     429       506386 :                       ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
     430              :                       blks_send_count, data_send_count, blks_send_displ,
     431              :                       data_send_displ, blks_send, data_send);
     432       506386 :     free(plans_per_pack[ipack]);
     433              : 
     434              :     // 1st communication: Exchange block counts.
     435       506386 :     int blks_recv_count[nranks], blks_recv_displ[nranks];
     436       506386 :     cp_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
     437      1583376 :     icumsum(nranks, blks_recv_count, blks_recv_displ);
     438       506386 :     const int nblocks_recv = isum(nranks, blks_recv_count);
     439              : 
     440              :     // 2nd communication: Exchange blocks.
     441       506386 :     dbm_pack_block_t *blks_recv =
     442       506386 :         cp_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
     443       506386 :     int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
     444       506386 :     int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
     445      1076990 :     for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
     446      1141208 :       blks_send_count_byte[i] =
     447       570604 :           checked_byte_count(blks_send_count[i], sizeof(dbm_pack_block_t));
     448      1141208 :       blks_send_displ_byte[i] =
     449       570604 :           checked_byte_count(blks_send_displ[i], sizeof(dbm_pack_block_t));
     450      1141208 :       blks_recv_count_byte[i] =
     451       570604 :           checked_byte_count(blks_recv_count[i], sizeof(dbm_pack_block_t));
     452       570604 :       blks_recv_displ_byte[i] =
     453       570604 :           checked_byte_count(blks_recv_displ[i], sizeof(dbm_pack_block_t));
     454              :     }
     455       506386 :     cp_mpi_alltoallv_byte(blks_send, blks_send_count_byte, blks_send_displ_byte,
     456              :                           blks_recv, blks_recv_count_byte, blks_recv_displ_byte,
     457              :                           dist->comm);
     458              : 
     459              :     // 3rd compute data counts from the received block metadata.
     460       506386 :     int data_recv_count[nranks], data_recv_displ[nranks];
     461       506386 :     compute_data_recv_count(nranks, blks_recv_count, blks_recv_displ,
     462              :                             free_index_sizes, sum_index_sizes, blks_recv,
     463              :                             data_recv_count);
     464      1583376 :     icumsum(nranks, data_recv_count, data_recv_displ);
     465       506386 :     const int ndata_recv = isum(nranks, data_recv_count);
     466              : 
     467              :     // 4th communication: Exchange data.
     468              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     469       506386 :     double *data_recv =
     470       506386 :         offload_mempool_host_malloc(ndata_recv * sizeof(double));
     471              : #else
     472              :     double *data_recv = cp_mpi_alloc_mem(ndata_recv * sizeof(double));
     473              : #endif
     474       506386 :     cp_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
     475              :                             data_recv, data_recv_count, data_recv_displ,
     476              :                             dist->comm);
     477              : 
     478              :     // 5th post-process received blocks and assemble them into a pack.
     479       506386 :     postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
     480              :                                 blks_recv_count, blks_recv_displ,
     481              :                                 data_recv_displ, blks_recv);
     482       506386 :     packed.send_packs[ipack].nblocks = nblocks_recv;
     483       506386 :     packed.send_packs[ipack].data_size = ndata_recv;
     484       506386 :     packed.send_packs[ipack].blocks = blks_recv;
     485       506386 :     packed.send_packs[ipack].data = data_recv;
     486              :   }
     487              : 
     488              :   // Deallocate send buffers.
     489       484980 :   cp_mpi_free_mem(blks_send);
     490       484980 :   cp_mpi_free_mem(data_send);
     491              : 
     492              :   // Allocate pack_recv.
     493       484980 :   int max_nblocks = 0, max_data_size = 0;
     494       991366 :   for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
     495       506386 :     max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
     496       506386 :     max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
     497              :   }
     498       484980 :   cp_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
     499       484980 :   cp_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
     500       484980 :   packed.max_nblocks = max_nblocks;
     501       484980 :   packed.max_data_size = max_data_size;
     502       969960 :   packed.recv_pack.blocks =
     503       484980 :       cp_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
     504              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     505       969960 :   packed.recv_pack.data =
     506       484980 :       offload_mempool_host_malloc(packed.max_data_size * sizeof(double));
     507              : #else
     508              :   packed.recv_pack.data =
     509              :       cp_mpi_alloc_mem(packed.max_data_size * sizeof(double));
     510              : #endif
     511              : 
     512       484980 :   return packed; // Ownership of packed transfers to caller.
     513              : }
     514              : 
     515              : /*******************************************************************************
     516              :  * \brief Private routine for sending and receiving the pack for the given tick.
     517              :  * \author Ole Schuett
     518              :  ******************************************************************************/
     519       527792 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
     520              :                                  dbm_packed_matrix_t *packed) {
     521       527792 :   const int nranks = packed->dist_ticks->nranks;
     522       527792 :   const int my_rank = packed->dist_ticks->my_rank;
     523              : 
     524              :   // Compute send rank and pack.
     525       527792 :   const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
     526       527792 :   const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
     527       527792 :   const int send_itick = (itick_of_rank0 + send_rank) % nticks;
     528       527792 :   const int send_ipack = send_itick / nranks;
     529       527792 :   assert(send_itick % nranks == my_rank);
     530              : 
     531              :   // Compute receive rank and pack.
     532       527792 :   const int recv_rank = itick % nranks;
     533       527792 :   const int recv_ipack = itick / nranks;
     534              : 
     535       527792 :   dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
     536       527792 :   if (send_rank == my_rank) {
     537       506386 :     assert(send_rank == recv_rank && send_ipack == recv_ipack);
     538              :     return send_pack; // Local pack, no mpi needed.
     539              :   } else {
     540              :     // Exchange blocks.
     541        64218 :     const int nblocks_in_bytes = cp_mpi_sendrecv_byte(
     542        21406 :         /*sendbuf=*/send_pack->blocks,
     543              :         /*sendcound=*/
     544        21406 :         checked_byte_count(send_pack->nblocks, sizeof(dbm_pack_block_t)),
     545              :         /*dest=*/send_rank,
     546              :         /*sendtag=*/send_ipack,
     547        21406 :         /*recvbuf=*/packed->recv_pack.blocks,
     548              :         /*recvcount=*/
     549        21406 :         checked_byte_count(packed->max_nblocks, sizeof(dbm_pack_block_t)),
     550              :         /*source=*/recv_rank,
     551              :         /*recvtag=*/recv_ipack,
     552              :         /*comm=*/packed->dist_ticks->comm);
     553              : 
     554        21406 :     assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
     555        21406 :     packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
     556              : 
     557              :     // Exchange data.
     558        42812 :     packed->recv_pack.data_size = cp_mpi_sendrecv_double(
     559        21406 :         /*sendbuf=*/send_pack->data,
     560        21406 :         /*sendcound=*/send_pack->data_size,
     561              :         /*dest=*/send_rank,
     562              :         /*sendtag=*/send_ipack,
     563              :         /*recvbuf=*/packed->recv_pack.data,
     564        21406 :         /*recvcount=*/packed->max_data_size,
     565              :         /*source=*/recv_rank,
     566              :         /*recvtag=*/recv_ipack,
     567        21406 :         /*comm=*/packed->dist_ticks->comm);
     568              : 
     569        21406 :     return &packed->recv_pack;
     570              :   }
     571              : }
     572              : 
     573              : /*******************************************************************************
     574              :  * \brief Private routine for releasing a packed matrix.
     575              :  * \author Ole Schuett
     576              :  ******************************************************************************/
     577       484980 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
     578       484980 :   cp_mpi_free_mem(packed->recv_pack.blocks);
     579              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     580       484980 :   offload_mempool_host_free(packed->recv_pack.data);
     581              : #else
     582              :   cp_mpi_free_mem(packed->recv_pack.data);
     583              : #endif
     584       991366 :   for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
     585       506386 :     cp_mpi_free_mem(packed->send_packs[ipack].blocks);
     586              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     587       506386 :     offload_mempool_host_free(packed->send_packs[ipack].data);
     588              : #else
     589              :     cp_mpi_free_mem(packed->send_packs[ipack].data);
     590              : #endif
     591              :   }
     592       484980 :   free(packed->send_packs);
     593       484980 : }
     594              : 
     595              : /*******************************************************************************
     596              :  * \brief Internal routine for creating a communication iterator.
     597              :  * \author Ole Schuett
     598              :  ******************************************************************************/
     599       242490 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
     600              :                                              const bool transb,
     601              :                                              const dbm_matrix_t *matrix_a,
     602              :                                              const dbm_matrix_t *matrix_b,
     603              :                                              const dbm_matrix_t *matrix_c) {
     604       242490 :   dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
     605       242490 :   assert(iter != NULL);
     606       242490 :   iter->dist = matrix_c->dist;
     607              : 
     608              :   // During each communication tick we'll fetch a pack_a and pack_b.
     609              :   // Since the cart might be non-squared, the number of communication ticks is
     610              :   // chosen as the least common multiple of the cart's dimensions.
     611       242490 :   iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
     612       242490 :   iter->itick = 0;
     613              : 
     614              :   // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
     615       242490 :   iter->packed_a =
     616       242490 :       pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
     617       242490 :   iter->packed_b =
     618       242490 :       pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
     619              : 
     620       242490 :   return iter;
     621              : }
     622              : 
     623              : /*******************************************************************************
     624              :  * \brief Internal routine for retrieving next pair of packs of given iterator.
     625              :  * \author Ole Schuett
     626              :  ******************************************************************************/
     627       506386 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
     628              :                             dbm_pack_t **pack_b) {
     629       506386 :   if (iter->itick >= iter->nticks) {
     630              :     return false; // end of iterator reached
     631              :   }
     632              : 
     633              :   // Start each rank at a different tick to spread the load on the sources.
     634       263896 :   const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
     635       263896 :   const int itick = (iter->itick + shift) % iter->nticks;
     636       263896 :   *pack_a = sendrecv_pack(itick, iter->nticks, &iter->packed_a);
     637       263896 :   *pack_b = sendrecv_pack(itick, iter->nticks, &iter->packed_b);
     638              : 
     639       263896 :   ++iter->itick;
     640       263896 :   return true;
     641              : }
     642              : 
     643              : /*******************************************************************************
     644              :  * \brief Internal routine for releasing the given communication iterator.
     645              :  * \author Ole Schuett
     646              :  ******************************************************************************/
     647       242490 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
     648       242490 :   free_packed_matrix(&iter->packed_a);
     649       242490 :   free_packed_matrix(&iter->packed_b);
     650       242490 :   free(iter);
     651       242490 : }
     652              : 
     653              : // EOF
        

Generated by: LCOV version 2.0-1