LCOV - code coverage report
Current view: top level - src/dbm - dbm_distribution.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 96.7 % 91 88
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 9 9

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : #include "dbm_distribution.h"
       8              : #include "dbm_hyperparams.h"
       9              : #include "dbm_internal.h"
      10              : 
      11              : #include <assert.h>
      12              : #include <math.h>
      13              : #include <omp.h>
      14              : #include <stdbool.h>
      15              : #include <stddef.h>
      16              : #include <stdlib.h>
      17              : #include <string.h>
      18              : 
      19              : /*******************************************************************************
      20              :  * \brief Private routine for creating a new one dimensional distribution.
      21              :  * \author Ole Schuett
      22              :  ******************************************************************************/
      23      1662768 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
      24              :                             const int coords[length], const cp_mpi_comm_t comm,
      25              :                             const int nshards) {
      26      1662768 :   dist->comm = comm;
      27      1662768 :   dist->nshards = nshards;
      28      1662768 :   dist->my_rank = cp_mpi_comm_rank(comm);
      29      1662768 :   dist->nranks = cp_mpi_comm_size(comm);
      30      1662768 :   dist->length = length;
      31      1662768 :   dist->index2coord = malloc(length * sizeof(int));
      32      1662768 :   assert(dist->index2coord != NULL || length == 0);
      33      1662768 :   if (length != 0) {
      34      1659897 :     memcpy(dist->index2coord, coords, length * sizeof(int));
      35              :   }
      36              : 
      37              :   // Check that cart coordinates and ranks are equivalent.
      38      1662768 :   int cart_dims[1], cart_periods[1], cart_coords[1];
      39      1662768 :   cp_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
      40      1662768 :   assert(dist->nranks == cart_dims[0]);
      41      1662768 :   assert(dist->my_rank == cart_coords[0]);
      42              : 
      43              :   // Count local rows/columns.
      44     22359790 :   for (int i = 0; i < length; i++) {
      45     20697022 :     assert(0 <= coords[i] && coords[i] < dist->nranks);
      46     20697022 :     if (coords[i] == dist->my_rank) {
      47     19751315 :       dist->nlocals++;
      48              :     }
      49              :   }
      50              : 
      51              :   // Store local rows/columns.
      52      1662768 :   dist->local_indicies = malloc(dist->nlocals * sizeof(int));
      53      1662768 :   assert(dist->local_indicies != NULL || dist->nlocals == 0);
      54              :   int j = 0;
      55     22359790 :   for (int i = 0; i < length; i++) {
      56     20697022 :     if (coords[i] == dist->my_rank) {
      57     19751315 :       dist->local_indicies[j++] = i;
      58              :     }
      59              :   }
      60      1662768 :   assert(j == dist->nlocals);
      61      1662768 : }
      62              : 
      63              : /*******************************************************************************
      64              :  * \brief Private routine for releasing a one dimensional distribution.
      65              :  * \author Ole Schuett
      66              :  ******************************************************************************/
      67      1662768 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
      68      1662768 :   free(dist->index2coord);
      69      1662768 :   free(dist->local_indicies);
      70      1662768 :   cp_mpi_comm_free(&dist->comm);
      71      1662768 : }
      72              : 
      73              : /*******************************************************************************
      74              :  * \brief Private routine for finding the optimal number of shard rows.
      75              :  * \author Ole Schuett
      76              :  ******************************************************************************/
      77       831384 : static int find_best_nrow_shards(const int nshards, const int nrows,
      78              :                                  const int ncols) {
      79       831384 :   const double target = imax(nrows, 1) / (double)imax(ncols, 1);
      80       831384 :   int best_nrow_shards = nshards;
      81       831384 :   double best_error = fabs(log(target / (double)nshards));
      82              : 
      83      1662768 :   for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
      84       831384 :     const int ncol_shards = nshards / nrow_shards;
      85       831384 :     if (nrow_shards * ncol_shards != nshards) {
      86            0 :       continue; // Not a factor of nshards.
      87              :     }
      88       831384 :     const double ratio = (double)nrow_shards / (double)ncol_shards;
      89       831384 :     const double error = fabs(log(target / ratio));
      90       831384 :     if (error < best_error) {
      91            0 :       best_error = error;
      92            0 :       best_nrow_shards = nrow_shards;
      93              :     }
      94              :   }
      95       831384 :   return best_nrow_shards;
      96              : }
      97              : 
      98              : /*******************************************************************************
      99              :  * \brief Creates a new two dimensional distribution.
     100              :  * \author Ole Schuett
     101              :  ******************************************************************************/
     102       831384 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
     103              :                           const int nrows, const int ncols,
     104              :                           const int row_dist[nrows],
     105              :                           const int col_dist[ncols]) {
     106       831384 :   assert(omp_get_num_threads() == 1);
     107       831384 :   dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
     108       831384 :   dist->ref_count = 1;
     109              : 
     110       831384 :   dist->comm = cp_mpi_comm_f2c(fortran_comm);
     111       831384 :   dist->my_rank = cp_mpi_comm_rank(dist->comm);
     112       831384 :   dist->nranks = cp_mpi_comm_size(dist->comm);
     113              : 
     114       831384 :   const int row_dim_remains[2] = {1, 0};
     115       831384 :   const cp_mpi_comm_t row_comm = cp_mpi_cart_sub(dist->comm, row_dim_remains);
     116              : 
     117       831384 :   const int col_dim_remains[2] = {0, 1};
     118       831384 :   const cp_mpi_comm_t col_comm = cp_mpi_cart_sub(dist->comm, col_dim_remains);
     119              : 
     120       831384 :   const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads();
     121       831384 :   const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
     122       831384 :   const int ncol_shards = nshards / nrow_shards;
     123              : 
     124       831384 :   dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
     125       831384 :   dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
     126              : 
     127       831384 :   assert(*dist_out == NULL);
     128       831384 :   *dist_out = dist;
     129       831384 : }
     130              : 
     131              : /*******************************************************************************
     132              :  * \brief Increases the reference counter of the given distribution.
     133              :  * \author Ole Schuett
     134              :  ******************************************************************************/
     135      2409786 : void dbm_distribution_hold(dbm_distribution_t *dist) {
     136      2409786 :   assert(dist->ref_count > 0);
     137      2409786 :   dist->ref_count++;
     138      2409786 : }
     139              : 
     140              : /*******************************************************************************
     141              :  * \brief Decreases the reference counter of the given distribution.
     142              :  * \author Ole Schuett
     143              :  ******************************************************************************/
     144      3241170 : void dbm_distribution_release(dbm_distribution_t *dist) {
     145      3241170 :   assert(dist->ref_count > 0);
     146      3241170 :   dist->ref_count--;
     147      3241170 :   if (dist->ref_count == 0) {
     148       831384 :     dbm_dist_1d_free(&dist->rows);
     149       831384 :     dbm_dist_1d_free(&dist->cols);
     150       831384 :     free(dist);
     151              :   }
     152      3241170 : }
     153              : 
     154              : /*******************************************************************************
     155              :  * \brief Returns the rows of the given distribution.
     156              :  * \author Ole Schuett
     157              :  ******************************************************************************/
     158       339146 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
     159              :                                const int **row_dist) {
     160       339146 :   assert(dist->ref_count > 0);
     161       339146 :   *nrows = dist->rows.length;
     162       339146 :   *row_dist = dist->rows.index2coord;
     163       339146 : }
     164              : 
     165              : /*******************************************************************************
     166              :  * \brief Returns the columns of the given distribution.
     167              :  * \author Ole Schuett
     168              :  ******************************************************************************/
     169       339146 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
     170              :                                const int **col_dist) {
     171       339146 :   assert(dist->ref_count > 0);
     172       339146 :   *ncols = dist->cols.length;
     173       339146 :   *col_dist = dist->cols.index2coord;
     174       339146 : }
     175              : 
     176              : /*******************************************************************************
     177              :  * \brief Returns the MPI rank on which the given block should be stored.
     178              :  * \author Ole Schuett
     179              :  ******************************************************************************/
     180     92208768 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist,
     181              :                                    const int row, const int col) {
     182     92208768 :   assert(dist->ref_count > 0);
     183     92208768 :   assert(0 <= row && row < dist->rows.length);
     184     92208768 :   assert(0 <= col && col < dist->cols.length);
     185     92208768 :   int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
     186     92208768 :   return cp_mpi_cart_rank(dist->comm, coords);
     187              : }
     188              : 
     189              : // EOF
        

Generated by: LCOV version 2.0-1