LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply_comm.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 100.0 % 167 167
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 13 13

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : #include "dbm_multiply_comm.h"
       8              : #include "../mpiwrap/cp_mpi.h"
       9              : #include "../offload/offload_mempool.h"
      10              : 
      11              : #include <assert.h>
      12              : #include <stdlib.h>
      13              : #include <string.h>
      14              : 
      15              : #if 1
      16              : #define DBM_MULTIPLY_COMM_MEMPOOL
      17              : #endif
      18              : 
      19              : /*******************************************************************************
      20              :  * \brief Private routine for computing greatest common divisor of two numbers.
      21              :  * \author Ole Schuett
      22              :  ******************************************************************************/
      23       445570 : static int gcd(const int a, const int b) {
      24       445570 :   if (a == 0) {
      25              :     return b;
      26              :   }
      27       232318 :   return gcd(b % a, a); // Euclid's algorithm.
      28              : }
      29              : 
      30              : /*******************************************************************************
      31              :  * \brief Private routine for computing least common multiple of two numbers.
      32              :  * \author Ole Schuett
      33              :  ******************************************************************************/
      34       213252 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
      35              : 
      36              : /*******************************************************************************
      37              :  * \brief Private routine for computing the sum of the given integers.
      38              :  * \author Ole Schuett
      39              :  ******************************************************************************/
      40       891276 : static inline int isum(const int n, const int input[n]) {
      41       891276 :   int output = 0;
      42      1897356 :   for (int i = 0; i < n; i++) {
      43      1006080 :     output += input[i];
      44              :   }
      45       891276 :   return output;
      46              : }
      47              : 
      48              : /*******************************************************************************
      49              :  * \brief Private routine for computing the cumulative sums of given numbers.
      50              :  * \author Ole Schuett
      51              :  ******************************************************************************/
      52      2228190 : static inline void icumsum(const int n, const int input[n], int output[n]) {
      53      2228190 :   output[0] = 0;
      54      2457798 :   for (int i = 1; i < n; i++) {
      55       229608 :     output[i] = output[i - 1] + input[i - 1];
      56              :   }
      57      2228190 : }
      58              : 
      59              : /*******************************************************************************
      60              :  * \brief Private struct used for planing during pack_matrix.
      61              :  * \author Ole Schuett
      62              :  ******************************************************************************/
      63              : typedef struct {
      64              :   const dbm_block_t *blk; // source block
      65              :   int rank;               // target mpi rank
      66              :   int row_size;
      67              :   int col_size;
      68              : } plan_t;
      69              : 
      70              : /*******************************************************************************
      71              :  * \brief Private routine for planing packs.
      72              :  * \author Ole Schuett
      73              :  ******************************************************************************/
      74       426504 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
      75              :                               const dbm_matrix_t *matrix,
      76              :                               const cp_mpi_comm_t comm,
      77              :                               const dbm_dist_1d_t *dist_indices,
      78              :                               const dbm_dist_1d_t *dist_ticks, const int nticks,
      79              :                               const int npacks, plan_t *plans_per_pack[npacks],
      80              :                               int nblks_per_pack[npacks],
      81              :                               int ndata_per_pack[npacks]) {
      82              : 
      83       426504 :   memset(nblks_per_pack, 0, npacks * sizeof(int));
      84       426504 :   memset(ndata_per_pack, 0, npacks * sizeof(int));
      85              : 
      86       426504 : #pragma omp parallel
      87              :   {
      88              :     // 1st pass: Compute number of blocks that will be send in each pack.
      89              :     int nblks_mythread[npacks];
      90              :     memset(nblks_mythread, 0, npacks * sizeof(int));
      91              : #pragma omp for schedule(static)
      92              :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
      93              :       dbm_shard_t *shard = &matrix->shards[ishard];
      94              :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
      95              :         const dbm_block_t *blk = &shard->blocks[iblock];
      96              :         const int sum_index = (trans_matrix) ? blk->row : blk->col;
      97              :         const int itick = (1021 * sum_index) % nticks; // 1021 = a random prime
      98              :         const int ipack = itick / dist_ticks->nranks;
      99              :         nblks_mythread[ipack]++;
     100              :       }
     101              :     }
     102              : 
     103              :     // Sum nblocks across threads and allocate arrays for plans.
     104              : #pragma omp critical
     105              :     for (int ipack = 0; ipack < npacks; ipack++) {
     106              :       nblks_per_pack[ipack] += nblks_mythread[ipack];
     107              :       nblks_mythread[ipack] = nblks_per_pack[ipack];
     108              :     }
     109              : #pragma omp barrier
     110              : #pragma omp for
     111              :     for (int ipack = 0; ipack < npacks; ipack++) {
     112              :       const int nblks = nblks_per_pack[ipack];
     113              :       plans_per_pack[ipack] = malloc(nblks * sizeof(plan_t));
     114              :       assert(plans_per_pack[ipack] != NULL || nblks == 0);
     115              :     }
     116              : 
     117              :     // 2nd pass: Plan where to send each block.
     118              :     int ndata_mythread[npacks];
     119              :     memset(ndata_mythread, 0, npacks * sizeof(int));
     120              : #pragma omp for schedule(static) // Need static to match previous loop.
     121              :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
     122              :       dbm_shard_t *shard = &matrix->shards[ishard];
     123              :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
     124              :         const dbm_block_t *blk = &shard->blocks[iblock];
     125              :         const int free_index = (trans_matrix) ? blk->col : blk->row;
     126              :         const int sum_index = (trans_matrix) ? blk->row : blk->col;
     127              :         const int itick = (1021 * sum_index) % nticks; // Same mapping as above.
     128              :         const int ipack = itick / dist_ticks->nranks;
     129              :         // Compute rank to which this block should be sent.
     130              :         const int coord_free_idx = dist_indices->index2coord[free_index];
     131              :         const int coord_sum_idx = itick % dist_ticks->nranks;
     132              :         const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
     133              :                                (trans_dist) ? coord_free_idx : coord_sum_idx};
     134              :         const int rank = cp_mpi_cart_rank(comm, coords);
     135              :         const int row_size = matrix->row_sizes[blk->row];
     136              :         const int col_size = matrix->col_sizes[blk->col];
     137              :         ndata_mythread[ipack] += row_size * col_size;
     138              :         // Create plan.
     139              :         const int iplan = --nblks_mythread[ipack];
     140              :         plans_per_pack[ipack][iplan].blk = blk;
     141              :         plans_per_pack[ipack][iplan].rank = rank;
     142              :         plans_per_pack[ipack][iplan].row_size = row_size;
     143              :         plans_per_pack[ipack][iplan].col_size = col_size;
     144              :       }
     145              :     }
     146              : #pragma omp critical
     147              :     for (int ipack = 0; ipack < npacks; ipack++) {
     148              :       ndata_per_pack[ipack] += ndata_mythread[ipack];
     149              :     }
     150              :   } // end of omp parallel region
     151       426504 : }
     152              : 
     153              : /*******************************************************************************
     154              :  * \brief Private routine for filling send buffers.
     155              :  * \author Ole Schuett
     156              :  ******************************************************************************/
     157       445638 : static void fill_send_buffers(
     158              :     const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
     159              :     const int ndata_send, plan_t plans[nblks_send], const int nranks,
     160              :     int blks_send_count[nranks], int data_send_count[nranks],
     161              :     int blks_send_displ[nranks], int data_send_displ[nranks],
     162              :     dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
     163              : 
     164       445638 :   memset(blks_send_count, 0, nranks * sizeof(int));
     165       445638 :   memset(data_send_count, 0, nranks * sizeof(int));
     166              : 
     167       445638 : #pragma omp parallel
     168              :   {
     169              :     // 3th pass: Compute per rank nblks and ndata.
     170              :     int nblks_mythread[nranks], ndata_mythread[nranks];
     171              :     memset(nblks_mythread, 0, nranks * sizeof(int));
     172              :     memset(ndata_mythread, 0, nranks * sizeof(int));
     173              : #pragma omp for schedule(static)
     174              :     for (int iblock = 0; iblock < nblks_send; iblock++) {
     175              :       const plan_t *plan = &plans[iblock];
     176              :       nblks_mythread[plan->rank] += 1;
     177              :       ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
     178              :     }
     179              : 
     180              :     // Sum nblks and ndata across threads.
     181              : #pragma omp critical
     182              :     for (int irank = 0; irank < nranks; irank++) {
     183              :       blks_send_count[irank] += nblks_mythread[irank];
     184              :       data_send_count[irank] += ndata_mythread[irank];
     185              :       nblks_mythread[irank] = blks_send_count[irank];
     186              :       ndata_mythread[irank] = data_send_count[irank];
     187              :     }
     188              : #pragma omp barrier
     189              : 
     190              :     // Compute send displacements.
     191              : #pragma omp master
     192              :     {
     193              :       icumsum(nranks, blks_send_count, blks_send_displ);
     194              :       icumsum(nranks, data_send_count, data_send_displ);
     195              :       const int m = nranks - 1;
     196              :       assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
     197              :       assert(ndata_send == data_send_displ[m] + data_send_count[m]);
     198              :     }
     199              : #pragma omp barrier
     200              : 
     201              :     // 4th pass: Fill blks_send and data_send arrays.
     202              : #pragma omp for schedule(static) // Need static to match previous loop.
     203              :     for (int iblock = 0; iblock < nblks_send; iblock++) {
     204              :       const plan_t *plan = &plans[iblock];
     205              :       const dbm_block_t *blk = plan->blk;
     206              :       const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
     207              :       const dbm_shard_t *shard = &matrix->shards[ishard];
     208              :       const double *blk_data = &shard->data[blk->offset];
     209              :       const int row_size = plan->row_size, col_size = plan->col_size;
     210              :       const int plan_size = row_size * col_size;
     211              :       const int irank = plan->rank;
     212              : 
     213              :       // The blk_send_data is ordered by rank, thread, and block.
     214              :       //   data_send_displ[irank]: Start of data for irank within blk_send_data.
     215              :       //   ndata_mythread[irank]: Current threads offset within data for irank.
     216              :       nblks_mythread[irank] -= 1;
     217              :       ndata_mythread[irank] -= plan_size;
     218              :       const int offset = data_send_displ[irank] + ndata_mythread[irank];
     219              :       const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
     220              : 
     221              :       double norm = 0.0; // Compute norm as double...
     222              :       if (trans_matrix) {
     223              :         // Transpose block to allow for outer-product style multiplication.
     224              :         for (int i = 0; i < row_size; i++) {
     225              :           for (int j = 0; j < col_size; j++) {
     226              :             const double element = blk_data[j * row_size + i];
     227              :             data_send[offset + i * col_size + j] = element;
     228              :             norm += element * element;
     229              :           }
     230              :         }
     231              :         blks_send[jblock].free_index = plan->blk->col;
     232              :         blks_send[jblock].sum_index = plan->blk->row;
     233              :       } else {
     234              :         for (int i = 0; i < plan_size; i++) {
     235              :           const double element = blk_data[i];
     236              :           data_send[offset + i] = element;
     237              :           norm += element * element;
     238              :         }
     239              :         blks_send[jblock].free_index = plan->blk->row;
     240              :         blks_send[jblock].sum_index = plan->blk->col;
     241              :       }
     242              :       blks_send[jblock].norm = (float)norm; // ...store norm as float.
     243              : 
     244              :       // After the block exchange data_recv_displ will be added to the offsets.
     245              :       blks_send[jblock].offset = offset - data_send_displ[irank];
     246              :     }
     247              :   } // end of omp parallel region
     248       445638 : }
     249              : 
     250              : /*******************************************************************************
     251              :  * \brief Private comperator passed to qsort to compare two blocks by sum_index.
     252              :  * \author Ole Schuett
     253              :  ******************************************************************************/
     254     70747990 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
     255     70747990 :   const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
     256     70747990 :   const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
     257     70747990 :   return blk_a->sum_index - blk_b->sum_index;
     258              : }
     259              : 
     260              : /*******************************************************************************
     261              :  * \brief Private routine for post-processing received blocks.
     262              :  * \author Ole Schuett
     263              :  ******************************************************************************/
     264       445638 : static void postprocess_received_blocks(
     265              :     const int nranks, const int nshards, const int nblocks_recv,
     266              :     const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
     267              :     const int data_recv_displ[nranks],
     268       445638 :     dbm_pack_block_t blks_recv[nblocks_recv]) {
     269              : 
     270       445638 :   int nblocks_per_shard[nshards], shard_start[nshards];
     271       445638 :   memset(nblocks_per_shard, 0, nshards * sizeof(int));
     272       445638 :   dbm_pack_block_t *blocks_tmp =
     273       445638 :       malloc(nblocks_recv * sizeof(dbm_pack_block_t));
     274       445638 :   assert(blocks_tmp != NULL || nblocks_recv == 0);
     275              : 
     276       445638 : #pragma omp parallel
     277              :   {
     278              :     // Add data_recv_displ to recveived block offsets.
     279              :     for (int irank = 0; irank < nranks; irank++) {
     280              : #pragma omp for
     281              :       for (int i = 0; i < blks_recv_count[irank]; i++) {
     282              :         blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
     283              :       }
     284              :     }
     285              : 
     286              :     // First use counting sort to group blocks by their free_index shard.
     287              :     int nblocks_mythread[nshards];
     288              :     memset(nblocks_mythread, 0, nshards * sizeof(int));
     289              : #pragma omp for schedule(static)
     290              :     for (int iblock = 0; iblock < nblocks_recv; iblock++) {
     291              :       blocks_tmp[iblock] = blks_recv[iblock];
     292              :       const int ishard = blks_recv[iblock].free_index % nshards;
     293              :       nblocks_mythread[ishard]++;
     294              :     }
     295              : #pragma omp critical
     296              :     for (int ishard = 0; ishard < nshards; ishard++) {
     297              :       nblocks_per_shard[ishard] += nblocks_mythread[ishard];
     298              :       nblocks_mythread[ishard] = nblocks_per_shard[ishard];
     299              :     }
     300              : #pragma omp barrier
     301              : #pragma omp master
     302              :     icumsum(nshards, nblocks_per_shard, shard_start);
     303              : #pragma omp barrier
     304              : #pragma omp for schedule(static) // Need static to match previous loop.
     305              :     for (int iblock = 0; iblock < nblocks_recv; iblock++) {
     306              :       const int ishard = blocks_tmp[iblock].free_index % nshards;
     307              :       const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
     308              :       blks_recv[jblock] = blocks_tmp[iblock];
     309              :     }
     310              : 
     311              :     // Then sort blocks within each shard by their sum_index.
     312              : #pragma omp for
     313              :     for (int ishard = 0; ishard < nshards; ishard++) {
     314              :       if (nblocks_per_shard[ishard] > 1) {
     315              :         qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
     316              :               sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
     317              :       }
     318              :     }
     319              :   } // end of omp parallel region
     320              : 
     321       445638 :   free(blocks_tmp);
     322       445638 : }
     323              : 
     324              : /*******************************************************************************
     325              :  * \brief Private routine for redistributing a matrix along selected dimensions.
     326              :  * \author Ole Schuett
     327              :  ******************************************************************************/
     328       426504 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
     329              :                                        const bool trans_dist,
     330              :                                        const dbm_matrix_t *matrix,
     331              :                                        const dbm_distribution_t *dist,
     332       426504 :                                        const int nticks) {
     333              : 
     334       426504 :   assert(cp_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
     335              : 
     336              :   // The row/col indicies are distributed along one cart dimension and the
     337              :   // ticks are distributed along the other cart dimension.
     338       426504 :   const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
     339       426504 :   const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
     340              : 
     341              :   // Allocate packed matrix.
     342       426504 :   const int nsend_packs = nticks / dist_ticks->nranks;
     343       426504 :   assert(nsend_packs * dist_ticks->nranks == nticks);
     344       426504 :   dbm_packed_matrix_t packed;
     345       426504 :   packed.dist_indices = dist_indices;
     346       426504 :   packed.dist_ticks = dist_ticks;
     347       426504 :   packed.nsend_packs = nsend_packs;
     348       426504 :   packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
     349       426504 :   assert(packed.send_packs != NULL || nsend_packs == 0);
     350              : 
     351              :   // Plan all packs.
     352       426504 :   plan_t *plans_per_pack[nsend_packs];
     353       426504 :   int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
     354       426504 :   create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
     355              :                     dist_ticks, nticks, nsend_packs, plans_per_pack,
     356              :                     nblks_send_per_pack, ndata_send_per_pack);
     357              : 
     358              :   // Allocate send buffers for maximum number of blocks/data over all packs.
     359       426504 :   int nblks_send_max = 0, ndata_send_max = 0;
     360       872142 :   for (int ipack = 0; ipack < nsend_packs; ++ipack) {
     361       445638 :     nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
     362       445638 :     ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
     363              :   }
     364       426504 :   dbm_pack_block_t *blks_send =
     365       426504 :       cp_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
     366       426504 :   double *data_send = cp_mpi_alloc_mem(ndata_send_max * sizeof(double));
     367              : 
     368              :   // Cannot parallelize over packs (there might be too few of them).
     369       872142 :   for (int ipack = 0; ipack < nsend_packs; ipack++) {
     370              :     // Fill send buffers according to plans.
     371       445638 :     const int nranks = dist->nranks;
     372       445638 :     int blks_send_count[nranks], data_send_count[nranks];
     373       445638 :     int blks_send_displ[nranks], data_send_displ[nranks];
     374       445638 :     fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
     375              :                       ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
     376              :                       blks_send_count, data_send_count, blks_send_displ,
     377              :                       data_send_displ, blks_send, data_send);
     378       445638 :     free(plans_per_pack[ipack]);
     379              : 
     380              :     // 1st communication: Exchange block counts.
     381       445638 :     int blks_recv_count[nranks], blks_recv_displ[nranks];
     382       445638 :     cp_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
     383       445638 :     icumsum(nranks, blks_recv_count, blks_recv_displ);
     384       445638 :     const int nblocks_recv = isum(nranks, blks_recv_count);
     385              : 
     386              :     // 2nd communication: Exchange blocks.
     387       445638 :     dbm_pack_block_t *blks_recv =
     388       445638 :         cp_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
     389       445638 :     int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
     390       445638 :     int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
     391       948678 :     for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
     392       503040 :       blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
     393       503040 :       blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
     394       503040 :       blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
     395       503040 :       blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
     396              :     }
     397       445638 :     cp_mpi_alltoallv_byte(blks_send, blks_send_count_byte, blks_send_displ_byte,
     398              :                           blks_recv, blks_recv_count_byte, blks_recv_displ_byte,
     399       445638 :                           dist->comm);
     400              : 
     401              :     // 3rd communication: Exchange data counts.
     402              :     // TODO: could be computed from blks_recv.
     403       445638 :     int data_recv_count[nranks], data_recv_displ[nranks];
     404       445638 :     cp_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
     405       445638 :     icumsum(nranks, data_recv_count, data_recv_displ);
     406       445638 :     const int ndata_recv = isum(nranks, data_recv_count);
     407              : 
     408              :     // 4th communication: Exchange data.
     409              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     410       445638 :     double *data_recv =
     411       445638 :         offload_mempool_host_malloc(ndata_recv * sizeof(double));
     412              : #else
     413              :     double *data_recv = cp_mpi_alloc_mem(ndata_recv * sizeof(double));
     414              : #endif
     415       445638 :     cp_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
     416              :                             data_recv, data_recv_count, data_recv_displ,
     417       445638 :                             dist->comm);
     418              : 
     419              :     // Post-process received blocks and assemble them into a pack.
     420       445638 :     postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
     421              :                                 blks_recv_count, blks_recv_displ,
     422              :                                 data_recv_displ, blks_recv);
     423       445638 :     packed.send_packs[ipack].nblocks = nblocks_recv;
     424       445638 :     packed.send_packs[ipack].data_size = ndata_recv;
     425       445638 :     packed.send_packs[ipack].blocks = blks_recv;
     426       445638 :     packed.send_packs[ipack].data = data_recv;
     427              :   }
     428              : 
     429              :   // Deallocate send buffers.
     430       426504 :   cp_mpi_free_mem(blks_send);
     431       426504 :   cp_mpi_free_mem(data_send);
     432              : 
     433              :   // Allocate pack_recv.
     434       426504 :   int max_nblocks = 0, max_data_size = 0;
     435       872142 :   for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
     436       445638 :     max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
     437       445638 :     max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
     438              :   }
     439       426504 :   cp_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
     440       426504 :   cp_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
     441       426504 :   packed.max_nblocks = max_nblocks;
     442       426504 :   packed.max_data_size = max_data_size;
     443       853008 :   packed.recv_pack.blocks =
     444       426504 :       cp_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
     445              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     446       853008 :   packed.recv_pack.data =
     447       426504 :       offload_mempool_host_malloc(packed.max_data_size * sizeof(double));
     448              : #else
     449              :   packed.recv_pack.data =
     450              :       cp_mpi_alloc_mem(packed.max_data_size * sizeof(double));
     451              : #endif
     452              : 
     453       426504 :   return packed; // Ownership of packed transfers to caller.
     454              : }
     455              : 
     456              : /*******************************************************************************
     457              :  * \brief Private routine for sending and receiving the pack for the given tick.
     458              :  * \author Ole Schuett
     459              :  ******************************************************************************/
     460       464772 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
     461              :                                  dbm_packed_matrix_t *packed) {
     462       464772 :   const int nranks = packed->dist_ticks->nranks;
     463       464772 :   const int my_rank = packed->dist_ticks->my_rank;
     464              : 
     465              :   // Compute send rank and pack.
     466       464772 :   const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
     467       464772 :   const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
     468       464772 :   const int send_itick = (itick_of_rank0 + send_rank) % nticks;
     469       464772 :   const int send_ipack = send_itick / nranks;
     470       464772 :   assert(send_itick % nranks == my_rank);
     471              : 
     472              :   // Compute receive rank and pack.
     473       464772 :   const int recv_rank = itick % nranks;
     474       464772 :   const int recv_ipack = itick / nranks;
     475              : 
     476       464772 :   dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
     477       464772 :   if (send_rank == my_rank) {
     478       445638 :     assert(send_rank == recv_rank && send_ipack == recv_ipack);
     479              :     return send_pack; // Local pack, no mpi needed.
     480              :   } else {
     481              :     // Exchange blocks.
     482        38268 :     const int nblocks_in_bytes = cp_mpi_sendrecv_byte(
     483        19134 :         /*sendbuf=*/send_pack->blocks,
     484        19134 :         /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
     485              :         /*dest=*/send_rank,
     486              :         /*sendtag=*/send_ipack,
     487        19134 :         /*recvbuf=*/packed->recv_pack.blocks,
     488        19134 :         /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
     489              :         /*source=*/recv_rank,
     490              :         /*recvtag=*/recv_ipack,
     491        19134 :         /*comm=*/packed->dist_ticks->comm);
     492              : 
     493        19134 :     assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
     494        19134 :     packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
     495              : 
     496              :     // Exchange data.
     497        38268 :     packed->recv_pack.data_size = cp_mpi_sendrecv_double(
     498        19134 :         /*sendbuf=*/send_pack->data,
     499              :         /*sendcound=*/send_pack->data_size,
     500              :         /*dest=*/send_rank,
     501              :         /*sendtag=*/send_ipack,
     502              :         /*recvbuf=*/packed->recv_pack.data,
     503              :         /*recvcount=*/packed->max_data_size,
     504              :         /*source=*/recv_rank,
     505              :         /*recvtag=*/recv_ipack,
     506        19134 :         /*comm=*/packed->dist_ticks->comm);
     507              : 
     508        19134 :     return &packed->recv_pack;
     509              :   }
     510              : }
     511              : 
     512              : /*******************************************************************************
     513              :  * \brief Private routine for releasing a packed matrix.
     514              :  * \author Ole Schuett
     515              :  ******************************************************************************/
     516       426504 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
     517       426504 :   cp_mpi_free_mem(packed->recv_pack.blocks);
     518              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     519       426504 :   offload_mempool_host_free(packed->recv_pack.data);
     520              : #else
     521              :   cp_mpi_free_mem(packed->recv_pack.data);
     522              : #endif
     523       872142 :   for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
     524       445638 :     cp_mpi_free_mem(packed->send_packs[ipack].blocks);
     525              : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
     526       445638 :     offload_mempool_host_free(packed->send_packs[ipack].data);
     527              : #else
     528              :     cp_mpi_free_mem(packed->send_packs[ipack].data);
     529              : #endif
     530              :   }
     531       426504 :   free(packed->send_packs);
     532       426504 : }
     533              : 
     534              : /*******************************************************************************
     535              :  * \brief Internal routine for creating a communication iterator.
     536              :  * \author Ole Schuett
     537              :  ******************************************************************************/
     538       213252 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
     539              :                                              const bool transb,
     540              :                                              const dbm_matrix_t *matrix_a,
     541              :                                              const dbm_matrix_t *matrix_b,
     542              :                                              const dbm_matrix_t *matrix_c) {
     543              : 
     544       213252 :   dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
     545       213252 :   assert(iter != NULL);
     546       213252 :   iter->dist = matrix_c->dist;
     547              : 
     548              :   // During each communication tick we'll fetch a pack_a and pack_b.
     549              :   // Since the cart might be non-squared, the number of communication ticks is
     550              :   // chosen as the least common multiple of the cart's dimensions.
     551       213252 :   iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
     552       213252 :   iter->itick = 0;
     553              : 
     554              :   // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
     555       213252 :   iter->packed_a =
     556       213252 :       pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
     557       213252 :   iter->packed_b =
     558       213252 :       pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
     559              : 
     560       213252 :   return iter;
     561              : }
     562              : 
     563              : /*******************************************************************************
     564              :  * \brief Internal routine for retriving next pair of packs from given iterator.
     565              :  * \author Ole Schuett
     566              :  ******************************************************************************/
     567       445638 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
     568              :                             dbm_pack_t **pack_b) {
     569       445638 :   if (iter->itick >= iter->nticks) {
     570              :     return false; // end of iterator reached
     571              :   }
     572              : 
     573              :   // Start each rank at a different tick to spread the load on the sources.
     574       232386 :   const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
     575       232386 :   const int shifted_itick = (iter->itick + shift) % iter->nticks;
     576       232386 :   *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
     577       232386 :   *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
     578              : 
     579       232386 :   iter->itick++;
     580       232386 :   return true;
     581              : }
     582              : 
     583              : /*******************************************************************************
     584              :  * \brief Internal routine for releasing the given communication iterator.
     585              :  * \author Ole Schuett
     586              :  ******************************************************************************/
     587       213252 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
     588       213252 :   free_packed_matrix(&iter->packed_a);
     589       213252 :   free_packed_matrix(&iter->packed_b);
     590       213252 :   free(iter);
     591       213252 : }
     592              : 
     593              : // EOF
        

Generated by: LCOV version 2.0-1