Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_distribution.h"
8 : #include "dbm_hyperparams.h"
9 : #include "dbm_internal.h"
10 :
11 : #include <assert.h>
12 : #include <math.h>
13 : #include <omp.h>
14 : #include <stdbool.h>
15 : #include <stddef.h>
16 : #include <stdlib.h>
17 : #include <string.h>
18 :
19 : /*******************************************************************************
20 : * \brief Private routine for creating a new one dimensional distribution.
21 : * \author Ole Schuett
22 : ******************************************************************************/
23 2316576 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
24 : const int coords[length], const cp_mpi_comm_t comm,
25 : const int nshards) {
26 2316576 : dist->comm = comm;
27 2316576 : dist->nshards = nshards;
28 2316576 : dist->my_rank = cp_mpi_comm_rank(comm);
29 2316576 : dist->nranks = cp_mpi_comm_size(comm);
30 2316576 : dist->length = length;
31 2316576 : dist->index2coord = malloc(length * sizeof(int));
32 2316576 : assert(dist->index2coord != NULL || length == 0);
33 2316576 : if (length != 0) {
34 2313705 : memcpy(dist->index2coord, coords, length * sizeof(int));
35 : }
36 :
37 : // Check that cart coordinates and ranks are equivalent.
38 2316576 : int cart_dims[1], cart_periods[1], cart_coords[1];
39 2316576 : cp_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
40 2316576 : assert(dist->nranks == cart_dims[0]);
41 2316576 : assert(dist->my_rank == cart_coords[0]);
42 :
43 : // Count local rows/columns.
44 24429448 : for (int i = 0; i < length; i++) {
45 22112872 : assert(0 <= coords[i] && coords[i] < dist->nranks);
46 22112872 : if (coords[i] == dist->my_rank) {
47 21003923 : dist->nlocals++;
48 : }
49 : }
50 :
51 : // Store local rows/columns.
52 2316576 : dist->local_indicies = malloc(dist->nlocals * sizeof(int));
53 2316576 : assert(dist->local_indicies != NULL || dist->nlocals == 0);
54 : int j = 0;
55 24429448 : for (int i = 0; i < length; i++) {
56 22112872 : if (coords[i] == dist->my_rank) {
57 21003923 : dist->local_indicies[j++] = i;
58 : }
59 : }
60 2316576 : assert(j == dist->nlocals);
61 2316576 : }
62 :
63 : /*******************************************************************************
64 : * \brief Private routine for releasing a one dimensional distribution.
65 : * \author Ole Schuett
66 : ******************************************************************************/
67 2316576 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
68 2316576 : free(dist->index2coord);
69 2316576 : free(dist->local_indicies);
70 2316576 : cp_mpi_comm_free(&dist->comm);
71 2316576 : }
72 :
73 : /*******************************************************************************
74 : * \brief Private routine for finding the optimal number of shard rows.
75 : * \author Ole Schuett
76 : ******************************************************************************/
77 1158288 : static int find_best_nrow_shards(const int nshards, const int nrows,
78 : const int ncols) {
79 1158288 : const double target = imax(nrows, 1) / (double)imax(ncols, 1);
80 1158288 : int best_nrow_shards = nshards;
81 1158288 : double best_error = fabs(log(target / (double)nshards));
82 :
83 2316576 : for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
84 1158288 : const int ncol_shards = nshards / nrow_shards;
85 1158288 : if (nrow_shards * ncol_shards != nshards) {
86 0 : continue; // Not a factor of nshards.
87 : }
88 1158288 : const double ratio = (double)nrow_shards / (double)ncol_shards;
89 1158288 : const double error = fabs(log(target / ratio));
90 1158288 : if (error < best_error) {
91 0 : best_error = error;
92 0 : best_nrow_shards = nrow_shards;
93 : }
94 : }
95 1158288 : return best_nrow_shards;
96 : }
97 :
98 : /*******************************************************************************
99 : * \brief Creates a new two dimensional distribution.
100 : * \author Ole Schuett
101 : ******************************************************************************/
102 1158288 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
103 : const int nrows, const int ncols,
104 : const int row_dist[nrows],
105 : const int col_dist[ncols]) {
106 1158288 : assert(omp_get_num_threads() == 1);
107 1158288 : dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
108 1158288 : dist->ref_count = 1;
109 :
110 1158288 : dist->comm = cp_mpi_comm_f2c(fortran_comm);
111 1158288 : dist->my_rank = cp_mpi_comm_rank(dist->comm);
112 1158288 : dist->nranks = cp_mpi_comm_size(dist->comm);
113 :
114 1158288 : const int row_dim_remains[2] = {1, 0};
115 1158288 : const cp_mpi_comm_t row_comm = cp_mpi_cart_sub(dist->comm, row_dim_remains);
116 :
117 1158288 : const int col_dim_remains[2] = {0, 1};
118 1158288 : const cp_mpi_comm_t col_comm = cp_mpi_cart_sub(dist->comm, col_dim_remains);
119 :
120 1158288 : const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads();
121 1158288 : const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
122 1158288 : const int ncol_shards = nshards / nrow_shards;
123 :
124 1158288 : dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
125 1158288 : dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
126 :
127 1158288 : assert(*dist_out == NULL);
128 1158288 : *dist_out = dist;
129 1158288 : }
130 :
131 : /*******************************************************************************
132 : * \brief Increases the reference counter of the given distribution.
133 : * \author Ole Schuett
134 : ******************************************************************************/
135 3331071 : void dbm_distribution_hold(dbm_distribution_t *dist) {
136 3331071 : assert(dist->ref_count > 0);
137 3331071 : dist->ref_count++;
138 3331071 : }
139 :
140 : /*******************************************************************************
141 : * \brief Decreases the reference counter of the given distribution.
142 : * \author Ole Schuett
143 : ******************************************************************************/
144 4489359 : void dbm_distribution_release(dbm_distribution_t *dist) {
145 4489359 : assert(dist->ref_count > 0);
146 4489359 : dist->ref_count--;
147 4489359 : if (dist->ref_count == 0) {
148 1158288 : dbm_dist_1d_free(&dist->rows);
149 1158288 : dbm_dist_1d_free(&dist->cols);
150 1158288 : free(dist);
151 : }
152 4489359 : }
153 :
154 : /*******************************************************************************
155 : * \brief Returns the rows of the given distribution.
156 : * \author Ole Schuett
157 : ******************************************************************************/
158 437023 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
159 : const int **row_dist) {
160 437023 : assert(dist->ref_count > 0);
161 437023 : *nrows = dist->rows.length;
162 437023 : *row_dist = dist->rows.index2coord;
163 437023 : }
164 :
165 : /*******************************************************************************
166 : * \brief Returns the columns of the given distribution.
167 : * \author Ole Schuett
168 : ******************************************************************************/
169 437023 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
170 : const int **col_dist) {
171 437023 : assert(dist->ref_count > 0);
172 437023 : *ncols = dist->cols.length;
173 437023 : *col_dist = dist->cols.index2coord;
174 437023 : }
175 :
176 : /*******************************************************************************
177 : * \brief Returns the MPI rank on which the given block should be stored.
178 : * \author Ole Schuett
179 : ******************************************************************************/
180 95922731 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist,
181 : const int row, const int col) {
182 95922731 : assert(dist->ref_count > 0);
183 95922731 : assert(0 <= row && row < dist->rows.length);
184 95922731 : assert(0 <= col && col < dist->cols.length);
185 95922731 : int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
186 95922731 : return cp_mpi_cart_rank(dist->comm, coords);
187 : }
188 :
189 : // EOF
|