LCOV - code coverage report
Current view: top level - src/grid/dgemm - grid_dgemm_utils.c (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:85b8a9b) Lines: 71.2 % 184 131
Test Date: 2026-06-14 06:48:14 Functions: 83.3 % 12 10

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : 
       8              : #include <assert.h>
       9              : #include <limits.h>
      10              : #include <math.h>
      11              : #include <stdbool.h>
      12              : #include <stdio.h>
      13              : #include <string.h>
      14              : 
      15              : #ifdef __MKL
      16              : #include <mkl.h>
      17              : #endif
      18              : 
      19              : #if defined(__LIBXSMM)
      20              : #include <libxsmm.h>
      21              : #endif
      22              : #if defined(__LIBXS)
      23              : #include <libxs/libxs_gemm.h>
      24              : #endif
      25              : 
      26              : #include "../common/grid_common.h"
      27              : #include "grid_dgemm_tensor_local.h"
      28              : #include "grid_dgemm_utils.h"
      29              : 
      30            0 : void convert_to_lattice_coordinates(const double dh_inv_[3][3],
      31              :                                     const double *restrict const rp,
      32              :                                     double *restrict rp_c) {
      33            0 :   rp_c[0] =
      34            0 :       dh_inv_[0][0] * rp[0] + dh_inv_[1][0] * rp[1] + dh_inv_[0][0] * rp[2];
      35            0 :   rp_c[1] =
      36            0 :       dh_inv_[0][1] * rp[0] + dh_inv_[1][1] * rp[1] + dh_inv_[1][1] * rp[2];
      37            0 :   rp_c[2] =
      38            0 :       dh_inv_[0][2] * rp[0] + dh_inv_[1][2] * rp[1] + dh_inv_[2][2] * rp[2];
      39            0 : }
      40              : 
      41       230646 : void dgemm_simplified(dgemm_params *const m) {
      42       230646 :   if (m == NULL)
      43            0 :     abort();
      44              : 
      45              : #if defined(__LIBXS)
      46              :   {
      47       230646 :     const char col_transa = (m->op2 == 'N') ? 'N' : 'T';
      48       230646 :     const char col_transb = (m->op1 == 'N') ? 'N' : 'T';
      49       461292 :     const libxs_gemm_config_t *cfg = libxs_gemm_dispatch(
      50              :         LIBXS_DATATYPE_F64, col_transa, col_transb, m->n, m->m, m->k, m->ldb,
      51       230646 :         m->lda, m->ldc, &m->alpha, &m->beta, NULL);
      52       230646 :     if (NULL != cfg) {
      53       230646 :       libxs_gemm_call(cfg, m->b, m->a, m->c);
      54       230646 :       return;
      55              :     }
      56              :   }
      57              : #endif
      58              : 
      59              : #if defined(__MKL)
      60              :   if ((m->op1 == 'N') && (m->op2 == 'N'))
      61              :     cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m->m, m->n, m->k,
      62              :                 m->alpha, m->a, m->lda, m->b, m->ldb, m->beta, m->c, m->ldc);
      63              : 
      64              :   if ((m->op1 == 'T') && (m->op2 == 'N'))
      65              :     cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m->m, m->n, m->k,
      66              :                 m->alpha, m->a, m->lda, m->b, m->ldb, m->beta, m->c, m->ldc);
      67              : 
      68              :   if ((m->op1 == 'N') && (m->op2 == 'T'))
      69              :     cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m->m, m->n, m->k,
      70              :                 m->alpha, m->a, m->lda, m->b, m->ldb, m->beta, m->c, m->ldc);
      71              : 
      72              :   if ((m->op1 == 'T') && (m->op2 == 'T'))
      73              :     cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, m->m, m->n, m->k,
      74              :                 m->alpha, m->a, m->lda, m->b, m->ldb, m->beta, m->c, m->ldc);
      75              : 
      76              : #else
      77              : 
      78            0 :   if ((m->op1 == 'N') && (m->op2 == 'N'))
      79            0 :     dgemm_("N", "N", &m->n, &m->m, &m->k, &m->alpha, m->b, &m->ldb, m->a,
      80            0 :            &m->lda, &m->beta, m->c, &m->ldc);
      81              : 
      82            0 :   if ((m->op1 == 'T') && (m->op2 == 'N'))
      83            0 :     dgemm_("N", "T", &m->n, &m->m, &m->k, &m->alpha, m->b, &m->ldb, m->a,
      84            0 :            &m->lda, &m->beta, m->c, &m->ldc);
      85              : 
      86            0 :   if ((m->op1 == 'T') && (m->op2 == 'T'))
      87            0 :     dgemm_("T", "T", &m->n, &m->m, &m->k, &m->alpha, m->b, &m->ldb, m->a,
      88            0 :            &m->lda, &m->beta, m->c, &m->ldc);
      89              : 
      90            0 :   if ((m->op1 == 'N') && (m->op2 == 'T'))
      91            0 :     dgemm_("T", "N", &m->n, &m->m, &m->k, &m->alpha, m->b, &m->ldb, m->a,
      92            0 :            &m->lda, &m->beta, m->c, &m->ldc);
      93              : 
      94              : #endif
      95              : }
      96              : 
      97       380820 : void extract_sub_grid(const int *lower_corner, const int *upper_corner,
      98              :                       const int *position, const tensor *const grid,
      99              :                       tensor *const subgrid) {
     100       380820 :   int position1[3] = {0, 0, 0};
     101              : 
     102       380820 :   if (position) {
     103       380820 :     position1[0] = position[0];
     104       380820 :     position1[1] = position[1];
     105       380820 :     position1[2] = position[2];
     106              :   }
     107              : 
     108       380820 :   const int sizex = upper_corner[2] - lower_corner[2];
     109       380820 :   const int sizey = upper_corner[1] - lower_corner[1];
     110       380820 :   const int sizez = upper_corner[0] - lower_corner[0];
     111              : 
     112      4974324 :   for (int z = 0; z < sizez; z++) {
     113     61956638 :     for (int y = 0; y < sizey; y++) {
     114     57363134 :       double *restrict src =
     115     57363134 :           &idx3(grid[0], lower_corner[0] + z - grid->window_shift[0],
     116              :                 lower_corner[1] + y - grid->window_shift[1],
     117              :                 lower_corner[2] - grid->window_shift[2]);
     118     57363134 :       double *restrict dst =
     119     57363134 :           &idx3(subgrid[0], position1[0] + z, position1[1] + y, position1[2]);
     120     57363134 :       GRID_PRAGMA_SIMD((dst, src), 8)
     121     57363134 :       for (int x = 0; x < sizex; x++) {
     122    778255433 :         dst[x] = src[x];
     123              :       }
     124              :     }
     125              :   }
     126              : 
     127       380820 :   return;
     128              : }
     129              : 
     130       414166 : void add_sub_grid(const int *lower_corner, const int *upper_corner,
     131              :                   const int *position, const tensor *subgrid, tensor *grid) {
     132       414166 :   int position1[3] = {0, 0, 0};
     133              : 
     134       414166 :   if (position) {
     135       414166 :     position1[0] = position[0];
     136       414166 :     position1[1] = position[1];
     137       414166 :     position1[2] = position[2];
     138              :   }
     139              : 
     140       414166 :   const int sizex = upper_corner[2] - lower_corner[2];
     141       414166 :   const int sizey = upper_corner[1] - lower_corner[1];
     142       414166 :   const int sizez = upper_corner[0] - lower_corner[0];
     143              : 
     144      5315394 :   for (int z = 0; z < sizez; z++) {
     145      4901228 :     double *restrict dst =
     146      4901228 :         &idx3(grid[0], lower_corner[0] + z, lower_corner[1], lower_corner[2]);
     147      4901228 :     double *restrict src =
     148      4901228 :         &idx3(subgrid[0], position1[0] + z, position1[1], position1[2]);
     149     59167776 :     for (int y = 0; y < sizey - 1; y++) {
     150              :       GRID_PRAGMA_SIMD((dst, src), 8)
     151              :       for (int x = 0; x < sizex; x++) {
     152    731179166 :         dst[x] += src[x];
     153              :       }
     154              : 
     155     54266548 :       dst += grid->ld_;
     156     54266548 :       src += subgrid->ld_;
     157              :     }
     158              : 
     159              :     // #pragma omp simd linear(dst, src) simdlen(8)
     160      4901228 :     GRID_PRAGMA_SIMD((dst, src), 8)
     161              :     for (int x = 0; x < sizex; x++) {
     162     59381511 :       dst[x] += src[x];
     163              :     }
     164              :   }
     165       414166 :   return;
     166              : }
     167              : 
     168        90334 : int compute_cube_properties(const bool ortho, const double radius,
     169              :                             const double dh[3][3], const double dh_inv[3][3],
     170              :                             const double *rp, double *disr_radius,
     171              :                             double *roffset, int *cubecenter, int *lb_cube,
     172              :                             int *ub_cube, int *cube_size) {
     173        90334 :   int cmax = 0;
     174              : 
     175              :   /* center of the gaussian in the lattice coordinates */
     176        90334 :   double rp1[3];
     177              : 
     178              :   /* it is in the lattice vector frame */
     179       361336 :   for (int i = 0; i < 3; i++) {
     180              :     double dh_inv_rp = 0.0;
     181      1084008 :     for (int j = 0; j < 3; j++) {
     182       813006 :       dh_inv_rp += dh_inv[j][i] * rp[j];
     183              :     }
     184       271002 :     rp1[2 - i] = dh_inv_rp;
     185       271002 :     cubecenter[2 - i] = floor(dh_inv_rp);
     186              :   }
     187              : 
     188        90334 :   if (ortho) {
     189              :     /* seting up the cube parameters */
     190        42228 :     const double dx[3] = {dh[2][2], dh[1][1], dh[0][0]};
     191        42228 :     const double dx_inv[3] = {dh_inv[2][2], dh_inv[1][1], dh_inv[0][0]};
     192              :     /* cube center */
     193              : 
     194              :     /* lower and upper bounds */
     195              : 
     196              :     // Historically, the radius gets discretized.
     197        42228 :     const double drmin = fmin(dh[0][0], fmin(dh[1][1], dh[2][2]));
     198        42228 :     *disr_radius = drmin * fmax(1.0, ceil(radius / drmin));
     199              : 
     200       168912 :     for (int i = 0; i < 3; i++) {
     201       126684 :       roffset[i] = rp[2 - i] - ((double)cubecenter[i]) * dx[i];
     202              :     }
     203              : 
     204       168912 :     for (int i = 0; i < 3; i++) {
     205       126684 :       lb_cube[i] = ceil(-1e-8 - *disr_radius * dx_inv[i]);
     206              :     }
     207              : 
     208              :     // Symmetric interval
     209       168912 :     for (int i = 0; i < 3; i++) {
     210       126684 :       ub_cube[i] = 1 - lb_cube[i];
     211              :     }
     212              : 
     213              :   } else {
     214       192424 :     for (int idir = 0; idir < 3; idir++) {
     215       144318 :       lb_cube[idir] = INT_MAX;
     216       144318 :       ub_cube[idir] = INT_MIN;
     217              :     }
     218              : 
     219              :     /* compute the size of the box. It is a fairly trivial way to compute
     220              :      * the box and it may have far more point than needed */
     221       192424 :     for (int i = -1; i <= 1; i++) {
     222       577272 :       for (int j = -1; j <= 1; j++) {
     223      1731816 :         for (int k = -1; k <= 1; k++) {
     224      1298862 :           double x[3] = {/* rp[0] + */ ((double)i) * radius,
     225      1298862 :                          /* rp[1] + */ ((double)j) * radius,
     226      1298862 :                          /* rp[2] + */ ((double)k) * radius};
     227              :           /* convert_to_lattice_coordinates(dh_inv, x, y); */
     228      5195448 :           for (int idir = 0; idir < 3; idir++) {
     229      3896586 :             const double resc = dh_inv[0][idir] * x[0] +
     230      3896586 :                                 dh_inv[1][idir] * x[1] + dh_inv[2][idir] * x[2];
     231      3896586 :             lb_cube[2 - idir] = imin(lb_cube[2 - idir], floor(resc));
     232      3896586 :             ub_cube[2 - idir] = imax(ub_cube[2 - idir], ceil(resc));
     233              :           }
     234              :         }
     235              :       }
     236              :     }
     237              : 
     238              :     /* compute the offset in lattice coordinates */
     239              : 
     240       192424 :     for (int i = 0; i < 3; i++) {
     241       144318 :       roffset[i] = rp1[i] - cubecenter[i];
     242              :     }
     243              : 
     244        48106 :     *disr_radius = radius;
     245              :   }
     246              : 
     247              :   /* compute the cube size ignoring periodicity */
     248              : 
     249              :   /* the +1 is normal here */
     250        90334 :   cube_size[0] = ub_cube[0] - lb_cube[0] + 1;
     251        90334 :   cube_size[1] = ub_cube[1] - lb_cube[1] + 1;
     252        90334 :   cube_size[2] = ub_cube[2] - lb_cube[2] + 1;
     253              : 
     254       361336 :   for (int i = 0; i < 3; i++) {
     255       271002 :     cmax = imax(cmax, cube_size[i]);
     256              :   }
     257              : 
     258        90334 :   return cmax;
     259              : }
     260              : 
     261            0 : void return_cube_position(const int *const lb_grid,
     262              :                           const int *const cube_center,
     263              :                           const int *const lower_boundaries_cube,
     264              :                           const int *const period, int *const position) {
     265            0 :   for (int i = 0; i < 3; i++)
     266            0 :     position[i] = modulo(cube_center[i] - lb_grid[i] + lower_boundaries_cube[i],
     267            0 :                          period[i]);
     268            0 : }
     269              : 
     270         1208 : void verify_orthogonality(const double dh[3][3], bool orthogonal[3]) {
     271         1208 :   double norm1, norm2, norm3;
     272              : 
     273         1208 :   norm1 = dh[0][0] * dh[0][0] + dh[0][1] * dh[0][1] + dh[0][2] * dh[0][2];
     274         1208 :   norm2 = dh[1][0] * dh[1][0] + dh[1][1] * dh[1][1] + dh[1][2] * dh[1][2];
     275         1208 :   norm3 = dh[2][0] * dh[2][0] + dh[2][1] * dh[2][1] + dh[2][2] * dh[2][2];
     276              : 
     277         1208 :   norm1 = 1.0 / sqrt(norm1);
     278         1208 :   norm2 = 1.0 / sqrt(norm2);
     279         1208 :   norm3 = 1.0 / sqrt(norm3);
     280              : 
     281              :   /* x z */
     282         1208 :   orthogonal[0] =
     283         1208 :       ((fabs(dh[0][0] * dh[2][0] + dh[0][1] * dh[2][1] + dh[0][2] * dh[2][2]) *
     284         1208 :         norm1 * norm3) < 1e-12);
     285              :   /* y z */
     286         1208 :   orthogonal[1] =
     287         1208 :       ((fabs(dh[1][0] * dh[2][0] + dh[1][1] * dh[2][1] + dh[1][2] * dh[2][2]) *
     288         1208 :         norm2 * norm3) < 1e-12);
     289              :   /* x y */
     290         1208 :   orthogonal[2] =
     291         1208 :       ((fabs(dh[0][0] * dh[1][0] + dh[0][1] * dh[1][1] + dh[0][2] * dh[1][2]) *
     292         1208 :         norm1 * norm2) < 1e-12);
     293         1208 : }
     294              : 
     295              : #ifndef __MKL
     296              : extern void dger_(const int *M, const int *N, const double *alpha,
     297              :                   const double *X, const int *incX, const double *Y,
     298              :                   const int *incY, double *A, const int *lda);
     299              : extern void dgemv_(const char *Trans, const int *M, const int *N,
     300              :                    const double *alpha, const double *A, const int *lda,
     301              :                    const double *X, const int *incX, const double *beta,
     302              :                    double *Y, const int *incY);
     303              : 
     304           76 : void cblas_daxpy(const int N, const double alpha, const double *X,
     305              :                  const int incX, double *Y, const int incY) {
     306           76 :   if ((incX == 1) && (incY == 1)) {
     307          760 :     for (int i = 0; i < N; i++)
     308          684 :       Y[i] += alpha * X[i];
     309              :     return;
     310              :   }
     311              : 
     312            0 :   if (incX == 1) {
     313            0 :     for (int i = 0; i < N; i++)
     314            0 :       Y[i + incY] += alpha * X[i];
     315              :     return;
     316              :   }
     317              : 
     318            0 :   if (incY == 1) {
     319            0 :     for (int i = 0; i < N; i++)
     320            0 :       Y[i] += alpha * X[i + incX];
     321              :     return;
     322              :   }
     323              : 
     324            0 :   for (int i = 0; i < N; i++)
     325            0 :     Y[i + incY] += alpha * X[i + incX];
     326              :   return;
     327              : }
     328              : 
     329         9280 : double cblas_ddot(const int N, const double *X, const int incX, const double *Y,
     330              :                   const int incY) {
     331         9280 :   if ((incX == incY) && (incY == 1)) {
     332              :     double res = 0.0;
     333              : 
     334       230673 :     for (int i = 0; i < N; i++) {
     335       221393 :       res += X[i] * Y[i];
     336              :     }
     337              :     return res;
     338              :   }
     339              : 
     340            0 :   if (incX == 1) {
     341              :     double res = 0.0;
     342              : 
     343            0 :     for (int i = 0; i < N; i++) {
     344            0 :       res += X[i] * Y[i + incY];
     345              :     }
     346              :     return res;
     347              :   }
     348              : 
     349            0 :   if (incY == 1) {
     350              :     double res = 0.0;
     351              : 
     352            0 :     for (int i = 0; i < N; i++) {
     353            0 :       res += X[i + incX] * Y[i];
     354              :     }
     355              :     return res;
     356              :   }
     357              : 
     358              :   double res = 0.0;
     359              : 
     360            0 :   for (int i = 0; i < N; i++) {
     361            0 :     res += X[i + incX] * Y[i + incY];
     362              :   }
     363              :   return res;
     364              : }
     365              : 
     366        69226 : void cblas_dger(const CBLAS_LAYOUT Layout, const int M, const int N,
     367              :                 const double alpha, const double *X, const int incX,
     368              :                 const double *Y, const int incY, double *A, const int lda) {
     369        69226 :   if (Layout == CblasRowMajor) {
     370        69226 :     dger_(&N, &M, &alpha, Y, &incY, X, &incX, A, &lda);
     371              :   } else {
     372            0 :     dger_(&N, &M, &alpha, X, &incX, Y, &incY, A, &lda);
     373              :   }
     374        69226 : }
     375              : 
     376              : /* code taken from gsl_cblas. We really need to use a proper cblas interface and
     377              :  * build system.... */
     378        18560 : void cblas_dgemv(const CBLAS_LAYOUT order, const CBLAS_TRANSPOSE TransA,
     379              :                  const int M, const int N, const double alpha, const double *A,
     380              :                  const int lda, const double *X, const int incX,
     381              :                  const double beta, double *Y, const int incY) {
     382              : 
     383        18560 :   if (order == CblasColMajor) {
     384            0 :     if (TransA == CblasTrans)
     385            0 :       dgemv_("T", &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
     386              :     else {
     387            0 :       dgemv_("N", &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
     388              :     }
     389              :   } else {
     390        18560 :     if (TransA == CblasTrans)
     391            0 :       dgemv_("N", &N, &M, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
     392              :     else {
     393        18560 :       dgemv_("T", &N, &M, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
     394              :     }
     395              :   }
     396        18560 : }
     397              : #endif
     398              : 
     399      1321626 : void compute_interval(const int *const map, const int full_size, const int size,
     400              :                       const int cube_size, const int x1, int *x,
     401              :                       int *const lower_corner, int *const upper_corner,
     402              :                       const Interval window) {
     403      1321626 :   if (size == full_size) {
     404              :     /* we have the full grid in that direction */
     405              :     /* lower boundary is within the window */
     406      1321626 :     *lower_corner = x1;
     407              :     /* now compute the upper corner */
     408              :     /* needs to be as large as possible. basically I take [x1..
     409              :      * min(grid.full_size, cube_size - x)] */
     410              : 
     411      1321626 :     *upper_corner = compute_next_boundaries(x1, *x, full_size, cube_size);
     412              : 
     413              :     /* { */
     414              :     /*   Interval tz = create_interval(*lower_corner, *upper_corner); */
     415              :     /*   Interval res = intersection_interval(tz, window); */
     416              :     /*   *lower_corner = res.xmin; */
     417              :     /*   *upper_corner = res.xmax; */
     418              :     /* } */
     419              :   } else {
     420            0 :     *lower_corner = x1;
     421            0 :     *upper_corner = x1 + 1;
     422              : 
     423              :     // the map is always increasing by 1 except when we cross the boundaries of
     424              :     // the grid and pbc are applied. Since we are only interested in by a
     425              :     // subwindow of the full table we check that the next point is inside the
     426              :     // window of interest and is also equal to the previous point + 1. The last
     427              :     // check is pointless in practice.
     428              : 
     429            0 :     for (int i = *x + 1; (i < cube_size) && (*upper_corner == map[i]) &&
     430            0 :                          is_point_in_interval(map[i], window);
     431            0 :          i++) {
     432            0 :       (*upper_corner)++;
     433              :     }
     434              :   }
     435      1321626 : }
        

Generated by: LCOV version 2.0-1