Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <omp.h>
10 : #include <stdio.h>
11 : #include <stdlib.h>
12 : #include <string.h>
13 :
14 : #include "dbm_mpi.h"
15 :
16 : #if defined(__parallel)
17 : /*******************************************************************************
18 : * \brief Check given MPI status and upon failure abort with a nice message.
19 : * \author Ole Schuett
20 : ******************************************************************************/
21 : #define CHECK(STATUS) \
22 : do { \
23 : if (MPI_SUCCESS != (STATUS)) { \
24 : fprintf(stderr, "MPI error #%i in %s:%i\n", STATUS, __FILE__, __LINE__); \
25 : MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); \
26 : } \
27 : } while (0)
28 : #endif
29 :
30 : /*******************************************************************************
31 : * \brief Wrapper around MPI_Init.
32 : * \author Ole Schuett
33 : ******************************************************************************/
34 0 : void dbm_mpi_init(int *argc, char ***argv) {
35 : #if defined(__parallel)
36 0 : CHECK(MPI_Init(argc, argv));
37 : #else
38 : (void)argc; // mark used
39 : (void)argv;
40 : #endif
41 0 : }
42 :
43 : /*******************************************************************************
44 : * \brief Wrapper around MPI_Finalize.
45 : * \author Ole Schuett
46 : ******************************************************************************/
47 0 : void dbm_mpi_finalize() {
48 : #if defined(__parallel)
49 0 : CHECK(MPI_Finalize());
50 : #endif
51 0 : }
52 :
53 : /*******************************************************************************
54 : * \brief Returns MPI_COMM_WORLD.
55 : * \author Ole Schuett
56 : ******************************************************************************/
57 0 : dbm_mpi_comm_t dbm_mpi_get_comm_world() {
58 : #if defined(__parallel)
59 0 : return MPI_COMM_WORLD;
60 : #else
61 : return -1;
62 : #endif
63 : }
64 :
65 : /*******************************************************************************
66 : * \brief Wrapper around MPI_Comm_f2c.
67 : * \author Ole Schuett
68 : ******************************************************************************/
69 809316 : dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm) {
70 : #if defined(__parallel)
71 809316 : return MPI_Comm_f2c(fortran_comm);
72 : #else
73 : (void)fortran_comm; // mark used
74 : return -1;
75 : #endif
76 : }
77 :
78 : /*******************************************************************************
79 : * \brief Wrapper around MPI_Comm_c2f.
80 : * \author Ole Schuett
81 : ******************************************************************************/
82 0 : int dbm_mpi_comm_c2f(const dbm_mpi_comm_t comm) {
83 : #if defined(__parallel)
84 0 : return MPI_Comm_c2f(comm);
85 : #else
86 : (void)comm; // mark used
87 : return -1;
88 : #endif
89 : }
90 :
91 : /*******************************************************************************
92 : * \brief Wrapper around MPI_Comm_rank.
93 : * \author Ole Schuett
94 : ******************************************************************************/
95 2400150 : int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm) {
96 : #if defined(__parallel)
97 2400150 : int rank;
98 2400150 : CHECK(MPI_Comm_rank(comm, &rank));
99 2400150 : return rank;
100 : #else
101 : (void)comm; // mark used
102 : return 0;
103 : #endif
104 : }
105 :
106 : /*******************************************************************************
107 : * \brief Wrapper around MPI_Comm_size.
108 : * \author Ole Schuett
109 : ******************************************************************************/
110 2400294 : int dbm_mpi_comm_size(const dbm_mpi_comm_t comm) {
111 : #if defined(__parallel)
112 2400294 : int nranks;
113 2400294 : CHECK(MPI_Comm_size(comm, &nranks));
114 2400294 : return nranks;
115 : #else
116 : (void)comm; // mark used
117 : return 1;
118 : #endif
119 : }
120 :
121 : /*******************************************************************************
122 : * \brief Wrapper around MPI_Dims_create.
123 : * \author Ole Schuett
124 : ******************************************************************************/
125 0 : void dbm_mpi_dims_create(const int nnodes, const int ndims, int dims[]) {
126 : #if defined(__parallel)
127 0 : CHECK(MPI_Dims_create(nnodes, ndims, dims));
128 : #else
129 : dims[0] = nnodes;
130 : for (int i = 1; i < ndims; i++) {
131 : dims[i] = 1;
132 : }
133 : #endif
134 0 : }
135 :
136 : /*******************************************************************************
137 : * \brief Wrapper around MPI_Cart_create.
138 : * \author Ole Schuett
139 : ******************************************************************************/
140 0 : dbm_mpi_comm_t dbm_mpi_cart_create(const dbm_mpi_comm_t comm_old,
141 : const int ndims, const int dims[],
142 : const int periods[], const int reorder) {
143 : #if defined(__parallel)
144 0 : dbm_mpi_comm_t comm_cart;
145 0 : CHECK(MPI_Cart_create(comm_old, ndims, dims, periods, reorder, &comm_cart));
146 0 : return comm_cart;
147 : #else
148 : (void)comm_old; // mark used
149 : (void)ndims;
150 : (void)dims;
151 : (void)periods;
152 : (void)reorder;
153 : return -1;
154 : #endif
155 : }
156 :
157 : /*******************************************************************************
158 : * \brief Wrapper around MPI_Cart_get.
159 : * \author Ole Schuett
160 : ******************************************************************************/
161 1600100 : void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[],
162 : int periods[], int coords[]) {
163 : #if defined(__parallel)
164 1600100 : CHECK(MPI_Cart_get(comm, maxdims, dims, periods, coords));
165 : #else
166 : (void)comm; // mark used
167 : for (int i = 0; i < maxdims; i++) {
168 : dims[i] = 1;
169 : periods[i] = 1;
170 : coords[i] = 0;
171 : }
172 : #endif
173 1600100 : }
174 :
175 : /*******************************************************************************
176 : * \brief Wrapper around MPI_Cart_rank.
177 : * \author Ole Schuett
178 : ******************************************************************************/
179 109290995 : int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[]) {
180 : #if defined(__parallel)
181 109290995 : int rank;
182 109290995 : CHECK(MPI_Cart_rank(comm, coords, &rank));
183 109290995 : return rank;
184 : #else
185 : (void)comm; // mark used
186 : (void)coords;
187 : return 0;
188 : #endif
189 : }
190 :
191 : /*******************************************************************************
192 : * \brief Wrapper around MPI_Cart_sub.
193 : * \author Ole Schuett
194 : ******************************************************************************/
195 1600100 : dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm,
196 : const int remain_dims[]) {
197 : #if defined(__parallel)
198 1600100 : dbm_mpi_comm_t newcomm;
199 1600100 : CHECK(MPI_Cart_sub(comm, remain_dims, &newcomm));
200 1600100 : return newcomm;
201 : #else
202 : (void)comm; // mark used
203 : (void)remain_dims;
204 : return -1;
205 : #endif
206 : }
207 :
208 : /*******************************************************************************
209 : * \brief Wrapper around MPI_Comm_free.
210 : * \author Ole Schuett
211 : ******************************************************************************/
212 1600100 : void dbm_mpi_comm_free(dbm_mpi_comm_t *comm) {
213 : #if defined(__parallel)
214 1600100 : CHECK(MPI_Comm_free(comm));
215 : #else
216 : (void)comm; // mark used
217 : #endif
218 1600100 : }
219 :
220 : /*******************************************************************************
221 : * \brief Wrapper around MPI_Comm_compare.
222 : * \author Ole Schuett
223 : ******************************************************************************/
224 394878 : bool dbm_mpi_comms_are_similar(const dbm_mpi_comm_t comm1,
225 : const dbm_mpi_comm_t comm2) {
226 : #if defined(__parallel)
227 394878 : int res;
228 394878 : CHECK(MPI_Comm_compare(comm1, comm2, &res));
229 394878 : return res == MPI_IDENT || res == MPI_CONGRUENT || res == MPI_SIMILAR;
230 : #else
231 : (void)comm1; // mark used
232 : (void)comm2;
233 : return true;
234 : #endif
235 : }
236 :
237 : /*******************************************************************************
238 : * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_INT.
239 : * \author Ole Schuett
240 : ******************************************************************************/
241 789468 : void dbm_mpi_max_int(int *values, const int count, const dbm_mpi_comm_t comm) {
242 : #if defined(__parallel)
243 789468 : int value = 0;
244 789468 : void *recvbuf = (1 < count ? dbm_mpi_alloc_mem(count * sizeof(int)) : &value);
245 789468 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_MAX, comm));
246 789468 : memcpy(values, recvbuf, count * sizeof(int));
247 789468 : if (1 < count) {
248 0 : dbm_mpi_free_mem(recvbuf);
249 : }
250 : #else
251 : (void)comm; // mark used
252 : (void)values;
253 : (void)count;
254 : #endif
255 789468 : }
256 :
257 : /*******************************************************************************
258 : * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_UINT64_T.
259 : * \author Ole Schuett
260 : ******************************************************************************/
261 18910 : void dbm_mpi_max_uint64(uint64_t *values, const int count,
262 : const dbm_mpi_comm_t comm) {
263 : #if defined(__parallel)
264 18910 : uint64_t value = 0;
265 37820 : void *recvbuf =
266 18910 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(uint64_t)) : &value);
267 18910 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_UINT64_T, MPI_MAX, comm));
268 18910 : memcpy(values, recvbuf, count * sizeof(uint64_t));
269 18910 : if (1 < count) {
270 0 : dbm_mpi_free_mem(recvbuf);
271 : }
272 : #else
273 : (void)comm; // mark used
274 : (void)values;
275 : (void)count;
276 : #endif
277 18910 : }
278 :
279 : /*******************************************************************************
280 : * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_DOUBLE.
281 : * \author Ole Schuett
282 : ******************************************************************************/
283 48 : void dbm_mpi_max_double(double *values, const int count,
284 : const dbm_mpi_comm_t comm) {
285 : #if defined(__parallel)
286 48 : double value = 0;
287 96 : void *recvbuf =
288 48 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(double)) : &value);
289 48 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_MAX, comm));
290 48 : memcpy(values, recvbuf, count * sizeof(double));
291 48 : if (1 < count) {
292 0 : dbm_mpi_free_mem(recvbuf);
293 : }
294 : #else
295 : (void)comm; // mark used
296 : (void)values;
297 : (void)count;
298 : #endif
299 48 : }
300 :
301 : /*******************************************************************************
302 : * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT.
303 : * \author Ole Schuett
304 : ******************************************************************************/
305 197367 : void dbm_mpi_sum_int(int *values, const int count, const dbm_mpi_comm_t comm) {
306 : #if defined(__parallel)
307 197367 : int value = 0;
308 197367 : void *recvbuf = (1 < count ? dbm_mpi_alloc_mem(count * sizeof(int)) : &value);
309 197367 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_SUM, comm));
310 197367 : memcpy(values, recvbuf, count * sizeof(int));
311 197367 : if (1 < count) {
312 194780 : dbm_mpi_free_mem(recvbuf);
313 : }
314 : #else
315 : (void)comm; // mark used
316 : (void)values;
317 : (void)count;
318 : #endif
319 197367 : }
320 :
321 : /*******************************************************************************
322 : * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT64_T.
323 : * \author Ole Schuett
324 : ******************************************************************************/
325 593024 : void dbm_mpi_sum_int64(int64_t *values, const int count,
326 : const dbm_mpi_comm_t comm) {
327 : #if defined(__parallel)
328 593024 : int64_t value = 0;
329 1186048 : void *recvbuf =
330 593024 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(int64_t)) : &value);
331 593024 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT64_T, MPI_SUM, comm));
332 593024 : memcpy(values, recvbuf, count * sizeof(int64_t));
333 593024 : if (1 < count) {
334 0 : dbm_mpi_free_mem(recvbuf);
335 : }
336 : #else
337 : (void)comm; // mark used
338 : (void)values;
339 : (void)count;
340 : #endif
341 593024 : }
342 :
343 : /*******************************************************************************
344 : * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_DOUBLE.
345 : * \author Ole Schuett
346 : ******************************************************************************/
347 190 : void dbm_mpi_sum_double(double *values, const int count,
348 : const dbm_mpi_comm_t comm) {
349 : #if defined(__parallel)
350 190 : double value = 0;
351 380 : void *recvbuf =
352 190 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(double)) : &value);
353 190 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_SUM, comm));
354 190 : memcpy(values, recvbuf, count * sizeof(double));
355 190 : if (1 < count) {
356 0 : dbm_mpi_free_mem(recvbuf);
357 : }
358 : #else
359 : (void)comm; // mark used
360 : (void)values;
361 : (void)count;
362 : #endif
363 190 : }
364 :
365 : /*******************************************************************************
366 : * \brief Wrapper around MPI_Sendrecv for datatype MPI_BYTE.
367 : * \author Ole Schuett
368 : ******************************************************************************/
369 17612 : int dbm_mpi_sendrecv_byte(const void *sendbuf, const int sendcount,
370 : const int dest, const int sendtag, void *recvbuf,
371 : const int recvcount, const int source,
372 : const int recvtag, const dbm_mpi_comm_t comm) {
373 : #if defined(__parallel)
374 17612 : MPI_Status status;
375 17612 : CHECK(MPI_Sendrecv(sendbuf, sendcount, MPI_BYTE, dest, sendtag, recvbuf,
376 : recvcount, MPI_BYTE, source, recvtag, comm, &status));
377 17612 : int count_received;
378 17612 : CHECK(MPI_Get_count(&status, MPI_BYTE, &count_received));
379 17612 : return count_received;
380 : #else
381 : (void)sendbuf; // mark used
382 : (void)sendcount;
383 : (void)dest;
384 : (void)sendtag;
385 : (void)recvbuf;
386 : (void)recvcount;
387 : (void)source;
388 : (void)recvtag;
389 : (void)comm;
390 : fprintf(stderr, "Error: dbm_mpi_sendrecv_byte not available without MPI\n");
391 : abort();
392 : #endif
393 : }
394 :
395 : /*******************************************************************************
396 : * \brief Wrapper around MPI_Sendrecv for datatype MPI_DOUBLE.
397 : * \author Ole Schuett
398 : ******************************************************************************/
399 17612 : int dbm_mpi_sendrecv_double(const double *sendbuf, const int sendcount,
400 : const int dest, const int sendtag, double *recvbuf,
401 : const int recvcount, const int source,
402 : const int recvtag, const dbm_mpi_comm_t comm) {
403 : #if defined(__parallel)
404 17612 : MPI_Status status;
405 17612 : CHECK(MPI_Sendrecv(sendbuf, sendcount, MPI_DOUBLE, dest, sendtag, recvbuf,
406 : recvcount, MPI_DOUBLE, source, recvtag, comm, &status));
407 17612 : int count_received;
408 17612 : CHECK(MPI_Get_count(&status, MPI_DOUBLE, &count_received));
409 17612 : return count_received;
410 : #else
411 : (void)sendbuf; // mark used
412 : (void)sendcount;
413 : (void)dest;
414 : (void)sendtag;
415 : (void)recvbuf;
416 : (void)recvcount;
417 : (void)source;
418 : (void)recvtag;
419 : (void)comm;
420 : fprintf(stderr, "Error: dbm_mpi_sendrecv_double not available without MPI\n");
421 : abort();
422 : #endif
423 : }
424 :
425 : /*******************************************************************************
426 : * \brief Wrapper around MPI_Alltoall for datatype MPI_INT.
427 : * \author Ole Schuett
428 : ******************************************************************************/
429 824836 : void dbm_mpi_alltoall_int(const int *sendbuf, const int sendcount, int *recvbuf,
430 : const int recvcount, const dbm_mpi_comm_t comm) {
431 : #if defined(__parallel)
432 824836 : CHECK(MPI_Alltoall(sendbuf, sendcount, MPI_INT, recvbuf, recvcount, MPI_INT,
433 : comm));
434 : #else
435 : (void)comm; // mark used
436 : assert(sendcount == recvcount);
437 : memcpy(recvbuf, sendbuf, sendcount * sizeof(int));
438 : #endif
439 824836 : }
440 :
441 : /*******************************************************************************
442 : * \brief Wrapper around MPI_Alltoallv for datatype MPI_BYTE.
443 : * \author Ole Schuett
444 : ******************************************************************************/
445 412346 : void dbm_mpi_alltoallv_byte(const void *sendbuf, const int *sendcounts,
446 : const int *sdispls, void *recvbuf,
447 : const int *recvcounts, const int *rdispls,
448 : const dbm_mpi_comm_t comm) {
449 : #if defined(__parallel)
450 412346 : CHECK(MPI_Alltoallv(sendbuf, sendcounts, sdispls, MPI_BYTE, recvbuf,
451 : recvcounts, rdispls, MPI_BYTE, comm));
452 : #else
453 : (void)comm; // mark used
454 : assert(sendcounts[0] == recvcounts[0]);
455 : assert(sdispls[0] == 0 && rdispls[0] == 0);
456 : memcpy(recvbuf, sendbuf, sendcounts[0]);
457 : #endif
458 412346 : }
459 :
460 : /*******************************************************************************
461 : * \brief Wrapper around MPI_Alltoallv for datatype MPI_DOUBLE.
462 : * \author Ole Schuett
463 : ******************************************************************************/
464 412490 : void dbm_mpi_alltoallv_double(const double *sendbuf, const int *sendcounts,
465 : const int *sdispls, double *recvbuf,
466 : const int *recvcounts, const int *rdispls,
467 : const dbm_mpi_comm_t comm) {
468 : #if defined(__parallel)
469 412490 : CHECK(MPI_Alltoallv(sendbuf, sendcounts, sdispls, MPI_DOUBLE, recvbuf,
470 : recvcounts, rdispls, MPI_DOUBLE, comm));
471 : #else
472 : (void)comm; // mark used
473 : assert(sendcounts[0] == recvcounts[0]);
474 : assert(sdispls[0] == 0 && rdispls[0] == 0);
475 : memcpy(recvbuf, sendbuf, sendcounts[0] * sizeof(double));
476 : #endif
477 412490 : }
478 :
479 : /*******************************************************************************
480 : * \brief Wrapper around MPI_Alloc_mem.
481 : * \author Hans Pabst
482 : ******************************************************************************/
483 4556919 : void *dbm_mpi_alloc_mem(size_t size) {
484 4556919 : void *result = NULL;
485 : #if DBM_ALLOC_MPI && defined(__parallel)
486 : CHECK(MPI_Alloc_mem((MPI_Aint)size, MPI_INFO_NULL, &result));
487 : #elif DBM_ALLOC_OPENMP && (201811 /*v5.0*/ <= _OPENMP)
488 : result = omp_alloc(size, omp_null_allocator);
489 : #else
490 4556919 : result = malloc(size);
491 : #endif
492 4556919 : return result;
493 : }
494 :
495 : /*******************************************************************************
496 : * \brief Wrapper around MPI_Free_mem.
497 : * \author Hans Pabst
498 : ******************************************************************************/
499 4556919 : void dbm_mpi_free_mem(void *mem) {
500 : #if DBM_ALLOC_MPI && defined(__parallel)
501 : CHECK(MPI_Free_mem(mem));
502 : #elif DBM_ALLOC_OPENMP && (201811 /*v5.0*/ <= _OPENMP)
503 : omp_free(mem, omp_null_allocator);
504 : #else
505 4556919 : free(mem);
506 : #endif
507 4556919 : }
508 :
509 : // EOF
|