LCOV - code coverage report
Current view: top level - src - pao_main.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:3db43b4) Lines: 98.0 % 147 144
Test Date: 2026-04-03 06:55:30 Functions: 100.0 % 5 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Main module for the PAO method
      10              : !> \author Ole Schuett
      11              : ! **************************************************************************************************
      12              : MODULE pao_main
      13              :    USE bibliography,                    ONLY: Schuett2018,&
      14              :                                               cite_reference
      15              :    USE cp_dbcsr_api,                    ONLY: dbcsr_add,&
      16              :                                               dbcsr_copy,&
      17              :                                               dbcsr_create,&
      18              :                                               dbcsr_p_type,&
      19              :                                               dbcsr_release,&
      20              :                                               dbcsr_set,&
      21              :                                               dbcsr_type
      22              :    USE cp_dbcsr_contrib,                ONLY: dbcsr_reserve_diag_blocks
      23              :    USE cp_external_control,             ONLY: external_control
      24              :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      25              :                                               ls_scf_env_type
      26              :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      27              :                                               section_vals_type
      28              :    USE kinds,                           ONLY: dp
      29              :    USE linesearch,                      ONLY: linesearch_finalize,&
      30              :                                               linesearch_init,&
      31              :                                               linesearch_reset,&
      32              :                                               linesearch_step
      33              :    USE machine,                         ONLY: m_walltime
      34              :    USE pao_input,                       ONLY: parse_pao_section
      35              :    USE pao_io,                          ONLY: pao_read_restart,&
      36              :                                               pao_write_hcore_matrix_csr,&
      37              :                                               pao_write_ks_matrix_csr,&
      38              :                                               pao_write_p_matrix_csr,&
      39              :                                               pao_write_restart,&
      40              :                                               pao_write_s_matrix_csr
      41              :    USE pao_methods,                     ONLY: &
      42              :         pao_add_forces, pao_build_core_hamiltonian, pao_build_diag_distribution, &
      43              :         pao_build_matrix_X, pao_build_orthogonalizer, pao_build_selector, pao_calc_energy, &
      44              :         pao_check_grad, pao_check_trace_ps, pao_guess_initial_P, pao_init_kinds, &
      45              :         pao_print_atom_info, pao_store_P, pao_test_convergence
      46              :    USE pao_ml,                          ONLY: pao_ml_init,&
      47              :                                               pao_ml_predict
      48              :    USE pao_model,                       ONLY: pao_model_predict
      49              :    USE pao_optimizer,                   ONLY: pao_opt_finalize,&
      50              :                                               pao_opt_init,&
      51              :                                               pao_opt_new_dir
      52              :    USE pao_param,                       ONLY: pao_calc_AB,&
      53              :                                               pao_param_finalize,&
      54              :                                               pao_param_init,&
      55              :                                               pao_param_initial_guess
      56              :    USE pao_types,                       ONLY: pao_env_type
      57              :    USE qs_environment_types,            ONLY: get_qs_env,&
      58              :                                               qs_environment_type
      59              :    USE virial_types,                    ONLY: virial_type
      60              : #include "./base/base_uses.f90"
      61              : 
      62              :    IMPLICIT NONE
      63              : 
      64              :    PRIVATE
      65              : 
      66              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_main'
      67              : 
      68              :    PUBLIC :: pao_init, pao_update, pao_post_scf, pao_optimization_start, pao_optimization_end
      69              : 
      70              : CONTAINS
      71              : 
      72              : ! **************************************************************************************************
      73              : !> \brief Initialize the PAO environment
      74              : !> \param qs_env ...
      75              : !> \param ls_scf_env ...
      76              : ! **************************************************************************************************
      77          438 :    SUBROUTINE pao_init(qs_env, ls_scf_env)
      78              :       TYPE(qs_environment_type), POINTER                 :: qs_env
      79              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
      80              : 
      81              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init'
      82              : 
      83              :       INTEGER                                            :: handle
      84          340 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
      85              :       TYPE(pao_env_type), POINTER                        :: pao
      86              :       TYPE(section_vals_type), POINTER                   :: input
      87              : 
      88          242 :       IF (.NOT. ls_scf_env%do_pao) RETURN
      89              : 
      90           98 :       CALL timeset(routineN, handle)
      91           98 :       CALL cite_reference(Schuett2018)
      92           98 :       pao => ls_scf_env%pao_env
      93           98 :       CALL get_qs_env(qs_env=qs_env, input=input, matrix_s=matrix_s)
      94              : 
      95              :       ! parse input
      96           98 :       CALL parse_pao_section(pao, input)
      97              : 
      98           98 :       CALL pao_init_kinds(pao, qs_env)
      99              : 
     100              :       ! train machine learning
     101           98 :       CALL pao_ml_init(pao, qs_env)
     102              : 
     103           98 :       CALL timestop(handle)
     104          340 :    END SUBROUTINE pao_init
     105              : 
     106              : ! **************************************************************************************************
     107              : !> \brief Start a PAO optimization run.
     108              : !> \param qs_env ...
     109              : !> \param ls_scf_env ...
     110              : ! **************************************************************************************************
     111          946 :    SUBROUTINE pao_optimization_start(qs_env, ls_scf_env)
     112              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     113              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     114              : 
     115              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_start'
     116              : 
     117              :       INTEGER                                            :: handle
     118              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     119              :       TYPE(pao_env_type), POINTER                        :: pao
     120              :       TYPE(section_vals_type), POINTER                   :: input, section
     121              : 
     122          652 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     123              : 
     124          294 :       CALL timeset(routineN, handle)
     125          294 :       CALL get_qs_env(qs_env, input=input)
     126          294 :       pao => ls_scf_env%pao_env
     127          294 :       ls_mstruct => ls_scf_env%ls_mstruct
     128              : 
     129              :       ! reset state
     130          294 :       pao%step_start_time = m_walltime()
     131          294 :       pao%istep = 0
     132          294 :       pao%matrix_P_ready = .FALSE.
     133              : 
     134              :       ! ready stuff that does not depend on atom positions
     135          294 :       IF (.NOT. pao%constants_ready) THEN
     136           98 :          CALL pao_build_diag_distribution(pao, qs_env)
     137           98 :          CALL pao_build_orthogonalizer(pao, qs_env)
     138           98 :          CALL pao_build_selector(pao, qs_env)
     139           98 :          CALL pao_build_core_hamiltonian(pao, qs_env)
     140           98 :          pao%constants_ready = .TRUE.
     141              :       END IF
     142              : 
     143          294 :       CALL pao_param_init(pao, qs_env)
     144              : 
     145              :       ! ready PAO parameter matrix_X
     146          294 :       IF (.NOT. pao%matrix_X_ready) THEN
     147           98 :          CALL pao_build_matrix_X(pao, qs_env)
     148           98 :          CALL pao_print_atom_info(pao)
     149           98 :          IF (LEN_TRIM(pao%restart_file) > 0) THEN
     150            8 :             CALL pao_read_restart(pao, qs_env)
     151           90 :          ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     152           18 :             CALL pao_ml_predict(pao, qs_env)
     153           72 :          ELSE IF (ALLOCATED(pao%models)) THEN
     154            4 :             CALL pao_model_predict(pao, qs_env)
     155              :          ELSE
     156           68 :             CALL pao_param_initial_guess(pao, qs_env)
     157              :          END IF
     158           98 :          pao%matrix_X_ready = .TRUE.
     159          196 :       ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     160          120 :          CALL pao_ml_predict(pao, qs_env)
     161           76 :       ELSE IF (ALLOCATED(pao%models)) THEN
     162           12 :          CALL pao_model_predict(pao, qs_env)
     163              :       ELSE
     164           64 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| reusing matrix_X from previous optimization"
     165              :       END IF
     166              : 
     167              :       ! init line-search
     168          294 :       section => section_vals_get_subs_vals(input, "DFT%LS_SCF%PAO%LINE_SEARCH")
     169          294 :       CALL linesearch_init(pao%linesearch, section, "PAO|")
     170              : 
     171              :       ! create some more matrices
     172          294 :       CALL dbcsr_copy(pao%matrix_G, pao%matrix_X)
     173          294 :       CALL dbcsr_set(pao%matrix_G, 0.0_dp)
     174              : 
     175          294 :       CALL dbcsr_create(ls_mstruct%matrix_A, template=pao%matrix_Y)
     176          294 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_A)
     177          294 :       CALL dbcsr_create(ls_mstruct%matrix_B, template=pao%matrix_Y)
     178          294 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_B)
     179              : 
     180              :       ! fill PAO transformation matrices
     181          294 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE.)
     182              : 
     183          294 :       CALL timestop(handle)
     184              :    END SUBROUTINE pao_optimization_start
     185              : 
     186              : ! **************************************************************************************************
     187              : !> \brief Called after the SCF optimization, updates the PAO basis.
     188              : !> \param qs_env ...
     189              : !> \param ls_scf_env ...
     190              : !> \param pao_is_done ...
     191              : ! **************************************************************************************************
     192         1116 :    SUBROUTINE pao_update(qs_env, ls_scf_env, pao_is_done)
     193              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     194              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     195              :       LOGICAL, INTENT(OUT)                               :: pao_is_done
     196              : 
     197              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_update'
     198              : 
     199              :       INTEGER                                            :: handle, icycle
     200              :       LOGICAL                                            :: cycle_converged, do_mixing, should_stop
     201              :       REAL(KIND=dp)                                      :: energy, penalty
     202              :       TYPE(dbcsr_type)                                   :: matrix_X_mixing
     203              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     204              :       TYPE(pao_env_type), POINTER                        :: pao
     205              : 
     206          870 :       IF (.NOT. ls_scf_env%do_pao) THEN
     207          358 :          pao_is_done = .TRUE.
     208          624 :          RETURN
     209              :       END IF
     210              : 
     211          512 :       ls_mstruct => ls_scf_env%ls_mstruct
     212          512 :       pao => ls_scf_env%pao_env
     213              : 
     214          512 :       IF (.NOT. pao%matrix_P_ready) THEN
     215          294 :          CALL pao_guess_initial_P(pao, qs_env, ls_scf_env)
     216          294 :          pao%matrix_P_ready = .TRUE.
     217              :       END IF
     218              : 
     219          512 :       IF (pao%max_pao == 0) THEN
     220          218 :          pao_is_done = .TRUE.
     221          218 :          RETURN
     222              :       END IF
     223              : 
     224          294 :       IF (pao%need_initial_scf) THEN
     225           48 :          pao_is_done = .FALSE.
     226           48 :          pao%need_initial_scf = .FALSE.
     227           48 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Performing initial SCF optimization."
     228           48 :          RETURN
     229              :       END IF
     230              : 
     231          246 :       CALL timeset(routineN, handle)
     232              : 
     233              :       ! perform mixing once we are well into the optimization
     234          246 :       do_mixing = pao%mixing /= 1.0_dp .AND. pao%istep > 1
     235              :       IF (do_mixing) THEN
     236          128 :          CALL dbcsr_copy(matrix_X_mixing, pao%matrix_X)
     237              :       END IF
     238              : 
     239          246 :       cycle_converged = .FALSE.
     240          246 :       icycle = 0
     241          246 :       CALL linesearch_reset(pao%linesearch)
     242          246 :       CALL pao_opt_init(pao)
     243              : 
     244        20024 :       DO WHILE (.TRUE.)
     245        10126 :          pao%istep = pao%istep + 1
     246              : 
     247        15189 :          IF (pao%iw > 0) WRITE (pao%iw, "(A,I9,A)") " PAO| ======================= Iteration: ", &
     248        10126 :             pao%istep, " ============================="
     249              : 
     250              :          ! calc energy and check trace_PS
     251        10126 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     252        10126 :          CALL pao_check_trace_PS(ls_scf_env)
     253              : 
     254        10126 :          IF (pao%linesearch%starts) THEN
     255         2616 :             icycle = icycle + 1
     256              :             ! calc new gradient including penalty terms
     257         2616 :             CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., penalty=penalty)
     258         2616 :             CALL pao_check_grad(pao, qs_env, ls_scf_env)
     259              : 
     260              :             ! calculate new direction for line-search
     261         2616 :             CALL pao_opt_new_dir(pao, icycle)
     262              : 
     263              :             !backup X
     264         2616 :             CALL dbcsr_copy(pao%matrix_X_orig, pao%matrix_X)
     265              : 
     266              :             ! print info and convergence test
     267         2616 :             CALL pao_test_convergence(pao, ls_scf_env, energy, cycle_converged)
     268         2616 :             IF (cycle_converged) THEN
     269          210 :                pao_is_done = icycle < 3
     270          210 :                IF (pao_is_done .AND. pao%iw > 0) WRITE (pao%iw, *) "PAO| converged after ", pao%istep, " steps :-)"
     271              :                EXIT
     272              :             END IF
     273              : 
     274              :             ! if we have reached the maximum number of cycles exit in order
     275              :             ! to restart with a fresh hamiltonian
     276         2406 :             IF (icycle >= pao%max_cycles) THEN
     277           18 :                IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| CG not yet converged after ", icycle, " cylces."
     278           18 :                pao_is_done = .FALSE.
     279           18 :                EXIT
     280              :             END IF
     281              : 
     282         2388 :             IF (MOD(icycle, pao%write_cycles) == 0) &
     283            8 :                CALL pao_write_restart(pao, qs_env, energy) ! write an intermediate restart file
     284              :          END IF
     285              : 
     286              :          ! check for early abort without convergence?
     287         9898 :          CALL external_control(should_stop, "PAO", start_time=qs_env%start_time, target_time=qs_env%target_time)
     288         9898 :          IF (should_stop .OR. pao%istep >= pao%max_pao) THEN
     289           18 :             IF (pao%ignore_convergence_failure) THEN
     290           18 :                CPWARN("PAO not converged!")
     291              :             ELSE
     292            0 :                CPABORT("PAO not converged!")
     293              :             END IF
     294           18 :             pao_is_done = .TRUE.
     295           18 :             EXIT
     296              :          END IF
     297              : 
     298              :          ! perform line-search step
     299         9880 :          CALL linesearch_step(pao%linesearch, energy=energy, slope=pao%norm_G**2)
     300              : 
     301         9880 :          IF (pao%linesearch%step_size < 1e-9_dp) CPABORT("PAO gradient is wrong.")
     302              : 
     303         9880 :          CALL dbcsr_copy(pao%matrix_X, pao%matrix_X_orig) !restore X
     304         9880 :          CALL dbcsr_add(pao%matrix_X, pao%matrix_D, 1.0_dp, pao%linesearch%step_size)
     305              :       END DO
     306              : 
     307              :       ! perform mixing of matrix_X
     308          246 :       IF (do_mixing) THEN
     309          128 :          CALL dbcsr_add(pao%matrix_X, matrix_X_mixing, pao%mixing, 1.0_dp - pao%mixing)
     310          128 :          CALL dbcsr_release(matrix_X_mixing)
     311          128 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Recalculating energy after mixing."
     312          128 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     313              :       END IF
     314              : 
     315          246 :       CALL pao_write_restart(pao, qs_env, energy)
     316          246 :       CALL pao_opt_finalize(pao)
     317              : 
     318          246 :       CALL timestop(handle)
     319          870 :    END SUBROUTINE pao_update
     320              : 
     321              : ! **************************************************************************************************
     322              : !> \brief Calculate PAO forces and store density matrix for future ASPC extrapolations
     323              : !> \param qs_env ...
     324              : !> \param ls_scf_env ...
     325              : !> \param pao_is_done ...
     326              : ! **************************************************************************************************
     327         1164 :    SUBROUTINE pao_post_scf(qs_env, ls_scf_env, pao_is_done)
     328              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     329              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     330              :       LOGICAL, INTENT(IN)                                :: pao_is_done
     331              : 
     332              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_post_scf'
     333              : 
     334              :       INTEGER                                            :: handle
     335              :       LOGICAL                                            :: use_virial
     336              :       TYPE(virial_type), POINTER                         :: virial
     337              : 
     338         1088 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     339          512 :       IF (.NOT. pao_is_done) RETURN
     340              : 
     341          294 :       CALL timeset(routineN, handle)
     342              : 
     343              :       ! print out the matrices here before pao_store_P converts them back into matrices in
     344              :       ! terms of the primary basis
     345          294 :       CALL pao_write_ks_matrix_csr(qs_env, ls_scf_env)
     346          294 :       CALL pao_write_s_matrix_csr(qs_env, ls_scf_env)
     347          294 :       CALL pao_write_hcore_matrix_csr(qs_env, ls_scf_env)
     348          294 :       CALL pao_write_p_matrix_csr(qs_env, ls_scf_env)
     349              : 
     350          294 :       CALL pao_store_P(qs_env, ls_scf_env)
     351          294 :       IF (ls_scf_env%calculate_forces) CALL pao_add_forces(qs_env, ls_scf_env)
     352              : 
     353          294 :       CALL get_qs_env(qs_env=qs_env, virial=virial)
     354          294 :       use_virial = virial%pv_availability .AND. (.NOT. virial%pv_numer)
     355            0 :       IF (use_virial .AND. ls_scf_env%calculate_forces) THEN
     356            0 :          CPABORT("Analytical stress tensor not implemented for PAO.")
     357              :       END IF
     358              : 
     359          294 :       CALL timestop(handle)
     360              :    END SUBROUTINE pao_post_scf
     361              : 
     362              : ! **************************************************************************************************
     363              : !> \brief Finish a PAO optimization run.
     364              : !> \param ls_scf_env ...
     365              : ! **************************************************************************************************
     366          946 :    SUBROUTINE pao_optimization_end(ls_scf_env)
     367              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     368              : 
     369              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_end'
     370              : 
     371              :       INTEGER                                            :: handle
     372              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     373              :       TYPE(pao_env_type), POINTER                        :: pao
     374              : 
     375          652 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     376              : 
     377          294 :       pao => ls_scf_env%pao_env
     378          294 :       ls_mstruct => ls_scf_env%ls_mstruct
     379              : 
     380          294 :       CALL timeset(routineN, handle)
     381              : 
     382          294 :       CALL pao_param_finalize(pao)
     383              : 
     384              :       ! We keep pao%matrix_X for next scf-run, e.g. during MD or GEO-OPT
     385          294 :       CALL dbcsr_release(pao%matrix_X_orig)
     386          294 :       CALL dbcsr_release(pao%matrix_G)
     387          294 :       CALL dbcsr_release(ls_mstruct%matrix_A)
     388          294 :       CALL dbcsr_release(ls_mstruct%matrix_B)
     389              : 
     390          294 :       CALL linesearch_finalize(pao%linesearch)
     391              : 
     392          294 :       CALL timestop(handle)
     393              :    END SUBROUTINE pao_optimization_end
     394              : 
     395              : END MODULE pao_main
        

Generated by: LCOV version 2.0-1