Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Equivariant parametrization
10 : !> \author Ole Schuett
11 : ! **************************************************************************************************
12 : MODULE pao_param_equi
13 : USE basis_set_types, ONLY: gto_basis_set_type
14 : USE cp_dbcsr_api, ONLY: &
15 : dbcsr_complete_redistribute, dbcsr_create, dbcsr_distribution_type, dbcsr_get_block_p, &
16 : dbcsr_get_info, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
17 : dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, &
18 : dbcsr_release, dbcsr_type
19 : USE cp_dbcsr_contrib, ONLY: dbcsr_reserve_diag_blocks
20 : USE dm_ls_scf_types, ONLY: ls_mstruct_type,&
21 : ls_scf_env_type
22 : USE kinds, ONLY: dp
23 : USE mathlib, ONLY: diamat_all
24 : USE message_passing, ONLY: mp_comm_type
25 : USE pao_param_methods, ONLY: pao_calc_grad_lnv_wrt_AB
26 : USE pao_potentials, ONLY: pao_guess_initial_potential
27 : USE pao_types, ONLY: pao_env_type
28 : USE qs_environment_types, ONLY: get_qs_env,&
29 : qs_environment_type
30 : USE qs_kind_types, ONLY: get_qs_kind,&
31 : qs_kind_type
32 : #include "./base/base_uses.f90"
33 :
34 : IMPLICIT NONE
35 :
36 : PRIVATE
37 :
38 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_param_equi'
39 :
40 : PUBLIC :: pao_param_init_equi, pao_param_finalize_equi, pao_calc_AB_equi
41 : PUBLIC :: pao_param_count_equi, pao_param_initguess_equi
42 :
43 : CONTAINS
44 :
45 : ! **************************************************************************************************
46 : !> \brief Initialize equivariant parametrization
47 : !> \param pao ...
48 : ! **************************************************************************************************
49 26 : SUBROUTINE pao_param_init_equi(pao)
50 : TYPE(pao_env_type), POINTER :: pao
51 :
52 26 : IF (pao%precondition) &
53 0 : CPABORT("PAO preconditioning not supported for selected parametrization.")
54 :
55 26 : END SUBROUTINE pao_param_init_equi
56 :
57 : ! **************************************************************************************************
58 : !> \brief Finalize equivariant parametrization
59 : ! **************************************************************************************************
60 26 : SUBROUTINE pao_param_finalize_equi()
61 :
62 : ! Nothing to do.
63 :
64 26 : END SUBROUTINE pao_param_finalize_equi
65 :
66 : ! **************************************************************************************************
67 : !> \brief Returns the number of parameters for given atomic kind
68 : !> \param qs_env ...
69 : !> \param ikind ...
70 : !> \param nparams ...
71 : ! **************************************************************************************************
72 112 : SUBROUTINE pao_param_count_equi(qs_env, ikind, nparams)
73 : TYPE(qs_environment_type), POINTER :: qs_env
74 : INTEGER, INTENT(IN) :: ikind
75 : INTEGER, INTENT(OUT) :: nparams
76 :
77 : INTEGER :: pao_basis_size, pri_basis_size
78 : TYPE(gto_basis_set_type), POINTER :: basis_set
79 56 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
80 :
81 56 : CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set)
82 : CALL get_qs_kind(qs_kind_set(ikind), &
83 : basis_set=basis_set, &
84 56 : pao_basis_size=pao_basis_size)
85 56 : pri_basis_size = basis_set%nsgf
86 :
87 56 : nparams = pao_basis_size*pri_basis_size
88 :
89 56 : END SUBROUTINE pao_param_count_equi
90 :
91 : ! **************************************************************************************************
92 : !> \brief Fills matrix_X with an initial guess
93 : !> \param pao ...
94 : !> \param qs_env ...
95 : ! **************************************************************************************************
96 10 : SUBROUTINE pao_param_initguess_equi(pao, qs_env)
97 : TYPE(pao_env_type), POINTER :: pao
98 : TYPE(qs_environment_type), POINTER :: qs_env
99 :
100 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_param_initguess_equi'
101 :
102 : INTEGER :: acol, arow, handle, i, iatom, m, n
103 10 : INTEGER, DIMENSION(:), POINTER :: blk_sizes_pao, blk_sizes_pri
104 : LOGICAL :: found
105 10 : REAL(dp), DIMENSION(:), POINTER :: H_evals
106 10 : REAL(dp), DIMENSION(:, :), POINTER :: A, block_H0, block_N, block_N_inv, &
107 10 : block_X, H, H_evecs, V0
108 : TYPE(dbcsr_iterator_type) :: iter
109 :
110 10 : CALL timeset(routineN, handle)
111 :
112 10 : CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
113 :
114 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,blk_sizes_pri,blk_sizes_pao) &
115 : !$OMP PRIVATE(iter,arow,acol,iatom,n,m,i,found) &
116 10 : !$OMP PRIVATE(block_X,block_H0,block_N,block_N_inv,A,H,H_evecs,H_evals,V0)
117 : CALL dbcsr_iterator_start(iter, pao%matrix_X)
118 : DO WHILE (dbcsr_iterator_blocks_left(iter))
119 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
120 : iatom = arow; CPASSERT(arow == acol)
121 :
122 : CALL dbcsr_get_block_p(matrix=pao%matrix_H0, row=iatom, col=iatom, block=block_H0, found=found)
123 : CALL dbcsr_get_block_p(matrix=pao%matrix_N_diag, row=iatom, col=iatom, block=block_N, found=found)
124 : CALL dbcsr_get_block_p(matrix=pao%matrix_N_inv_diag, row=iatom, col=iatom, block=block_N_inv, found=found)
125 : CPASSERT(ASSOCIATED(block_H0) .AND. ASSOCIATED(block_N) .AND. ASSOCIATED(block_N_inv))
126 :
127 : n = blk_sizes_pri(iatom) ! size of primary basis
128 : m = blk_sizes_pao(iatom) ! size of pao basis
129 :
130 : ALLOCATE (V0(n, n))
131 : CALL pao_guess_initial_potential(qs_env, iatom, V0)
132 :
133 : ! construct H
134 : ALLOCATE (H(n, n))
135 : H = MATMUL(MATMUL(block_N, block_H0 + V0), block_N) ! transform into orthonormal basis
136 :
137 : ! diagonalize H
138 : ALLOCATE (H_evecs(n, n), H_evals(n))
139 : H_evecs = H
140 : CALL diamat_all(H_evecs, H_evals)
141 :
142 : ! use first m eigenvectors as initial guess
143 : ALLOCATE (A(n, m))
144 : A = MATMUL(block_N_inv, H_evecs(:, 1:m))
145 :
146 : ! normalize vectors
147 : DO i = 1, m
148 : A(:, i) = A(:, i)/NORM2(A(:, i))
149 : END DO
150 :
151 : block_X = RESHAPE(A, [n*m, 1])
152 : DEALLOCATE (H, V0, A, H_evecs, H_evals)
153 :
154 : END DO
155 : CALL dbcsr_iterator_stop(iter)
156 : !$OMP END PARALLEL
157 :
158 10 : CALL timestop(handle)
159 :
160 10 : END SUBROUTINE pao_param_initguess_equi
161 :
162 : ! **************************************************************************************************
163 : !> \brief Takes current matrix_X and calculates the matrices A and B.
164 : !> \param pao ...
165 : !> \param qs_env ...
166 : !> \param ls_scf_env ...
167 : !> \param gradient ...
168 : !> \param penalty ...
169 : ! **************************************************************************************************
170 3412 : SUBROUTINE pao_calc_AB_equi(pao, qs_env, ls_scf_env, gradient, penalty)
171 : TYPE(pao_env_type), POINTER :: pao
172 : TYPE(qs_environment_type), POINTER :: qs_env
173 : TYPE(ls_scf_env_type), TARGET :: ls_scf_env
174 : LOGICAL, INTENT(IN) :: gradient
175 : REAL(dp), INTENT(INOUT), OPTIONAL :: penalty
176 :
177 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_calc_AB_equi'
178 :
179 : INTEGER :: acol, arow, handle, i, iatom, j, k, m, n
180 : LOGICAL :: found
181 : REAL(dp) :: denom, penalty_sum, w
182 1706 : REAL(dp), DIMENSION(:), POINTER :: ANNA_evals
183 1706 : REAL(dp), DIMENSION(:, :), POINTER :: ANNA, ANNA_evecs, ANNA_inv, block_A, &
184 1706 : block_B, block_G, block_Ma, block_Mb, &
185 1706 : block_N, block_X, D, G, M1, M2, M3, &
186 1706 : M4, M5, NN
187 : TYPE(dbcsr_distribution_type) :: main_dist
188 : TYPE(dbcsr_iterator_type) :: iter
189 1706 : TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: matrix_s
190 : TYPE(dbcsr_type) :: matrix_G_nondiag, matrix_Ma, matrix_Mb, &
191 : matrix_X_nondiag
192 : TYPE(ls_mstruct_type), POINTER :: ls_mstruct
193 : TYPE(mp_comm_type) :: group
194 :
195 1706 : CALL timeset(routineN, handle)
196 1706 : ls_mstruct => ls_scf_env%ls_mstruct
197 :
198 1706 : IF (gradient) THEN
199 234 : CALL pao_calc_grad_lnv_wrt_AB(qs_env, ls_scf_env, matrix_Ma, matrix_Mb)
200 : END IF
201 :
202 : ! Redistribute matrix_X from diag_distribution to distribution of matrix_s.
203 1706 : CALL get_qs_env(qs_env, matrix_s=matrix_s)
204 1706 : CALL dbcsr_get_info(matrix=matrix_s(1)%matrix, distribution=main_dist)
205 : CALL dbcsr_create(matrix_X_nondiag, &
206 : name="PAO matrix_X_nondiag", &
207 : dist=main_dist, &
208 1706 : template=pao%matrix_X)
209 1706 : CALL dbcsr_reserve_diag_blocks(matrix_X_nondiag)
210 1706 : CALL dbcsr_complete_redistribute(pao%matrix_X, matrix_X_nondiag)
211 :
212 : ! Compuation of matrix_G uses distr. of matrix_s, afterwards we redistribute to diag_distribution.
213 1706 : IF (gradient) THEN
214 : CALL dbcsr_create(matrix_G_nondiag, &
215 : name="PAO matrix_G_nondiag", &
216 : dist=main_dist, &
217 234 : template=pao%matrix_G)
218 234 : CALL dbcsr_reserve_diag_blocks(matrix_G_nondiag)
219 : END IF
220 :
221 : penalty_sum = 0.0_dp
222 :
223 : !$OMP PARALLEL DEFAULT(NONE) &
224 : !$OMP SHARED(pao,ls_mstruct,matrix_X_nondiag,matrix_G_nondiag,matrix_Ma,matrix_Mb,gradient,penalty) &
225 : !$OMP PRIVATE(iter,arow,acol,iatom,found,n,m,w,i,j,k,denom) &
226 : !$OMP PRIVATE(NN,ANNA,ANNA_evals,ANNA_evecs,ANNA_inv,D,G,M1,M2,M3,M4,M5) &
227 : !$OMP PRIVATE(block_X,block_A,block_B,block_N,block_Ma, block_Mb, block_G) &
228 1706 : !$OMP REDUCTION(+:penalty_sum)
229 : CALL dbcsr_iterator_start(iter, matrix_X_nondiag)
230 : DO WHILE (dbcsr_iterator_blocks_left(iter))
231 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
232 : iatom = arow; CPASSERT(arow == acol)
233 : CALL dbcsr_get_block_p(matrix=ls_mstruct%matrix_A, row=iatom, col=iatom, block=block_A, found=found)
234 : CPASSERT(ASSOCIATED(block_A))
235 : CALL dbcsr_get_block_p(matrix=ls_mstruct%matrix_B, row=iatom, col=iatom, block=block_B, found=found)
236 : CPASSERT(ASSOCIATED(block_B))
237 : CALL dbcsr_get_block_p(matrix=pao%matrix_N, row=iatom, col=iatom, block=block_N, found=found)
238 : CPASSERT(ASSOCIATED(block_N))
239 :
240 : n = SIZE(block_A, 1) ! size of primary basis
241 : m = SIZE(block_A, 2) ! size of pao basis
242 : block_A = RESHAPE(block_X, [n, m])
243 :
244 : ! restrain pao basis vectors to unit norm
245 : IF (PRESENT(penalty)) THEN
246 : DO i = 1, m
247 : w = 1.0_dp - SUM(block_A(:, i)**2)
248 : penalty_sum = penalty_sum + pao%penalty_strength*w**2
249 : END DO
250 : END IF
251 :
252 : ALLOCATE (NN(n, n), ANNA(m, m))
253 : NN = MATMUL(block_N, block_N) ! it's actually S^{-1}
254 : ANNA = MATMUL(MATMUL(TRANSPOSE(block_A), NN), block_A)
255 :
256 : ! diagonalize ANNA
257 : ALLOCATE (ANNA_evecs(m, m), ANNA_evals(m))
258 : ANNA_evecs(:, :) = ANNA
259 : CALL diamat_all(ANNA_evecs, ANNA_evals)
260 : IF (MINVAL(ABS(ANNA_evals)) < 1e-10_dp) CPABORT("PAO basis singualar.")
261 :
262 : ! build ANNA_inv
263 : ALLOCATE (ANNA_inv(m, m))
264 : ANNA_inv(:, :) = 0.0_dp
265 : DO k = 1, m
266 : w = 1.0_dp/ANNA_evals(k)
267 : DO i = 1, m
268 : DO j = 1, m
269 : ANNA_inv(i, j) = ANNA_inv(i, j) + w*ANNA_evecs(i, k)*ANNA_evecs(j, k)
270 : END DO
271 : END DO
272 : END DO
273 :
274 : !B = 1/S * A * 1/(A^T 1/S A)
275 : block_B = MATMUL(MATMUL(NN, block_A), ANNA_inv)
276 :
277 : ! TURNING POINT (if calc grad) ------------------------------------------
278 : IF (gradient) THEN
279 : CALL dbcsr_get_block_p(matrix=matrix_G_nondiag, row=iatom, col=iatom, block=block_G, found=found)
280 : CPASSERT(ASSOCIATED(block_G))
281 : CALL dbcsr_get_block_p(matrix=matrix_Ma, row=iatom, col=iatom, block=block_Ma, found=found)
282 : CALL dbcsr_get_block_p(matrix=matrix_Mb, row=iatom, col=iatom, block=block_Mb, found=found)
283 : ! don't check ASSOCIATED(block_M), it might have been filtered out.
284 :
285 : ALLOCATE (G(n, m))
286 : G(:, :) = 0.0_dp
287 :
288 : IF (PRESENT(penalty)) THEN
289 : DO i = 1, m
290 : w = 1.0_dp - SUM(block_A(:, i)**2)
291 : G(:, i) = -4.0_dp*pao%penalty_strength*w*block_A(:, i)
292 : END DO
293 : END IF
294 :
295 : IF (ASSOCIATED(block_Ma)) THEN
296 : G = G + block_Ma
297 : END IF
298 :
299 : IF (ASSOCIATED(block_Mb)) THEN
300 : G = G + MATMUL(MATMUL(NN, block_Mb), ANNA_inv)
301 :
302 : ! calculate derivatives dAA_inv/ dAA
303 : ALLOCATE (D(m, m), M1(m, m), M2(m, m), M3(m, m), M4(m, m), M5(m, m))
304 :
305 : DO i = 1, m
306 : DO j = 1, m
307 : denom = ANNA_evals(i) - ANNA_evals(j)
308 : IF (i == j) THEN
309 : D(i, i) = -1.0_dp/ANNA_evals(i)**2 ! diagonal elements
310 : ELSE IF (ABS(denom) > 1e-10_dp) THEN
311 : D(i, j) = (1.0_dp/ANNA_evals(i) - 1.0_dp/ANNA_evals(j))/denom
312 : ELSE
313 : D(i, j) = -1.0_dp ! limit according to L'Hospital's rule
314 : END IF
315 : END DO
316 : END DO
317 :
318 : M1 = MATMUL(MATMUL(TRANSPOSE(block_A), NN), block_Mb)
319 : M2 = MATMUL(MATMUL(TRANSPOSE(ANNA_evecs), M1), ANNA_evecs)
320 : M3 = M2*D ! Hadamard product
321 : M4 = MATMUL(MATMUL(ANNA_evecs, M3), TRANSPOSE(ANNA_evecs))
322 : M5 = 0.5_dp*(M4 + TRANSPOSE(M4))
323 : G = G + 2.0_dp*MATMUL(MATMUL(NN, block_A), M5)
324 :
325 : DEALLOCATE (D, M1, M2, M3, M4, M5)
326 : END IF
327 :
328 : block_G = RESHAPE(G, [n*m, 1])
329 : DEALLOCATE (G)
330 : END IF
331 :
332 : DEALLOCATE (NN, ANNA, ANNA_evecs, ANNA_evals, ANNA_inv)
333 : END DO
334 : CALL dbcsr_iterator_stop(iter)
335 : !$OMP END PARALLEL
336 :
337 : ! sum penalty energies across ranks
338 1706 : IF (PRESENT(penalty)) THEN
339 1678 : CALL dbcsr_get_info(pao%matrix_X, group=group)
340 1678 : CALL group%sum(penalty_sum)
341 1678 : penalty = penalty_sum
342 : END IF
343 :
344 1706 : CALL dbcsr_release(matrix_X_nondiag)
345 :
346 1706 : IF (gradient) THEN
347 234 : CALL dbcsr_complete_redistribute(matrix_G_nondiag, pao%matrix_G)
348 234 : CALL dbcsr_release(matrix_G_nondiag)
349 234 : CALL dbcsr_release(matrix_Ma)
350 234 : CALL dbcsr_release(matrix_Mb)
351 : END IF
352 :
353 1706 : CALL timestop(handle)
354 :
355 1706 : END SUBROUTINE pao_calc_AB_equi
356 :
357 : END MODULE pao_param_equi
|