Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 2 : PROGRAM gemm_square_unittest
8 2 : USE kinds, ONLY: dp
9 : USE mathlib, ONLY: gemm_square
10 : #include "../base/base_uses.f90"
11 :
12 : IMPLICIT NONE
13 :
14 : COMPLEX(kind=dp), DIMENSION(3, 3) :: A_in, B_in, C_in, res_c, res_c_ref
15 : REAL(kind=dp), DIMENSION(3, 3) :: X_in, Y_in, Z_in, res_r, res_r_ref
16 : REAL(kind=dp) :: tolerance = 1.0e-6_dp
17 :
18 : ! Prepare inputs
19 2 : A_in(1, 1) = CMPLX(0.8815928086074307_dp, 0.6726432190297216_dp, kind=dp)
20 2 : A_in(1, 2) = CMPLX(0.7660079579530265_dp, 0.6663301208376479_dp, kind=dp)
21 2 : A_in(1, 3) = CMPLX(0.8910730680466552_dp, 0.6447684662974965_dp, kind=dp)
22 2 : A_in(2, 1) = CMPLX(0.270178070784315_dp, 0.9380895276020503_dp, kind=dp)
23 2 : A_in(2, 2) = CMPLX(0.4365740872106577_dp, 0.5843460996868933_dp, kind=dp)
24 2 : A_in(2, 3) = CMPLX(0.07466461985206008_dp, 0.6899750234684598_dp, kind=dp)
25 2 : A_in(3, 1) = CMPLX(0.840974290337725_dp, 0.8395064317543346_dp, kind=dp)
26 2 : A_in(3, 2) = CMPLX(0.5872667752635958_dp, 0.6233467352665024_dp, kind=dp)
27 2 : A_in(3, 3) = CMPLX(0.5024930933188588_dp, 0.7727803824712417_dp, kind=dp)
28 :
29 2 : B_in(1, 1) = CMPLX(0.08269815253296364_dp, 0.34184561260312574_dp, kind=dp)
30 2 : B_in(1, 2) = CMPLX(0.9876346392802493_dp, 0.26436123295003866_dp, kind=dp)
31 2 : B_in(1, 3) = CMPLX(0.780810836185207_dp, 0.376036133357872_dp, kind=dp)
32 2 : B_in(2, 1) = CMPLX(0.4787818411690774_dp, 0.7596241356044092_dp, kind=dp)
33 2 : B_in(2, 2) = CMPLX(0.4298758196722595_dp, 0.4813479548810141_dp, kind=dp)
34 2 : B_in(2, 3) = CMPLX(0.2086685419449945_dp, 0.3860478932514133_dp, kind=dp)
35 2 : B_in(3, 1) = CMPLX(0.34008386216308817_dp, 0.8353095227337101_dp, kind=dp)
36 2 : B_in(3, 2) = CMPLX(0.7379600798045334_dp, 0.7634442366598211_dp, kind=dp)
37 2 : B_in(3, 3) = CMPLX(0.9840849653895581_dp, 0.9273454280026875_dp, kind=dp)
38 :
39 2 : C_in(1, 1) = CMPLX(0.731192921191078_dp, 0.9732725403607281_dp, kind=dp)
40 2 : C_in(1, 2) = CMPLX(0.07386957805916261_dp, 0.14228952898305391_dp, kind=dp)
41 2 : C_in(1, 3) = CMPLX(0.12229506342104235_dp, 0.6298697123856768_dp, kind=dp)
42 2 : C_in(2, 1) = CMPLX(0.007352653494114958_dp, 0.29359318766569575_dp, kind=dp)
43 2 : C_in(2, 2) = CMPLX(0.29087841717040863_dp, 0.48194825561460775_dp, kind=dp)
44 2 : C_in(2, 3) = CMPLX(0.22558916232632764_dp, 0.9229223568661166_dp, kind=dp)
45 2 : C_in(3, 1) = CMPLX(0.5728946948517463_dp, 0.9149335302204014_dp, kind=dp)
46 2 : C_in(3, 2) = CMPLX(0.20475976494474424_dp, 0.6082208447082643_dp, kind=dp)
47 2 : C_in(3, 3) = CMPLX(0.9060121198373113_dp, 0.008565705864987172_dp, kind=dp)
48 :
49 2 : X_in(1, 1) = 0.42929014430726375_dp
50 2 : X_in(1, 2) = 0.21820709659663573_dp
51 2 : X_in(1, 3) = 0.5394292090282415_dp
52 2 : X_in(2, 1) = 0.7828031363115503_dp
53 2 : X_in(2, 2) = 0.1422677264194132_dp
54 2 : X_in(2, 3) = 0.25344520034350637_dp
55 2 : X_in(3, 1) = 0.5044049742159297_dp
56 2 : X_in(3, 2) = 0.6969177100349894_dp
57 2 : X_in(3, 3) = 0.6999162742203425_dp
58 :
59 2 : Y_in(1, 1) = 0.5331333823513378_dp
60 2 : Y_in(1, 2) = 0.8001773249628732_dp
61 2 : Y_in(1, 3) = 0.2850504760853374_dp
62 2 : Y_in(2, 1) = 0.23062673571851455_dp
63 2 : Y_in(2, 2) = 0.5013417881822918_dp
64 2 : Y_in(2, 3) = 0.07530315834987644_dp
65 2 : Y_in(3, 1) = 0.2267846125008932_dp
66 2 : Y_in(3, 2) = 0.19831160340777076_dp
67 2 : Y_in(3, 3) = 0.3050258528838238_dp
68 :
69 2 : Z_in(1, 1) = 0.5400800562659297_dp
70 2 : Z_in(1, 2) = 0.506259700373107_dp
71 2 : Z_in(1, 3) = 0.24342576996957088_dp
72 2 : Z_in(2, 1) = 0.3517364012861689_dp
73 2 : Z_in(2, 2) = 0.04901381134580918_dp
74 2 : Z_in(2, 3) = 0.31263102401008236_dp
75 2 : Z_in(3, 1) = 0.20684120795408456_dp
76 2 : Z_in(3, 2) = 0.8051322416754273_dp
77 2 : Z_in(3, 3) = 0.5860282518273413_dp
78 :
79 : ! Test X * Y
80 :
81 2 : CALL gemm_square(X_in, 'N', Y_in, 'N', res_r)
82 :
83 2 : res_r_ref(1, 1) = 0.4015275411844552_dp
84 2 : res_r_ref(1, 2) = 0.5598796466739117_dp
85 2 : res_r_ref(1, 3) = 0.3033408981158978_dp
86 2 : res_r_ref(2, 1) = 0.5076266966693295_dp
87 2 : res_r_ref(2, 2) = 0.7479672000061859_dp
88 2 : res_r_ref(2, 3) = 0.31115895421143025_dp
89 2 : res_r_ref(3, 1) = 0.588373227540499_dp
90 2 : res_r_ref(3, 2) = 0.8918089125227482_dp
91 2 : res_r_ref(3, 3) = 0.4097535412069894_dp
92 :
93 2 : CALL check_ref_r(res_r, res_r_ref, tolerance)
94 :
95 : ! Test A * B
96 :
97 2 : CALL gemm_square(A_in, 'N', B_in, 'N', res_c)
98 :
99 2 : res_c_ref(1, 1) = CMPLX(-0.5319854477298264_dp, 2.221497049670068_dp, kind=dp)
100 2 : res_c_ref(1, 2) = CMPLX(0.8667540454951561_dp, 2.708638262772094_dp, kind=dp)
101 2 : res_c_ref(1, 3) = CMPLX(0.6169940071822635_dp, 2.7523152479106017_dp, kind=dp)
102 2 : res_c_ref(2, 1) = CMPLX(-1.0841486927833452_dp, 1.0783614128260843_dp, kind=dp)
103 2 : res_c_ref(2, 2) = CMPLX(-0.5464163853751492_dp, 2.025430919812893_dp, kind=dp)
104 2 : res_c_ref(2, 3) = CMPLX(-0.8426527494261556_dp, 1.872774281691093_dp, kind=dp)
105 2 : res_c_ref(3, 1) = CMPLX(-0.884392147923063_dp, 1.78400551956788_dp, kind=dp)
106 2 : res_c_ref(3, 2) = CMPLX(0.3418926087394045_dp, 2.5559945109530244_dp, kind=dp)
107 2 : res_c_ref(3, 3) = CMPLX(0.000721037947416292_dp, 2.5549846237243488_dp, kind=dp)
108 :
109 2 : CALL check_ref_c(res_c, res_c_ref, tolerance)
110 :
111 : ! Test X * Y * Z
112 :
113 2 : CALL gemm_square(X_in, 'N', Y_in, 'N', Z_in, 'N', res_r)
114 :
115 2 : res_r_ref(1, 1) = 0.4765304668978436_dp
116 2 : res_r_ref(1, 2) = 0.47494858536191625_dp
117 2 : res_r_ref(1, 3) = 0.45054423436947794_dp
118 2 : res_r_ref(2, 1) = 0.6016068400643494_dp
119 2 : res_r_ref(2, 2) = 0.5441757689127916_dp
120 2 : res_r_ref(2, 3) = 0.5397551091346775_dp
121 2 : res_r_ref(3, 1) = 0.7162042207878401_dp
122 2 : res_r_ref(3, 2) = 0.6714863948435401_dp
123 2 : res_r_ref(3, 3) = 0.6621594909204267_dp
124 :
125 2 : CALL check_ref_r(res_r, res_r_ref, tolerance)
126 :
127 : ! Test A * B * C
128 :
129 2 : CALL gemm_square(A_in, 'N', B_in, 'N', C_in, 'N', res_c)
130 :
131 2 : res_c_ref(1, 1) = CMPLX(-5.504683782712595_dp, 3.5222601599484484_dp, kind=dp)
132 2 : res_c_ref(1, 2) = CMPLX(-2.9563767075011937_dp, 2.232852141215477_dp, kind=dp)
133 2 : res_c_ref(1, 3) = CMPLX(-3.233216866606864_dp, 3.8464986862655572_dp, kind=dp)
134 2 : res_c_ref(2, 1) = CMPLX(-4.63714700646176_dp, -0.11028256160215588_dp, kind=dp)
135 2 : res_c_ref(2, 2) = CMPLX(-2.680220510420048_dp, 0.12215466665130881_dp, kind=dp)
136 2 : res_c_ref(2, 3) = CMPLX(-3.583889556370547_dp, 1.091159499729193_dp, kind=dp)
137 2 : res_c_ref(3, 1) = CMPLX(-5.468121643036268_dp, 2.0272651357548415_dp, kind=dp)
138 2 : res_c_ref(3, 2) = CMPLX(-3.0054301614279586_dp, 1.4377987781538424_dp, kind=dp)
139 2 : res_c_ref(3, 3) = CMPLX(-3.534937025955706_dp, 2.8681214444912606_dp, kind=dp)
140 :
141 2 : CALL check_ref_c(res_c, res_c_ref, tolerance)
142 :
143 : ! Test X^T * Y * Z
144 :
145 2 : CALL gemm_square(X_in, 'T', Y_in, 'N', Z_in, 'N', res_r)
146 :
147 2 : res_r_ref(1, 1) = 0.6462671475738445_dp
148 2 : res_r_ref(1, 2) = 0.5760105624808499_dp
149 2 : res_r_ref(1, 3) = 0.5852827099256332_dp
150 2 : res_r_ref(2, 1) = 0.36007554144181786_dp
151 2 : res_r_ref(2, 2) = 0.40420627668516457_dp
152 2 : res_r_ref(2, 3) = 0.36217776140102986_dp
153 2 : res_r_ref(3, 1) = 0.5978645617240842_dp
154 2 : res_r_ref(3, 2) = 0.600788264751924_dp
155 2 : res_r_ref(3, 3) = 0.5673424971463421_dp
156 :
157 2 : CALL check_ref_r(res_r, res_r_ref, tolerance)
158 :
159 : ! Test A^H * B * C
160 :
161 2 : CALL gemm_square(A_in, 'C', B_in, 'N', C_in, 'N', res_c)
162 :
163 2 : res_c_ref(1, 1) = CMPLX(3.375089298965469_dp, 5.744913993063936_dp, kind=dp)
164 2 : res_c_ref(1, 2) = CMPLX(2.0725172551868294_dp, 3.258926327791143_dp, kind=dp)
165 2 : res_c_ref(1, 3) = CMPLX(3.965529787950442_dp, 3.621340775428089_dp, kind=dp)
166 2 : res_c_ref(2, 1) = CMPLX(2.4231309591599897_dp, 4.665551869666368_dp, kind=dp)
167 2 : res_c_ref(2, 2) = CMPLX(1.5937647760286848_dp, 2.6021783330446246_dp, kind=dp)
168 2 : res_c_ref(2, 3) = CMPLX(2.9609793918714686_dp, 2.92153954960111_dp, kind=dp)
169 2 : res_c_ref(3, 1) = CMPLX(3.278689562249669_dp, 4.308656958132163_dp, kind=dp)
170 2 : res_c_ref(3, 2) = CMPLX(2.05357432643753_dp, 2.5060755291807237_dp, kind=dp)
171 2 : res_c_ref(3, 3) = CMPLX(3.646272530313196_dp, 2.5667051324585874_dp, kind=dp)
172 :
173 2 : CALL check_ref_c(res_c, res_c_ref, tolerance)
174 :
175 : CONTAINS
176 : ! **************************************************************************************************
177 : !> \brief ...
178 : !> \param mat ...
179 : !> \param ref ...
180 : !> \param tolerance ...
181 : ! **************************************************************************************************
182 6 : SUBROUTINE check_ref_r(mat, ref, tolerance)
183 : REAL(kind=dp), DIMENSION(3, 3) :: mat, ref
184 : REAL(kind=dp) :: tolerance
185 :
186 : INTEGER :: i, j
187 :
188 24 : DO i = 1, 3
189 78 : DO j = 1, 3
190 72 : CPASSERT(ABS(mat(i, j) - ref(i, j)) <= tolerance)
191 : END DO
192 : END DO
193 6 : END SUBROUTINE check_ref_r
194 : ! **************************************************************************************************
195 : !> \brief ...
196 : !> \param mat ...
197 : !> \param ref ...
198 : !> \param tolerance ...
199 : ! **************************************************************************************************
200 6 : SUBROUTINE check_ref_c(mat, ref, tolerance)
201 : COMPLEX(kind=dp), DIMENSION(3, 3) :: mat, ref
202 : REAL(kind=dp) :: tolerance
203 :
204 : INTEGER :: i, j
205 :
206 24 : DO i = 1, 3
207 78 : DO j = 1, 3
208 72 : CPASSERT(ABS(mat(i, j) - ref(i, j)) <= tolerance)
209 : END DO
210 : END DO
211 6 : END SUBROUTINE check_ref_c
212 : END PROGRAM gemm_square_unittest
|