LCOV - code coverage report
Current view: top level - src/common - gemm_square_unittest.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 100.0 % 132 132
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 4 4

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7            2 : PROGRAM gemm_square_unittest
       8            2 :    USE kinds,                           ONLY: dp
       9              :    USE mathlib,                         ONLY: gemm_square
      10              : #include "../base/base_uses.f90"
      11              : 
      12              :    IMPLICIT NONE
      13              : 
      14              :    COMPLEX(kind=dp), DIMENSION(3, 3) :: A_in, B_in, C_in, res_c, res_c_ref
      15              :    REAL(kind=dp), DIMENSION(3, 3) :: X_in, Y_in, Z_in, res_r, res_r_ref
      16              :    REAL(kind=dp) :: tolerance = 1.0e-6_dp
      17              : 
      18              :    ! Prepare inputs
      19            2 :    A_in(1, 1) = CMPLX(0.8815928086074307_dp, 0.6726432190297216_dp, kind=dp)
      20            2 :    A_in(1, 2) = CMPLX(0.7660079579530265_dp, 0.6663301208376479_dp, kind=dp)
      21            2 :    A_in(1, 3) = CMPLX(0.8910730680466552_dp, 0.6447684662974965_dp, kind=dp)
      22            2 :    A_in(2, 1) = CMPLX(0.270178070784315_dp, 0.9380895276020503_dp, kind=dp)
      23            2 :    A_in(2, 2) = CMPLX(0.4365740872106577_dp, 0.5843460996868933_dp, kind=dp)
      24            2 :    A_in(2, 3) = CMPLX(0.07466461985206008_dp, 0.6899750234684598_dp, kind=dp)
      25            2 :    A_in(3, 1) = CMPLX(0.840974290337725_dp, 0.8395064317543346_dp, kind=dp)
      26            2 :    A_in(3, 2) = CMPLX(0.5872667752635958_dp, 0.6233467352665024_dp, kind=dp)
      27            2 :    A_in(3, 3) = CMPLX(0.5024930933188588_dp, 0.7727803824712417_dp, kind=dp)
      28              : 
      29            2 :    B_in(1, 1) = CMPLX(0.08269815253296364_dp, 0.34184561260312574_dp, kind=dp)
      30            2 :    B_in(1, 2) = CMPLX(0.9876346392802493_dp, 0.26436123295003866_dp, kind=dp)
      31            2 :    B_in(1, 3) = CMPLX(0.780810836185207_dp, 0.376036133357872_dp, kind=dp)
      32            2 :    B_in(2, 1) = CMPLX(0.4787818411690774_dp, 0.7596241356044092_dp, kind=dp)
      33            2 :    B_in(2, 2) = CMPLX(0.4298758196722595_dp, 0.4813479548810141_dp, kind=dp)
      34            2 :    B_in(2, 3) = CMPLX(0.2086685419449945_dp, 0.3860478932514133_dp, kind=dp)
      35            2 :    B_in(3, 1) = CMPLX(0.34008386216308817_dp, 0.8353095227337101_dp, kind=dp)
      36            2 :    B_in(3, 2) = CMPLX(0.7379600798045334_dp, 0.7634442366598211_dp, kind=dp)
      37            2 :    B_in(3, 3) = CMPLX(0.9840849653895581_dp, 0.9273454280026875_dp, kind=dp)
      38              : 
      39            2 :    C_in(1, 1) = CMPLX(0.731192921191078_dp, 0.9732725403607281_dp, kind=dp)
      40            2 :    C_in(1, 2) = CMPLX(0.07386957805916261_dp, 0.14228952898305391_dp, kind=dp)
      41            2 :    C_in(1, 3) = CMPLX(0.12229506342104235_dp, 0.6298697123856768_dp, kind=dp)
      42            2 :    C_in(2, 1) = CMPLX(0.007352653494114958_dp, 0.29359318766569575_dp, kind=dp)
      43            2 :    C_in(2, 2) = CMPLX(0.29087841717040863_dp, 0.48194825561460775_dp, kind=dp)
      44            2 :    C_in(2, 3) = CMPLX(0.22558916232632764_dp, 0.9229223568661166_dp, kind=dp)
      45            2 :    C_in(3, 1) = CMPLX(0.5728946948517463_dp, 0.9149335302204014_dp, kind=dp)
      46            2 :    C_in(3, 2) = CMPLX(0.20475976494474424_dp, 0.6082208447082643_dp, kind=dp)
      47            2 :    C_in(3, 3) = CMPLX(0.9060121198373113_dp, 0.008565705864987172_dp, kind=dp)
      48              : 
      49            2 :    X_in(1, 1) = 0.42929014430726375_dp
      50            2 :    X_in(1, 2) = 0.21820709659663573_dp
      51            2 :    X_in(1, 3) = 0.5394292090282415_dp
      52            2 :    X_in(2, 1) = 0.7828031363115503_dp
      53            2 :    X_in(2, 2) = 0.1422677264194132_dp
      54            2 :    X_in(2, 3) = 0.25344520034350637_dp
      55            2 :    X_in(3, 1) = 0.5044049742159297_dp
      56            2 :    X_in(3, 2) = 0.6969177100349894_dp
      57            2 :    X_in(3, 3) = 0.6999162742203425_dp
      58              : 
      59            2 :    Y_in(1, 1) = 0.5331333823513378_dp
      60            2 :    Y_in(1, 2) = 0.8001773249628732_dp
      61            2 :    Y_in(1, 3) = 0.2850504760853374_dp
      62            2 :    Y_in(2, 1) = 0.23062673571851455_dp
      63            2 :    Y_in(2, 2) = 0.5013417881822918_dp
      64            2 :    Y_in(2, 3) = 0.07530315834987644_dp
      65            2 :    Y_in(3, 1) = 0.2267846125008932_dp
      66            2 :    Y_in(3, 2) = 0.19831160340777076_dp
      67            2 :    Y_in(3, 3) = 0.3050258528838238_dp
      68              : 
      69            2 :    Z_in(1, 1) = 0.5400800562659297_dp
      70            2 :    Z_in(1, 2) = 0.506259700373107_dp
      71            2 :    Z_in(1, 3) = 0.24342576996957088_dp
      72            2 :    Z_in(2, 1) = 0.3517364012861689_dp
      73            2 :    Z_in(2, 2) = 0.04901381134580918_dp
      74            2 :    Z_in(2, 3) = 0.31263102401008236_dp
      75            2 :    Z_in(3, 1) = 0.20684120795408456_dp
      76            2 :    Z_in(3, 2) = 0.8051322416754273_dp
      77            2 :    Z_in(3, 3) = 0.5860282518273413_dp
      78              : 
      79              :    ! Test X * Y
      80              : 
      81            2 :    CALL gemm_square(X_in, 'N', Y_in, 'N', res_r)
      82              : 
      83            2 :    res_r_ref(1, 1) = 0.4015275411844552_dp
      84            2 :    res_r_ref(1, 2) = 0.5598796466739117_dp
      85            2 :    res_r_ref(1, 3) = 0.3033408981158978_dp
      86            2 :    res_r_ref(2, 1) = 0.5076266966693295_dp
      87            2 :    res_r_ref(2, 2) = 0.7479672000061859_dp
      88            2 :    res_r_ref(2, 3) = 0.31115895421143025_dp
      89            2 :    res_r_ref(3, 1) = 0.588373227540499_dp
      90            2 :    res_r_ref(3, 2) = 0.8918089125227482_dp
      91            2 :    res_r_ref(3, 3) = 0.4097535412069894_dp
      92              : 
      93            2 :    CALL check_ref_r(res_r, res_r_ref, tolerance)
      94              : 
      95              :    ! Test A * B
      96              : 
      97            2 :    CALL gemm_square(A_in, 'N', B_in, 'N', res_c)
      98              : 
      99            2 :    res_c_ref(1, 1) = CMPLX(-0.5319854477298264_dp, 2.221497049670068_dp, kind=dp)
     100            2 :    res_c_ref(1, 2) = CMPLX(0.8667540454951561_dp, 2.708638262772094_dp, kind=dp)
     101            2 :    res_c_ref(1, 3) = CMPLX(0.6169940071822635_dp, 2.7523152479106017_dp, kind=dp)
     102            2 :    res_c_ref(2, 1) = CMPLX(-1.0841486927833452_dp, 1.0783614128260843_dp, kind=dp)
     103            2 :    res_c_ref(2, 2) = CMPLX(-0.5464163853751492_dp, 2.025430919812893_dp, kind=dp)
     104            2 :    res_c_ref(2, 3) = CMPLX(-0.8426527494261556_dp, 1.872774281691093_dp, kind=dp)
     105            2 :    res_c_ref(3, 1) = CMPLX(-0.884392147923063_dp, 1.78400551956788_dp, kind=dp)
     106            2 :    res_c_ref(3, 2) = CMPLX(0.3418926087394045_dp, 2.5559945109530244_dp, kind=dp)
     107            2 :    res_c_ref(3, 3) = CMPLX(0.000721037947416292_dp, 2.5549846237243488_dp, kind=dp)
     108              : 
     109            2 :    CALL check_ref_c(res_c, res_c_ref, tolerance)
     110              : 
     111              :    ! Test X * Y * Z
     112              : 
     113            2 :    CALL gemm_square(X_in, 'N', Y_in, 'N', Z_in, 'N', res_r)
     114              : 
     115            2 :    res_r_ref(1, 1) = 0.4765304668978436_dp
     116            2 :    res_r_ref(1, 2) = 0.47494858536191625_dp
     117            2 :    res_r_ref(1, 3) = 0.45054423436947794_dp
     118            2 :    res_r_ref(2, 1) = 0.6016068400643494_dp
     119            2 :    res_r_ref(2, 2) = 0.5441757689127916_dp
     120            2 :    res_r_ref(2, 3) = 0.5397551091346775_dp
     121            2 :    res_r_ref(3, 1) = 0.7162042207878401_dp
     122            2 :    res_r_ref(3, 2) = 0.6714863948435401_dp
     123            2 :    res_r_ref(3, 3) = 0.6621594909204267_dp
     124              : 
     125            2 :    CALL check_ref_r(res_r, res_r_ref, tolerance)
     126              : 
     127              :    ! Test A * B * C
     128              : 
     129            2 :    CALL gemm_square(A_in, 'N', B_in, 'N', C_in, 'N', res_c)
     130              : 
     131            2 :    res_c_ref(1, 1) = CMPLX(-5.504683782712595_dp, 3.5222601599484484_dp, kind=dp)
     132            2 :    res_c_ref(1, 2) = CMPLX(-2.9563767075011937_dp, 2.232852141215477_dp, kind=dp)
     133            2 :    res_c_ref(1, 3) = CMPLX(-3.233216866606864_dp, 3.8464986862655572_dp, kind=dp)
     134            2 :    res_c_ref(2, 1) = CMPLX(-4.63714700646176_dp, -0.11028256160215588_dp, kind=dp)
     135            2 :    res_c_ref(2, 2) = CMPLX(-2.680220510420048_dp, 0.12215466665130881_dp, kind=dp)
     136            2 :    res_c_ref(2, 3) = CMPLX(-3.583889556370547_dp, 1.091159499729193_dp, kind=dp)
     137            2 :    res_c_ref(3, 1) = CMPLX(-5.468121643036268_dp, 2.0272651357548415_dp, kind=dp)
     138            2 :    res_c_ref(3, 2) = CMPLX(-3.0054301614279586_dp, 1.4377987781538424_dp, kind=dp)
     139            2 :    res_c_ref(3, 3) = CMPLX(-3.534937025955706_dp, 2.8681214444912606_dp, kind=dp)
     140              : 
     141            2 :    CALL check_ref_c(res_c, res_c_ref, tolerance)
     142              : 
     143              :    ! Test X^T * Y * Z
     144              : 
     145            2 :    CALL gemm_square(X_in, 'T', Y_in, 'N', Z_in, 'N', res_r)
     146              : 
     147            2 :    res_r_ref(1, 1) = 0.6462671475738445_dp
     148            2 :    res_r_ref(1, 2) = 0.5760105624808499_dp
     149            2 :    res_r_ref(1, 3) = 0.5852827099256332_dp
     150            2 :    res_r_ref(2, 1) = 0.36007554144181786_dp
     151            2 :    res_r_ref(2, 2) = 0.40420627668516457_dp
     152            2 :    res_r_ref(2, 3) = 0.36217776140102986_dp
     153            2 :    res_r_ref(3, 1) = 0.5978645617240842_dp
     154            2 :    res_r_ref(3, 2) = 0.600788264751924_dp
     155            2 :    res_r_ref(3, 3) = 0.5673424971463421_dp
     156              : 
     157            2 :    CALL check_ref_r(res_r, res_r_ref, tolerance)
     158              : 
     159              :    ! Test A^H * B * C
     160              : 
     161            2 :    CALL gemm_square(A_in, 'C', B_in, 'N', C_in, 'N', res_c)
     162              : 
     163            2 :    res_c_ref(1, 1) = CMPLX(3.375089298965469_dp, 5.744913993063936_dp, kind=dp)
     164            2 :    res_c_ref(1, 2) = CMPLX(2.0725172551868294_dp, 3.258926327791143_dp, kind=dp)
     165            2 :    res_c_ref(1, 3) = CMPLX(3.965529787950442_dp, 3.621340775428089_dp, kind=dp)
     166            2 :    res_c_ref(2, 1) = CMPLX(2.4231309591599897_dp, 4.665551869666368_dp, kind=dp)
     167            2 :    res_c_ref(2, 2) = CMPLX(1.5937647760286848_dp, 2.6021783330446246_dp, kind=dp)
     168            2 :    res_c_ref(2, 3) = CMPLX(2.9609793918714686_dp, 2.92153954960111_dp, kind=dp)
     169            2 :    res_c_ref(3, 1) = CMPLX(3.278689562249669_dp, 4.308656958132163_dp, kind=dp)
     170            2 :    res_c_ref(3, 2) = CMPLX(2.05357432643753_dp, 2.5060755291807237_dp, kind=dp)
     171            2 :    res_c_ref(3, 3) = CMPLX(3.646272530313196_dp, 2.5667051324585874_dp, kind=dp)
     172              : 
     173            2 :    CALL check_ref_c(res_c, res_c_ref, tolerance)
     174              : 
     175              : CONTAINS
     176              : ! **************************************************************************************************
     177              : !> \brief ...
     178              : !> \param mat ...
     179              : !> \param ref ...
     180              : !> \param tolerance ...
     181              : ! **************************************************************************************************
     182            6 :    SUBROUTINE check_ref_r(mat, ref, tolerance)
     183              :       REAL(kind=dp), DIMENSION(3, 3)                     :: mat, ref
     184              :       REAL(kind=dp)                                      :: tolerance
     185              : 
     186              :       INTEGER                                            :: i, j
     187              : 
     188           24 :       DO i = 1, 3
     189           78 :          DO j = 1, 3
     190           72 :             CPASSERT(ABS(mat(i, j) - ref(i, j)) <= tolerance)
     191              :          END DO
     192              :       END DO
     193            6 :    END SUBROUTINE check_ref_r
     194              : ! **************************************************************************************************
     195              : !> \brief ...
     196              : !> \param mat ...
     197              : !> \param ref ...
     198              : !> \param tolerance ...
     199              : ! **************************************************************************************************
     200            6 :    SUBROUTINE check_ref_c(mat, ref, tolerance)
     201              :       COMPLEX(kind=dp), DIMENSION(3, 3)                  :: mat, ref
     202              :       REAL(kind=dp)                                      :: tolerance
     203              : 
     204              :       INTEGER                                            :: i, j
     205              : 
     206           24 :       DO i = 1, 3
     207           78 :          DO j = 1, 3
     208           72 :             CPASSERT(ABS(mat(i, j) - ref(i, j)) <= tolerance)
     209              :          END DO
     210              :       END DO
     211            6 :    END SUBROUTINE check_ref_c
     212              : END PROGRAM gemm_square_unittest
        

Generated by: LCOV version 2.0-1