Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_distribution.h"
8 : #include "dbm_hyperparams.h"
9 : #include "dbm_internal.h"
10 :
11 : #include <assert.h>
12 : #include <math.h>
13 : #include <omp.h>
14 : #include <stdbool.h>
15 : #include <stddef.h>
16 : #include <stdlib.h>
17 : #include <string.h>
18 :
19 : /*******************************************************************************
20 : * \brief Private routine for creating a new one dimensional distribution.
21 : * \author Ole Schuett
22 : ******************************************************************************/
23 1662768 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
24 : const int coords[length], const cp_mpi_comm_t comm,
25 : const int nshards) {
26 1662768 : dist->comm = comm;
27 1662768 : dist->nshards = nshards;
28 1662768 : dist->my_rank = cp_mpi_comm_rank(comm);
29 1662768 : dist->nranks = cp_mpi_comm_size(comm);
30 1662768 : dist->length = length;
31 1662768 : dist->index2coord = malloc(length * sizeof(int));
32 1662768 : assert(dist->index2coord != NULL || length == 0);
33 1662768 : if (length != 0) {
34 1659897 : memcpy(dist->index2coord, coords, length * sizeof(int));
35 : }
36 :
37 : // Check that cart coordinates and ranks are equivalent.
38 1662768 : int cart_dims[1], cart_periods[1], cart_coords[1];
39 1662768 : cp_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
40 1662768 : assert(dist->nranks == cart_dims[0]);
41 1662768 : assert(dist->my_rank == cart_coords[0]);
42 :
43 : // Count local rows/columns.
44 22359790 : for (int i = 0; i < length; i++) {
45 20697022 : assert(0 <= coords[i] && coords[i] < dist->nranks);
46 20697022 : if (coords[i] == dist->my_rank) {
47 19751315 : dist->nlocals++;
48 : }
49 : }
50 :
51 : // Store local rows/columns.
52 1662768 : dist->local_indicies = malloc(dist->nlocals * sizeof(int));
53 1662768 : assert(dist->local_indicies != NULL || dist->nlocals == 0);
54 : int j = 0;
55 22359790 : for (int i = 0; i < length; i++) {
56 20697022 : if (coords[i] == dist->my_rank) {
57 19751315 : dist->local_indicies[j++] = i;
58 : }
59 : }
60 1662768 : assert(j == dist->nlocals);
61 1662768 : }
62 :
63 : /*******************************************************************************
64 : * \brief Private routine for releasing a one dimensional distribution.
65 : * \author Ole Schuett
66 : ******************************************************************************/
67 1662768 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
68 1662768 : free(dist->index2coord);
69 1662768 : free(dist->local_indicies);
70 1662768 : cp_mpi_comm_free(&dist->comm);
71 1662768 : }
72 :
73 : /*******************************************************************************
74 : * \brief Private routine for finding the optimal number of shard rows.
75 : * \author Ole Schuett
76 : ******************************************************************************/
77 831384 : static int find_best_nrow_shards(const int nshards, const int nrows,
78 : const int ncols) {
79 831384 : const double target = imax(nrows, 1) / (double)imax(ncols, 1);
80 831384 : int best_nrow_shards = nshards;
81 831384 : double best_error = fabs(log(target / (double)nshards));
82 :
83 1662768 : for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
84 831384 : const int ncol_shards = nshards / nrow_shards;
85 831384 : if (nrow_shards * ncol_shards != nshards) {
86 0 : continue; // Not a factor of nshards.
87 : }
88 831384 : const double ratio = (double)nrow_shards / (double)ncol_shards;
89 831384 : const double error = fabs(log(target / ratio));
90 831384 : if (error < best_error) {
91 0 : best_error = error;
92 0 : best_nrow_shards = nrow_shards;
93 : }
94 : }
95 831384 : return best_nrow_shards;
96 : }
97 :
98 : /*******************************************************************************
99 : * \brief Creates a new two dimensional distribution.
100 : * \author Ole Schuett
101 : ******************************************************************************/
102 831384 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
103 : const int nrows, const int ncols,
104 : const int row_dist[nrows],
105 : const int col_dist[ncols]) {
106 831384 : assert(omp_get_num_threads() == 1);
107 831384 : dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
108 831384 : dist->ref_count = 1;
109 :
110 831384 : dist->comm = cp_mpi_comm_f2c(fortran_comm);
111 831384 : dist->my_rank = cp_mpi_comm_rank(dist->comm);
112 831384 : dist->nranks = cp_mpi_comm_size(dist->comm);
113 :
114 831384 : const int row_dim_remains[2] = {1, 0};
115 831384 : const cp_mpi_comm_t row_comm = cp_mpi_cart_sub(dist->comm, row_dim_remains);
116 :
117 831384 : const int col_dim_remains[2] = {0, 1};
118 831384 : const cp_mpi_comm_t col_comm = cp_mpi_cart_sub(dist->comm, col_dim_remains);
119 :
120 831384 : const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads();
121 831384 : const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
122 831384 : const int ncol_shards = nshards / nrow_shards;
123 :
124 831384 : dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
125 831384 : dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
126 :
127 831384 : assert(*dist_out == NULL);
128 831384 : *dist_out = dist;
129 831384 : }
130 :
131 : /*******************************************************************************
132 : * \brief Increases the reference counter of the given distribution.
133 : * \author Ole Schuett
134 : ******************************************************************************/
135 2409786 : void dbm_distribution_hold(dbm_distribution_t *dist) {
136 2409786 : assert(dist->ref_count > 0);
137 2409786 : dist->ref_count++;
138 2409786 : }
139 :
140 : /*******************************************************************************
141 : * \brief Decreases the reference counter of the given distribution.
142 : * \author Ole Schuett
143 : ******************************************************************************/
144 3241170 : void dbm_distribution_release(dbm_distribution_t *dist) {
145 3241170 : assert(dist->ref_count > 0);
146 3241170 : dist->ref_count--;
147 3241170 : if (dist->ref_count == 0) {
148 831384 : dbm_dist_1d_free(&dist->rows);
149 831384 : dbm_dist_1d_free(&dist->cols);
150 831384 : free(dist);
151 : }
152 3241170 : }
153 :
154 : /*******************************************************************************
155 : * \brief Returns the rows of the given distribution.
156 : * \author Ole Schuett
157 : ******************************************************************************/
158 339146 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
159 : const int **row_dist) {
160 339146 : assert(dist->ref_count > 0);
161 339146 : *nrows = dist->rows.length;
162 339146 : *row_dist = dist->rows.index2coord;
163 339146 : }
164 :
165 : /*******************************************************************************
166 : * \brief Returns the columns of the given distribution.
167 : * \author Ole Schuett
168 : ******************************************************************************/
169 339146 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
170 : const int **col_dist) {
171 339146 : assert(dist->ref_count > 0);
172 339146 : *ncols = dist->cols.length;
173 339146 : *col_dist = dist->cols.index2coord;
174 339146 : }
175 :
176 : /*******************************************************************************
177 : * \brief Returns the MPI rank on which the given block should be stored.
178 : * \author Ole Schuett
179 : ******************************************************************************/
180 92208768 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist,
181 : const int row, const int col) {
182 92208768 : assert(dist->ref_count > 0);
183 92208768 : assert(0 <= row && row < dist->rows.length);
184 92208768 : assert(0 <= col && col < dist->cols.length);
185 92208768 : int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
186 92208768 : return cp_mpi_cart_rank(dist->comm, coords);
187 : }
188 :
189 : // EOF
|