LCOV - code coverage report
Current view: top level - src - pao_param_equi.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 98.1 % 54 53
Test Date: 2025-07-25 12:55:17 Functions: 100.0 % 5 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Equivariant parametrization
      10              : !> \author Ole Schuett
      11              : ! **************************************************************************************************
      12              : MODULE pao_param_equi
      13              :    USE basis_set_types,                 ONLY: gto_basis_set_type
      14              :    USE cp_dbcsr_api,                    ONLY: &
      15              :         dbcsr_complete_redistribute, dbcsr_create, dbcsr_distribution_type, dbcsr_get_block_p, &
      16              :         dbcsr_get_info, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
      17              :         dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, &
      18              :         dbcsr_release, dbcsr_type
      19              :    USE cp_dbcsr_contrib,                ONLY: dbcsr_reserve_diag_blocks
      20              :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      21              :                                               ls_scf_env_type
      22              :    USE kinds,                           ONLY: dp
      23              :    USE mathlib,                         ONLY: diamat_all
      24              :    USE message_passing,                 ONLY: mp_comm_type
      25              :    USE pao_param_methods,               ONLY: pao_calc_grad_lnv_wrt_AB
      26              :    USE pao_potentials,                  ONLY: pao_guess_initial_potential
      27              :    USE pao_types,                       ONLY: pao_env_type
      28              :    USE qs_environment_types,            ONLY: get_qs_env,&
      29              :                                               qs_environment_type
      30              :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      31              :                                               qs_kind_type
      32              : #include "./base/base_uses.f90"
      33              : 
      34              :    IMPLICIT NONE
      35              : 
      36              :    PRIVATE
      37              : 
      38              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_param_equi'
      39              : 
      40              :    PUBLIC :: pao_param_init_equi, pao_param_finalize_equi, pao_calc_AB_equi
      41              :    PUBLIC :: pao_param_count_equi, pao_param_initguess_equi
      42              : 
      43              : CONTAINS
      44              : 
      45              : ! **************************************************************************************************
      46              : !> \brief Initialize equivariant parametrization
      47              : !> \param pao ...
      48              : ! **************************************************************************************************
      49           26 :    SUBROUTINE pao_param_init_equi(pao)
      50              :       TYPE(pao_env_type), POINTER                        :: pao
      51              : 
      52           26 :       IF (pao%precondition) &
      53            0 :          CPABORT("PAO preconditioning not supported for selected parametrization.")
      54              : 
      55           26 :    END SUBROUTINE pao_param_init_equi
      56              : 
      57              : ! **************************************************************************************************
      58              : !> \brief Finalize equivariant parametrization
      59              : ! **************************************************************************************************
      60           26 :    SUBROUTINE pao_param_finalize_equi()
      61              : 
      62              :       ! Nothing to do.
      63              : 
      64           26 :    END SUBROUTINE pao_param_finalize_equi
      65              : 
      66              : ! **************************************************************************************************
      67              : !> \brief Returns the number of parameters for given atomic kind
      68              : !> \param qs_env ...
      69              : !> \param ikind ...
      70              : !> \param nparams ...
      71              : ! **************************************************************************************************
      72          112 :    SUBROUTINE pao_param_count_equi(qs_env, ikind, nparams)
      73              :       TYPE(qs_environment_type), POINTER                 :: qs_env
      74              :       INTEGER, INTENT(IN)                                :: ikind
      75              :       INTEGER, INTENT(OUT)                               :: nparams
      76              : 
      77              :       INTEGER                                            :: pao_basis_size, pri_basis_size
      78              :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
      79           56 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      80              : 
      81           56 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set)
      82              :       CALL get_qs_kind(qs_kind_set(ikind), &
      83              :                        basis_set=basis_set, &
      84           56 :                        pao_basis_size=pao_basis_size)
      85           56 :       pri_basis_size = basis_set%nsgf
      86              : 
      87           56 :       nparams = pao_basis_size*pri_basis_size
      88              : 
      89           56 :    END SUBROUTINE pao_param_count_equi
      90              : 
      91              : ! **************************************************************************************************
      92              : !> \brief Fills matrix_X with an initial guess
      93              : !> \param pao ...
      94              : !> \param qs_env ...
      95              : ! **************************************************************************************************
      96           10 :    SUBROUTINE pao_param_initguess_equi(pao, qs_env)
      97              :       TYPE(pao_env_type), POINTER                        :: pao
      98              :       TYPE(qs_environment_type), POINTER                 :: qs_env
      99              : 
     100              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_param_initguess_equi'
     101              : 
     102              :       INTEGER                                            :: acol, arow, handle, i, iatom, m, n
     103           10 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_pao, blk_sizes_pri
     104              :       LOGICAL                                            :: found
     105           10 :       REAL(dp), DIMENSION(:), POINTER                    :: H_evals
     106           10 :       REAL(dp), DIMENSION(:, :), POINTER                 :: A, block_H0, block_N, block_N_inv, &
     107           10 :                                                             block_X, H, H_evecs, V0
     108              :       TYPE(dbcsr_iterator_type)                          :: iter
     109              : 
     110           10 :       CALL timeset(routineN, handle)
     111              : 
     112           10 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
     113              : 
     114              : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,blk_sizes_pri,blk_sizes_pao) &
     115              : !$OMP PRIVATE(iter,arow,acol,iatom,n,m,i,found) &
     116           10 : !$OMP PRIVATE(block_X,block_H0,block_N,block_N_inv,A,H,H_evecs,H_evals,V0)
     117              :       CALL dbcsr_iterator_start(iter, pao%matrix_X)
     118              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     119              :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     120              :          iatom = arow; CPASSERT(arow == acol)
     121              : 
     122              :          CALL dbcsr_get_block_p(matrix=pao%matrix_H0, row=iatom, col=iatom, block=block_H0, found=found)
     123              :          CALL dbcsr_get_block_p(matrix=pao%matrix_N_diag, row=iatom, col=iatom, block=block_N, found=found)
     124              :          CALL dbcsr_get_block_p(matrix=pao%matrix_N_inv_diag, row=iatom, col=iatom, block=block_N_inv, found=found)
     125              :          CPASSERT(ASSOCIATED(block_H0) .AND. ASSOCIATED(block_N) .AND. ASSOCIATED(block_N_inv))
     126              : 
     127              :          n = blk_sizes_pri(iatom) ! size of primary basis
     128              :          m = blk_sizes_pao(iatom) ! size of pao basis
     129              : 
     130              :          ALLOCATE (V0(n, n))
     131              :          CALL pao_guess_initial_potential(qs_env, iatom, V0)
     132              : 
     133              :          ! construct H
     134              :          ALLOCATE (H(n, n))
     135              :          H = MATMUL(MATMUL(block_N, block_H0 + V0), block_N) ! transform into orthonormal basis
     136              : 
     137              :          ! diagonalize H
     138              :          ALLOCATE (H_evecs(n, n), H_evals(n))
     139              :          H_evecs = H
     140              :          CALL diamat_all(H_evecs, H_evals)
     141              : 
     142              :          ! use first m eigenvectors as initial guess
     143              :          ALLOCATE (A(n, m))
     144              :          A = MATMUL(block_N_inv, H_evecs(:, 1:m))
     145              : 
     146              :          ! normalize vectors
     147              :          DO i = 1, m
     148              :             A(:, i) = A(:, i)/NORM2(A(:, i))
     149              :          END DO
     150              : 
     151              :          block_X = RESHAPE(A, (/n*m, 1/))
     152              :          DEALLOCATE (H, V0, A, H_evecs, H_evals)
     153              : 
     154              :       END DO
     155              :       CALL dbcsr_iterator_stop(iter)
     156              : !$OMP END PARALLEL
     157              : 
     158           10 :       CALL timestop(handle)
     159              : 
     160           10 :    END SUBROUTINE pao_param_initguess_equi
     161              : 
     162              : ! **************************************************************************************************
     163              : !> \brief Takes current matrix_X and calculates the matrices A and B.
     164              : !> \param pao ...
     165              : !> \param qs_env ...
     166              : !> \param ls_scf_env ...
     167              : !> \param gradient ...
     168              : !> \param penalty ...
     169              : ! **************************************************************************************************
     170         3412 :    SUBROUTINE pao_calc_AB_equi(pao, qs_env, ls_scf_env, gradient, penalty)
     171              :       TYPE(pao_env_type), POINTER                        :: pao
     172              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     173              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     174              :       LOGICAL, INTENT(IN)                                :: gradient
     175              :       REAL(dp), INTENT(INOUT), OPTIONAL                  :: penalty
     176              : 
     177              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_calc_AB_equi'
     178              : 
     179              :       INTEGER                                            :: acol, arow, handle, i, iatom, j, k, m, n
     180              :       LOGICAL                                            :: found
     181              :       REAL(dp)                                           :: denom, w
     182         1706 :       REAL(dp), DIMENSION(:), POINTER                    :: ANNA_evals
     183         1706 :       REAL(dp), DIMENSION(:, :), POINTER                 :: ANNA, ANNA_evecs, ANNA_inv, block_A, &
     184         1706 :                                                             block_B, block_G, block_Ma, block_Mb, &
     185         1706 :                                                             block_N, block_X, D, G, M1, M2, M3, &
     186         1706 :                                                             M4, M5, NN
     187              :       TYPE(dbcsr_distribution_type)                      :: main_dist
     188              :       TYPE(dbcsr_iterator_type)                          :: iter
     189         1706 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     190              :       TYPE(dbcsr_type)                                   :: matrix_G_nondiag, matrix_Ma, matrix_Mb, &
     191              :                                                             matrix_X_nondiag
     192              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     193              :       TYPE(mp_comm_type)                                 :: group
     194              : 
     195         1706 :       CALL timeset(routineN, handle)
     196         1706 :       ls_mstruct => ls_scf_env%ls_mstruct
     197              : 
     198         1706 :       IF (gradient) THEN
     199          234 :          CALL pao_calc_grad_lnv_wrt_AB(qs_env, ls_scf_env, matrix_Ma, matrix_Mb)
     200              :       END IF
     201              : 
     202              :       ! Redistribute matrix_X from diag_distribution to distribution of matrix_s.
     203         1706 :       CALL get_qs_env(qs_env, matrix_s=matrix_s)
     204         1706 :       CALL dbcsr_get_info(matrix=matrix_s(1)%matrix, distribution=main_dist)
     205              :       CALL dbcsr_create(matrix_X_nondiag, &
     206              :                         name="PAO matrix_X_nondiag", &
     207              :                         dist=main_dist, &
     208         1706 :                         template=pao%matrix_X)
     209         1706 :       CALL dbcsr_reserve_diag_blocks(matrix_X_nondiag)
     210         1706 :       CALL dbcsr_complete_redistribute(pao%matrix_X, matrix_X_nondiag)
     211              : 
     212              :       ! Compuation of matrix_G uses distr. of matrix_s, afterwards we redistribute to diag_distribution.
     213         1706 :       IF (gradient) THEN
     214              :          CALL dbcsr_create(matrix_G_nondiag, &
     215              :                            name="PAO matrix_G_nondiag", &
     216              :                            dist=main_dist, &
     217          234 :                            template=pao%matrix_G)
     218          234 :          CALL dbcsr_reserve_diag_blocks(matrix_G_nondiag)
     219              :       END IF
     220              : 
     221              : !$OMP PARALLEL DEFAULT(NONE) &
     222              : !$OMP SHARED(pao,ls_mstruct,matrix_X_nondiag,matrix_G_nondiag,matrix_Ma,matrix_Mb,gradient,penalty) &
     223              : !$OMP PRIVATE(iter,arow,acol,iatom,found,n,m,w,i,j,k,denom) &
     224              : !$OMP PRIVATE(NN,ANNA,ANNA_evals,ANNA_evecs,ANNA_inv,D,G,M1,M2,M3,M4,M5) &
     225         1706 : !$OMP PRIVATE(block_X,block_A,block_B,block_N,block_Ma, block_Mb, block_G)
     226              :       CALL dbcsr_iterator_start(iter, matrix_X_nondiag)
     227              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     228              :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     229              :          iatom = arow; CPASSERT(arow == acol)
     230              :          CALL dbcsr_get_block_p(matrix=ls_mstruct%matrix_A, row=iatom, col=iatom, block=block_A, found=found)
     231              :          CPASSERT(ASSOCIATED(block_A))
     232              :          CALL dbcsr_get_block_p(matrix=ls_mstruct%matrix_B, row=iatom, col=iatom, block=block_B, found=found)
     233              :          CPASSERT(ASSOCIATED(block_B))
     234              :          CALL dbcsr_get_block_p(matrix=pao%matrix_N, row=iatom, col=iatom, block=block_N, found=found)
     235              :          CPASSERT(ASSOCIATED(block_N))
     236              : 
     237              :          n = SIZE(block_A, 1) ! size of primary basis
     238              :          m = SIZE(block_A, 2) ! size of pao basis
     239              :          block_A = RESHAPE(block_X, (/n, m/))
     240              : 
     241              :          ! restrain pao basis vectors to unit norm
     242              :          IF (PRESENT(penalty)) THEN
     243              :             DO i = 1, m
     244              :                w = 1.0_dp - SUM(block_A(:, i)**2)
     245              :                penalty = penalty + pao%penalty_strength*w**2
     246              :             END DO
     247              :          END IF
     248              : 
     249              :          ALLOCATE (NN(n, n), ANNA(m, m))
     250              :          NN = MATMUL(block_N, block_N) ! it's actually S^{-1}
     251              :          ANNA = MATMUL(MATMUL(TRANSPOSE(block_A), NN), block_A)
     252              : 
     253              :          ! diagonalize ANNA
     254              :          ALLOCATE (ANNA_evecs(m, m), ANNA_evals(m))
     255              :          ANNA_evecs(:, :) = ANNA
     256              :          CALL diamat_all(ANNA_evecs, ANNA_evals)
     257              :          IF (MINVAL(ABS(ANNA_evals)) < 1e-10_dp) CPABORT("PAO basis singualar.")
     258              : 
     259              :          ! build ANNA_inv
     260              :          ALLOCATE (ANNA_inv(m, m))
     261              :          ANNA_inv(:, :) = 0.0_dp
     262              :          DO k = 1, m
     263              :             w = 1.0_dp/ANNA_evals(k)
     264              :             DO i = 1, m
     265              :             DO j = 1, m
     266              :                ANNA_inv(i, j) = ANNA_inv(i, j) + w*ANNA_evecs(i, k)*ANNA_evecs(j, k)
     267              :             END DO
     268              :             END DO
     269              :          END DO
     270              : 
     271              :          !B = 1/S * A * 1/(A^T 1/S A)
     272              :          block_B = MATMUL(MATMUL(NN, block_A), ANNA_inv)
     273              : 
     274              :          ! TURNING POINT (if calc grad) ------------------------------------------
     275              :          IF (gradient) THEN
     276              :             CALL dbcsr_get_block_p(matrix=matrix_G_nondiag, row=iatom, col=iatom, block=block_G, found=found)
     277              :             CPASSERT(ASSOCIATED(block_G))
     278              :             CALL dbcsr_get_block_p(matrix=matrix_Ma, row=iatom, col=iatom, block=block_Ma, found=found)
     279              :             CALL dbcsr_get_block_p(matrix=matrix_Mb, row=iatom, col=iatom, block=block_Mb, found=found)
     280              :             ! don't check ASSOCIATED(block_M), it might have been filtered out.
     281              : 
     282              :             ALLOCATE (G(n, m))
     283              :             G(:, :) = 0.0_dp
     284              : 
     285              :             IF (PRESENT(penalty)) THEN
     286              :                DO i = 1, m
     287              :                   w = 1.0_dp - SUM(block_A(:, i)**2)
     288              :                   G(:, i) = -4.0_dp*pao%penalty_strength*w*block_A(:, i)
     289              :                END DO
     290              :             END IF
     291              : 
     292              :             IF (ASSOCIATED(block_Ma)) THEN
     293              :                G = G + block_Ma
     294              :             END IF
     295              : 
     296              :             IF (ASSOCIATED(block_Mb)) THEN
     297              :                G = G + MATMUL(MATMUL(NN, block_Mb), ANNA_inv)
     298              : 
     299              :                ! calculate derivatives dAA_inv/ dAA
     300              :                ALLOCATE (D(m, m), M1(m, m), M2(m, m), M3(m, m), M4(m, m), M5(m, m))
     301              : 
     302              :                DO i = 1, m
     303              :                DO j = 1, m
     304              :                   denom = ANNA_evals(i) - ANNA_evals(j)
     305              :                   IF (i == j) THEN
     306              :                      D(i, i) = -1.0_dp/ANNA_evals(i)**2 ! diagonal elements
     307              :                   ELSE IF (ABS(denom) > 1e-10_dp) THEN
     308              :                      D(i, j) = (1.0_dp/ANNA_evals(i) - 1.0_dp/ANNA_evals(j))/denom
     309              :                   ELSE
     310              :                      D(i, j) = -1.0_dp ! limit according to L'Hospital's rule
     311              :                   END IF
     312              :                END DO
     313              :                END DO
     314              : 
     315              :                M1 = MATMUL(MATMUL(TRANSPOSE(block_A), NN), block_Mb)
     316              :                M2 = MATMUL(MATMUL(TRANSPOSE(ANNA_evecs), M1), ANNA_evecs)
     317              :                M3 = M2*D ! Hadamard product
     318              :                M4 = MATMUL(MATMUL(ANNA_evecs, M3), TRANSPOSE(ANNA_evecs))
     319              :                M5 = 0.5_dp*(M4 + TRANSPOSE(M4))
     320              :                G = G + 2.0_dp*MATMUL(MATMUL(NN, block_A), M5)
     321              : 
     322              :                DEALLOCATE (D, M1, M2, M3, M4, M5)
     323              :             END IF
     324              : 
     325              :             block_G = RESHAPE(G, (/n*m, 1/))
     326              :             DEALLOCATE (G)
     327              :          END IF
     328              : 
     329              :          DEALLOCATE (NN, ANNA, ANNA_evecs, ANNA_evals, ANNA_inv)
     330              :       END DO
     331              :       CALL dbcsr_iterator_stop(iter)
     332              : !$OMP END PARALLEL
     333              : 
     334              :       ! sum penalty energies across ranks
     335         1706 :       IF (PRESENT(penalty)) THEN
     336         1678 :          CALL dbcsr_get_info(pao%matrix_X, group=group)
     337         1678 :          CALL group%sum(penalty)
     338              :       END IF
     339              : 
     340         1706 :       CALL dbcsr_release(matrix_X_nondiag)
     341              : 
     342         1706 :       IF (gradient) THEN
     343          234 :          CALL dbcsr_complete_redistribute(matrix_G_nondiag, pao%matrix_G)
     344          234 :          CALL dbcsr_release(matrix_G_nondiag)
     345          234 :          CALL dbcsr_release(matrix_Ma)
     346          234 :          CALL dbcsr_release(matrix_Mb)
     347              :       END IF
     348              : 
     349         1706 :       CALL timestop(handle)
     350              : 
     351         1706 :    END SUBROUTINE pao_calc_AB_equi
     352              : 
     353              : END MODULE pao_param_equi
        

Generated by: LCOV version 2.0-1