LCOV - code coverage report
Current view: top level - src/dbm - dbm_distribution.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:e7e05ae) Lines: 87 90 96.7 %
Date: 2024-04-18 06:59:28 Functions: 9 9 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2024 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <assert.h>
       9             : #include <math.h>
      10             : #include <omp.h>
      11             : #include <stdbool.h>
      12             : #include <stddef.h>
      13             : #include <stdlib.h>
      14             : #include <string.h>
      15             : 
      16             : #include "dbm_distribution.h"
      17             : #include "dbm_hyperparams.h"
      18             : 
      19             : /*******************************************************************************
      20             :  * \brief Private routine for creating a new one dimensional distribution.
      21             :  * \author Ole Schuett
      22             :  ******************************************************************************/
      23     1179144 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
      24             :                             const int coords[length], const dbm_mpi_comm_t comm,
      25             :                             const int nshards) {
      26     1179144 :   dist->comm = comm;
      27     1179144 :   dist->nshards = nshards;
      28     1179144 :   dist->my_rank = dbm_mpi_comm_rank(comm);
      29     1179144 :   dist->nranks = dbm_mpi_comm_size(comm);
      30     1179144 :   dist->length = length;
      31     1179144 :   dist->index2coord = malloc(length * sizeof(int));
      32     1179144 :   memcpy(dist->index2coord, coords, length * sizeof(int));
      33             : 
      34             :   // Check that cart coordinates and ranks are equivalent.
      35     1179144 :   int cart_dims[1], cart_periods[1], cart_coords[1];
      36     1179144 :   dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
      37     1179144 :   assert(dist->nranks == cart_dims[0]);
      38     1179144 :   assert(dist->my_rank == cart_coords[0]);
      39             : 
      40             :   // Count local rows/columns.
      41    19346560 :   for (int i = 0; i < length; i++) {
      42    18167416 :     assert(0 <= coords[i] && coords[i] < dist->nranks);
      43    18167416 :     if (coords[i] == dist->my_rank) {
      44    17315955 :       dist->nlocals++;
      45             :     }
      46             :   }
      47             : 
      48             :   // Store local rows/columns.
      49     1179144 :   dist->local_indicies = malloc(dist->nlocals * sizeof(int));
      50     1179144 :   int j = 0;
      51    19346560 :   for (int i = 0; i < length; i++) {
      52    18167416 :     if (coords[i] == dist->my_rank) {
      53    17315955 :       dist->local_indicies[j++] = i;
      54             :     }
      55             :   }
      56     1179144 :   assert(j == dist->nlocals);
      57     1179144 : }
      58             : 
      59             : /*******************************************************************************
      60             :  * \brief Private routine for releasing a one dimensional distribution.
      61             :  * \author Ole Schuett
      62             :  ******************************************************************************/
      63     1179144 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
      64     1179144 :   free(dist->index2coord);
      65     1179144 :   free(dist->local_indicies);
      66     1179144 :   dbm_mpi_comm_free(&dist->comm);
      67     1179144 : }
      68             : 
      69             : /*******************************************************************************
      70             :  * \brief Returns the larger of two given integer (missing from the C standard)
      71             :  * \author Ole Schuett
      72             :  ******************************************************************************/
      73      589572 : static inline int imax(int x, int y) { return (x > y ? x : y); }
      74             : 
      75             : /*******************************************************************************
      76             :  * \brief Private routine for finding the optimal number of shard rows.
      77             :  * \author Ole Schuett
      78             :  ******************************************************************************/
      79      589572 : static int find_best_nrow_shards(const int nshards, const int nrows,
      80             :                                  const int ncols) {
      81      589572 :   const double target = (double)imax(nrows, 1) / (double)imax(ncols, 1);
      82      589572 :   int best_nrow_shards = nshards;
      83      589572 :   double best_error = fabs(log(target / (double)nshards));
      84             : 
      85     1179144 :   for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
      86      589572 :     const int ncol_shards = nshards / nrow_shards;
      87      589572 :     if (nrow_shards * ncol_shards != nshards)
      88           0 :       continue; // Not a factor of nshards.
      89      589572 :     const double ratio = (double)nrow_shards / (double)ncol_shards;
      90      589572 :     const double error = fabs(log(target / ratio));
      91      589572 :     if (error < best_error) {
      92           0 :       best_error = error;
      93           0 :       best_nrow_shards = nrow_shards;
      94             :     }
      95             :   }
      96      589572 :   return best_nrow_shards;
      97             : }
      98             : 
      99             : /*******************************************************************************
     100             :  * \brief Creates a new two dimensional distribution.
     101             :  * \author Ole Schuett
     102             :  ******************************************************************************/
     103      589572 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
     104             :                           const int nrows, const int ncols,
     105             :                           const int row_dist[nrows],
     106             :                           const int col_dist[ncols]) {
     107      589572 :   assert(omp_get_num_threads() == 1);
     108      589572 :   dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
     109      589572 :   dist->ref_count = 1;
     110             : 
     111      589572 :   dist->comm = dbm_mpi_comm_f2c(fortran_comm);
     112      589572 :   dist->my_rank = dbm_mpi_comm_rank(dist->comm);
     113      589572 :   dist->nranks = dbm_mpi_comm_size(dist->comm);
     114             : 
     115      589572 :   const int row_dim_remains[2] = {1, 0};
     116      589572 :   const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains);
     117             : 
     118      589572 :   const int col_dim_remains[2] = {0, 1};
     119      589572 :   const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);
     120             : 
     121      589572 :   const int nshards = SHARDS_PER_THREAD * omp_get_max_threads();
     122      589572 :   const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
     123      589572 :   const int ncol_shards = nshards / nrow_shards;
     124             : 
     125      589572 :   dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
     126      589572 :   dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
     127             : 
     128      589572 :   assert(*dist_out == NULL);
     129      589572 :   *dist_out = dist;
     130      589572 : }
     131             : 
     132             : /*******************************************************************************
     133             :  * \brief Increases the reference counter of the given distribution.
     134             :  * \author Ole Schuett
     135             :  ******************************************************************************/
     136     1597771 : void dbm_distribution_hold(dbm_distribution_t *dist) {
     137     1597771 :   assert(dist->ref_count > 0);
     138     1597771 :   dist->ref_count++;
     139     1597771 : }
     140             : 
     141             : /*******************************************************************************
     142             :  * \brief Decreases the reference counter of the given distribution.
     143             :  * \author Ole Schuett
     144             :  ******************************************************************************/
     145     2187343 : void dbm_distribution_release(dbm_distribution_t *dist) {
     146     2187343 :   assert(dist->ref_count > 0);
     147     2187343 :   dist->ref_count--;
     148     2187343 :   if (dist->ref_count == 0) {
     149      589572 :     dbm_dist_1d_free(&dist->rows);
     150      589572 :     dbm_dist_1d_free(&dist->cols);
     151      589572 :     free(dist);
     152             :   }
     153     2187343 : }
     154             : 
     155             : /*******************************************************************************
     156             :  * \brief Returns the rows of the given distribution.
     157             :  * \author Ole Schuett
     158             :  ******************************************************************************/
     159      241284 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
     160             :                                const int **row_dist) {
     161      241284 :   assert(dist->ref_count > 0);
     162      241284 :   *nrows = dist->rows.length;
     163      241284 :   *row_dist = dist->rows.index2coord;
     164      241284 : }
     165             : 
     166             : /*******************************************************************************
     167             :  * \brief Returns the columns of the given distribution.
     168             :  * \author Ole Schuett
     169             :  ******************************************************************************/
     170      241284 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
     171             :                                const int **col_dist) {
     172      241284 :   assert(dist->ref_count > 0);
     173      241284 :   *ncols = dist->cols.length;
     174      241284 :   *col_dist = dist->cols.index2coord;
     175      241284 : }
     176             : 
     177             : /*******************************************************************************
     178             :  * \brief Returns the MPI rank on which the given block should be stored.
     179             :  * \author Ole Schuett
     180             :  ******************************************************************************/
     181    74793680 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist,
     182             :                                    const int row, const int col) {
     183    74793680 :   assert(dist->ref_count > 0);
     184    74793680 :   assert(0 <= row && row < dist->rows.length);
     185    74793680 :   assert(0 <= col && col < dist->cols.length);
     186    74793680 :   int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
     187    74793680 :   return dbm_mpi_cart_rank(dist->comm, coords);
     188             : }
     189             : 
     190             : // EOF

Generated by: LCOV version 1.15