LCOV - code coverage report
Current view: top level - src/grid/dgemm - grid_dgemm_utils.h (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 100.0 % 4 4
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 1 1

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6              : /*----------------------------------------------------------------------------*/
       7              : 
       8              : #ifndef GRID_DGEMM_UTILS_H
       9              : #define GRID_DGEMM_UTILS_H
      10              : 
      11              : #include <stdbool.h>
      12              : #include <stdio.h>
      13              : #include <string.h>
      14              : 
      15              : #if defined(__MKL)
      16              : #include <mkl.h>
      17              : #include <mkl_cblas.h>
      18              : #endif
      19              : 
      20              : #if defined(__LIBXSMM)
      21              : #include <libxsmm.h>
      22              : #endif
      23              : 
      24              : #include "../common/grid_common.h"
      25              : #include "grid_dgemm_private_header.h"
      26              : #include "grid_dgemm_tensor_local.h"
      27              : 
      28              : /* inverse of the factorials */
      29              : static const double inv_fac[] = {1.0,
      30              :                                  1.0,
      31              :                                  0.5,
      32              :                                  0.166666666666666666666666666667,
      33              :                                  0.0416666666666666666666666666667,
      34              :                                  0.00833333333333333333333333333333,
      35              :                                  0.00138888888888888888888888888889,
      36              :                                  0.000198412698412698412698412698413,
      37              :                                  0.0000248015873015873015873015873016,
      38              :                                  2.7557319223985890652557319224e-6,
      39              :                                  2.7557319223985890652557319224e-7,
      40              :                                  2.50521083854417187750521083854e-8,
      41              :                                  2.08767569878680989792100903212e-9,
      42              :                                  1.60590438368216145993923771702e-10,
      43              :                                  1.14707455977297247138516979787e-11,
      44              :                                  7.64716373181981647590113198579e-13,
      45              :                                  4.77947733238738529743820749112e-14,
      46              :                                  2.81145725434552076319894558301e-15,
      47              :                                  1.56192069685862264622163643501e-16,
      48              :                                  8.22063524662432971695598123687e-18,
      49              :                                  4.11031762331216485847799061844e-19,
      50              :                                  1.95729410633912612308475743735e-20,
      51              :                                  8.8967913924505732867488974425e-22,
      52              :                                  3.86817017063068403771691193152e-23,
      53              :                                  1.6117375710961183490487133048e-24,
      54              :                                  6.4469502843844733961948532192e-26,
      55              :                                  2.47959626322479746007494354585e-27,
      56              :                                  9.18368986379554614842571683647e-29,
      57              :                                  3.27988923706983791015204172731e-30,
      58              :                                  1.13099628864477169315587645769e-31,
      59              :                                  3.76998762881590564385292152565e-33};
      60              : 
      61              : inline int coset_without_offset(int lx, int ly, int lz) {
      62              :   const int l = lx + ly + lz;
      63              :   if (l == 0) {
      64              :     return 0;
      65              :   } else {
      66              :     return ((l - lx) * (l - lx + 1)) / 2 + lz;
      67              :   }
      68              : }
      69              : 
      70              : typedef struct dgemm_params_ {
      71              :   char storage;
      72              :   char op1;
      73              :   char op2;
      74              :   double alpha;
      75              :   double beta;
      76              :   double *a, *b, *c;
      77              :   int m, n, k, lda, ldb, ldc;
      78              :   int x, y, z;
      79              :   int x1, y1, z1;
      80              :   bool use_libxsmm;
      81              : #if defined(__LIBXSMM)
      82              :   libxsmm_dmmfunction kernel;
      83              :   int prefetch;
      84              :   int flags;
      85              : #endif
      86              : } dgemm_params;
      87              : 
      88              : extern void dgemm_simplified(dgemm_params *const m);
      89              : extern void batched_dgemm_simplified(dgemm_params *const m,
      90              :                                      const int batch_size);
      91              : 
      92              : /*******************************************************************************
      93              :  * \brief Prototype for BLAS dgemm.
      94              :  * \author Ole Schuett
      95              :  ******************************************************************************/
      96              : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
      97              :             const int *k, const double *alpha, const double *a, const int *lda,
      98              :             const double *b, const int *ldb, const double *beta, double *c,
      99              :             const int *ldc);
     100              : 
     101              : extern void extract_sub_grid(const int *lower_corner, const int *upper_corner,
     102              :                              const int *position, const tensor *const grid,
     103              :                              tensor *const subgrid);
     104              : extern void add_sub_grid(const int *lower_corner, const int *upper_corner,
     105              :                          const int *position, const tensor *subgrid,
     106              :                          tensor *grid);
     107              : extern void return_cube_position(const int *lb_grid, const int *cube_center,
     108              :                                  const int *lower_boundaries_cube,
     109              :                                  const int *period, int *const position);
     110              : 
     111              : extern void verify_orthogonality(const double dh[3][3], bool orthogonal[3]);
     112              : 
     113              : extern int compute_cube_properties(const bool ortho, const double radius,
     114              :                                    const double dh[3][3],
     115              :                                    const double dh_inv[3][3], const double *rp,
     116              :                                    double *disr_radius, double *roffset,
     117              :                                    int *cubecenter, int *lb_cube, int *ub_cube,
     118              :                                    int *cube_size);
     119              : 
     120              : inline int return_offset_l(const int l) {
     121              :   static const int offset_[] = {1,   4,   7,   11,  16,  22,  29,
     122              :                                 37,  46,  56,  67,  79,  92,  106,
     123              :                                 121, 137, 154, 172, 191, 211, 232};
     124              :   return offset_[l];
     125              : }
     126              : 
     127              : inline int return_linear_index_from_exponents(const int alpha, const int beta,
     128              :                                               const int gamma) {
     129              :   const int l = alpha + beta + gamma;
     130              :   return return_offset_l(l) + (l - alpha) * (l - alpha + 1) / 2 + gamma;
     131              : }
     132              : 
     133       192586 : static inline void *grid_allocate_scratch(size_t size) {
     134              : #ifdef __LIBXSMM
     135       192586 :   return libxsmm_aligned_scratch(size, 0 /*auto-alignment*/);
     136              : #else
     137              :   return malloc(size);
     138              : #endif
     139              : }
     140              : 
     141       192586 : static inline void grid_free_scratch(void *ptr) {
     142              : #ifdef __LIBXSMM
     143        96374 :   libxsmm_free(ptr);
     144              : #else
     145              :   free(ptr);
     146              : #endif
     147              : }
     148              : 
     149              : /* even openblas and lapack has cblas versions of lapack and blas. */
     150              : #ifndef __MKL
     151              : enum CBLAS_LAYOUT { CblasRowMajor = 101, CblasColMajor = 102 };
     152              : enum CBLAS_TRANSPOSE {
     153              :   CblasNoTrans = 111,
     154              :   CblasTrans = 112,
     155              :   CblasConjTrans = 113
     156              : };
     157              : enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 };
     158              : enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 };
     159              : enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 };
     160              : 
     161              : typedef enum CBLAS_LAYOUT CBLAS_LAYOUT;
     162              : typedef enum CBLAS_TRANSPOSE CBLAS_TRANSPOSE;
     163              : typedef enum CBLAS_UPLO CBLAS_UPLO;
     164              : typedef enum CBLAS_DIAG CBLAS_DIAG;
     165              : 
     166              : double cblas_ddot(const int N, const double *X, const int incX, const double *Y,
     167              :                   const int incY);
     168              : 
     169              : void cblas_dger(const CBLAS_LAYOUT Layout, const int M, const int N,
     170              :                 const double alpha, const double *X, const int incX,
     171              :                 const double *Y, const int incY, double *A, const int lda);
     172              : 
     173              : void cblas_daxpy(const int N, const double alpha, const double *X,
     174              :                  const int incX, double *Y, const int incY);
     175              : 
     176              : void cblas_dgemv(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA,
     177              :                  const int M, const int N, const double alpha, const double *A,
     178              :                  const int lda, const double *X, const int incX,
     179              :                  const double beta, double *Y, const int incY);
     180              : 
     181              : #endif
     182              : 
     183              : extern void compute_interval(const int *const map, const int full_size,
     184              :                              const int size, const int cube_size, const int x1,
     185              :                              int *x, int *const lower_corner,
     186              :                              int *const upper_corner, Interval window);
     187              : #endif
        

Generated by: LCOV version 2.0-1