LCOV - code coverage report
Current view: top level - src - pao_main.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 100.0 % 139 139
Test Date: 2025-07-25 12:55:17 Functions: 100.0 % 5 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Main module for the PAO method
      10              : !> \author Ole Schuett
      11              : ! **************************************************************************************************
      12              : MODULE pao_main
      13              :    USE bibliography,                    ONLY: Schuett2018,&
      14              :                                               cite_reference
      15              :    USE cp_dbcsr_api,                    ONLY: dbcsr_add,&
      16              :                                               dbcsr_copy,&
      17              :                                               dbcsr_create,&
      18              :                                               dbcsr_p_type,&
      19              :                                               dbcsr_release,&
      20              :                                               dbcsr_set,&
      21              :                                               dbcsr_type
      22              :    USE cp_dbcsr_contrib,                ONLY: dbcsr_reserve_diag_blocks
      23              :    USE cp_external_control,             ONLY: external_control
      24              :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      25              :                                               ls_scf_env_type
      26              :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      27              :                                               section_vals_type
      28              :    USE kinds,                           ONLY: dp
      29              :    USE linesearch,                      ONLY: linesearch_finalize,&
      30              :                                               linesearch_init,&
      31              :                                               linesearch_reset,&
      32              :                                               linesearch_step
      33              :    USE machine,                         ONLY: m_walltime
      34              :    USE pao_input,                       ONLY: parse_pao_section
      35              :    USE pao_io,                          ONLY: pao_read_restart,&
      36              :                                               pao_write_ks_matrix_csr,&
      37              :                                               pao_write_restart,&
      38              :                                               pao_write_s_matrix_csr
      39              :    USE pao_methods,                     ONLY: &
      40              :         pao_add_forces, pao_build_core_hamiltonian, pao_build_diag_distribution, &
      41              :         pao_build_matrix_X, pao_build_orthogonalizer, pao_build_selector, pao_calc_energy, &
      42              :         pao_check_grad, pao_check_trace_ps, pao_guess_initial_P, pao_init_kinds, &
      43              :         pao_print_atom_info, pao_store_P, pao_test_convergence
      44              :    USE pao_ml,                          ONLY: pao_ml_init,&
      45              :                                               pao_ml_predict
      46              :    USE pao_model,                       ONLY: pao_model_predict
      47              :    USE pao_optimizer,                   ONLY: pao_opt_finalize,&
      48              :                                               pao_opt_init,&
      49              :                                               pao_opt_new_dir
      50              :    USE pao_param,                       ONLY: pao_calc_AB,&
      51              :                                               pao_param_finalize,&
      52              :                                               pao_param_init,&
      53              :                                               pao_param_initial_guess
      54              :    USE pao_types,                       ONLY: pao_env_type
      55              :    USE qs_environment_types,            ONLY: get_qs_env,&
      56              :                                               qs_environment_type
      57              : #include "./base/base_uses.f90"
      58              : 
      59              :    IMPLICIT NONE
      60              : 
      61              :    PRIVATE
      62              : 
      63              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_main'
      64              : 
      65              :    PUBLIC :: pao_init, pao_update, pao_post_scf, pao_optimization_start, pao_optimization_end
      66              : 
      67              : CONTAINS
      68              : 
      69              : ! **************************************************************************************************
      70              : !> \brief Initialize the PAO environment
      71              : !> \param qs_env ...
      72              : !> \param ls_scf_env ...
      73              : ! **************************************************************************************************
      74          434 :    SUBROUTINE pao_init(qs_env, ls_scf_env)
      75              :       TYPE(qs_environment_type), POINTER                 :: qs_env
      76              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
      77              : 
      78              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init'
      79              : 
      80              :       INTEGER                                            :: handle
      81          336 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
      82              :       TYPE(pao_env_type), POINTER                        :: pao
      83              :       TYPE(section_vals_type), POINTER                   :: input
      84              : 
      85          238 :       IF (.NOT. ls_scf_env%do_pao) RETURN
      86              : 
      87           98 :       CALL timeset(routineN, handle)
      88           98 :       CALL cite_reference(Schuett2018)
      89           98 :       pao => ls_scf_env%pao_env
      90           98 :       CALL get_qs_env(qs_env=qs_env, input=input, matrix_s=matrix_s)
      91              : 
      92              :       ! parse input
      93           98 :       CALL parse_pao_section(pao, input)
      94              : 
      95           98 :       CALL pao_init_kinds(pao, qs_env)
      96              : 
      97              :       ! train machine learning
      98           98 :       CALL pao_ml_init(pao, qs_env)
      99              : 
     100           98 :       CALL timestop(handle)
     101          336 :    END SUBROUTINE pao_init
     102              : 
     103              : ! **************************************************************************************************
     104              : !> \brief Start a PAO optimization run.
     105              : !> \param qs_env ...
     106              : !> \param ls_scf_env ...
     107              : ! **************************************************************************************************
     108          892 :    SUBROUTINE pao_optimization_start(qs_env, ls_scf_env)
     109              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     110              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     111              : 
     112              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_start'
     113              : 
     114              :       INTEGER                                            :: handle
     115              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     116              :       TYPE(pao_env_type), POINTER                        :: pao
     117              :       TYPE(section_vals_type), POINTER                   :: input, section
     118              : 
     119          598 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     120              : 
     121          294 :       CALL timeset(routineN, handle)
     122          294 :       CALL get_qs_env(qs_env, input=input)
     123          294 :       pao => ls_scf_env%pao_env
     124          294 :       ls_mstruct => ls_scf_env%ls_mstruct
     125              : 
     126              :       ! reset state
     127          294 :       pao%step_start_time = m_walltime()
     128          294 :       pao%istep = 0
     129          294 :       pao%matrix_P_ready = .FALSE.
     130              : 
     131              :       ! ready stuff that does not depend on atom positions
     132          294 :       IF (.NOT. pao%constants_ready) THEN
     133           98 :          CALL pao_build_diag_distribution(pao, qs_env)
     134           98 :          CALL pao_build_orthogonalizer(pao, qs_env)
     135           98 :          CALL pao_build_selector(pao, qs_env)
     136           98 :          CALL pao_build_core_hamiltonian(pao, qs_env)
     137           98 :          pao%constants_ready = .TRUE.
     138              :       END IF
     139              : 
     140          294 :       CALL pao_param_init(pao, qs_env)
     141              : 
     142              :       ! ready PAO parameter matrix_X
     143          294 :       IF (.NOT. pao%matrix_X_ready) THEN
     144           98 :          CALL pao_build_matrix_X(pao, qs_env)
     145           98 :          CALL pao_print_atom_info(pao)
     146           98 :          IF (LEN_TRIM(pao%restart_file) > 0) THEN
     147            8 :             CALL pao_read_restart(pao, qs_env)
     148           90 :          ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     149           18 :             CALL pao_ml_predict(pao, qs_env)
     150           72 :          ELSE IF (ALLOCATED(pao%models)) THEN
     151            4 :             CALL pao_model_predict(pao, qs_env)
     152              :          ELSE
     153           68 :             CALL pao_param_initial_guess(pao, qs_env)
     154              :          END IF
     155           98 :          pao%matrix_X_ready = .TRUE.
     156          196 :       ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     157          120 :          CALL pao_ml_predict(pao, qs_env)
     158           76 :       ELSE IF (ALLOCATED(pao%models)) THEN
     159           12 :          CALL pao_model_predict(pao, qs_env)
     160              :       ELSE
     161           64 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| reusing matrix_X from previous optimization"
     162              :       END IF
     163              : 
     164              :       ! init line-search
     165          294 :       section => section_vals_get_subs_vals(input, "DFT%LS_SCF%PAO%LINE_SEARCH")
     166          294 :       CALL linesearch_init(pao%linesearch, section, "PAO|")
     167              : 
     168              :       ! create some more matrices
     169          294 :       CALL dbcsr_copy(pao%matrix_G, pao%matrix_X)
     170          294 :       CALL dbcsr_set(pao%matrix_G, 0.0_dp)
     171              : 
     172          294 :       CALL dbcsr_create(ls_mstruct%matrix_A, template=pao%matrix_Y)
     173          294 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_A)
     174          294 :       CALL dbcsr_create(ls_mstruct%matrix_B, template=pao%matrix_Y)
     175          294 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_B)
     176              : 
     177              :       ! fill PAO transformation matrices
     178          294 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE.)
     179              : 
     180          294 :       CALL timestop(handle)
     181              :    END SUBROUTINE pao_optimization_start
     182              : 
     183              : ! **************************************************************************************************
     184              : !> \brief Called after the SCF optimization, updates the PAO basis.
     185              : !> \param qs_env ...
     186              : !> \param ls_scf_env ...
     187              : !> \param pao_is_done ...
     188              : ! **************************************************************************************************
     189         1062 :    SUBROUTINE pao_update(qs_env, ls_scf_env, pao_is_done)
     190              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     191              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     192              :       LOGICAL, INTENT(OUT)                               :: pao_is_done
     193              : 
     194              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_update'
     195              : 
     196              :       INTEGER                                            :: handle, icycle
     197              :       LOGICAL                                            :: cycle_converged, do_mixing, should_stop
     198              :       REAL(KIND=dp)                                      :: energy, penalty
     199              :       TYPE(dbcsr_type)                                   :: matrix_X_mixing
     200              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     201              :       TYPE(pao_env_type), POINTER                        :: pao
     202              : 
     203          816 :       IF (.NOT. ls_scf_env%do_pao) THEN
     204          304 :          pao_is_done = .TRUE.
     205          570 :          RETURN
     206              :       END IF
     207              : 
     208          512 :       ls_mstruct => ls_scf_env%ls_mstruct
     209          512 :       pao => ls_scf_env%pao_env
     210              : 
     211          512 :       IF (.NOT. pao%matrix_P_ready) THEN
     212          294 :          CALL pao_guess_initial_P(pao, qs_env, ls_scf_env)
     213          294 :          pao%matrix_P_ready = .TRUE.
     214              :       END IF
     215              : 
     216          512 :       IF (pao%max_pao == 0) THEN
     217          218 :          pao_is_done = .TRUE.
     218          218 :          RETURN
     219              :       END IF
     220              : 
     221          294 :       IF (pao%need_initial_scf) THEN
     222           48 :          pao_is_done = .FALSE.
     223           48 :          pao%need_initial_scf = .FALSE.
     224           48 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Performing initial SCF optimization."
     225           48 :          RETURN
     226              :       END IF
     227              : 
     228          246 :       CALL timeset(routineN, handle)
     229              : 
     230              :       ! perform mixing once we are well into the optimization
     231          246 :       do_mixing = pao%mixing /= 1.0_dp .AND. pao%istep > 1
     232              :       IF (do_mixing) THEN
     233          128 :          CALL dbcsr_copy(matrix_X_mixing, pao%matrix_X)
     234              :       END IF
     235              : 
     236          246 :       cycle_converged = .FALSE.
     237          246 :       icycle = 0
     238          246 :       CALL linesearch_reset(pao%linesearch)
     239          246 :       CALL pao_opt_init(pao)
     240              : 
     241        20024 :       DO WHILE (.TRUE.)
     242        10126 :          pao%istep = pao%istep + 1
     243              : 
     244        15189 :          IF (pao%iw > 0) WRITE (pao%iw, "(A,I9,A)") " PAO| ======================= Iteration: ", &
     245        10126 :             pao%istep, " ============================="
     246              : 
     247              :          ! calc energy and check trace_PS
     248        10126 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     249        10126 :          CALL pao_check_trace_PS(ls_scf_env)
     250              : 
     251        10126 :          IF (pao%linesearch%starts) THEN
     252         2616 :             icycle = icycle + 1
     253              :             ! calc new gradient including penalty terms
     254         2616 :             CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., penalty=penalty)
     255         2616 :             CALL pao_check_grad(pao, qs_env, ls_scf_env)
     256              : 
     257              :             ! calculate new direction for line-search
     258         2616 :             CALL pao_opt_new_dir(pao, icycle)
     259              : 
     260              :             !backup X
     261         2616 :             CALL dbcsr_copy(pao%matrix_X_orig, pao%matrix_X)
     262              : 
     263              :             ! print info and convergence test
     264         2616 :             CALL pao_test_convergence(pao, ls_scf_env, energy, cycle_converged)
     265         2616 :             IF (cycle_converged) THEN
     266          210 :                pao_is_done = icycle < 3
     267          210 :                IF (pao_is_done .AND. pao%iw > 0) WRITE (pao%iw, *) "PAO| converged after ", pao%istep, " steps :-)"
     268              :                EXIT
     269              :             END IF
     270              : 
     271              :             ! if we have reached the maximum number of cycles exit in order
     272              :             ! to restart with a fresh hamiltonian
     273         2406 :             IF (icycle >= pao%max_cycles) THEN
     274           18 :                IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| CG not yet converged after ", icycle, " cylces."
     275           18 :                pao_is_done = .FALSE.
     276           18 :                EXIT
     277              :             END IF
     278              : 
     279         2388 :             IF (MOD(icycle, pao%write_cycles) == 0) &
     280            8 :                CALL pao_write_restart(pao, qs_env, energy) ! write an intermediate restart file
     281              :          END IF
     282              : 
     283              :          ! check for early abort without convergence?
     284         9898 :          CALL external_control(should_stop, "PAO", start_time=qs_env%start_time, target_time=qs_env%target_time)
     285         9898 :          IF (should_stop .OR. pao%istep >= pao%max_pao) THEN
     286           18 :             CPWARN("PAO not converged!")
     287           18 :             pao_is_done = .TRUE.
     288           18 :             EXIT
     289              :          END IF
     290              : 
     291              :          ! perform line-search step
     292         9880 :          CALL linesearch_step(pao%linesearch, energy=energy, slope=pao%norm_G**2)
     293              : 
     294         9880 :          IF (pao%linesearch%step_size < 1e-9_dp) CPABORT("PAO gradient is wrong.")
     295              : 
     296         9880 :          CALL dbcsr_copy(pao%matrix_X, pao%matrix_X_orig) !restore X
     297         9880 :          CALL dbcsr_add(pao%matrix_X, pao%matrix_D, 1.0_dp, pao%linesearch%step_size)
     298              :       END DO
     299              : 
     300              :       ! perform mixing of matrix_X
     301          246 :       IF (do_mixing) THEN
     302          128 :          CALL dbcsr_add(pao%matrix_X, matrix_X_mixing, pao%mixing, 1.0_dp - pao%mixing)
     303          128 :          CALL dbcsr_release(matrix_X_mixing)
     304          128 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Recalculating energy after mixing."
     305          128 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     306              :       END IF
     307              : 
     308          246 :       CALL pao_write_restart(pao, qs_env, energy)
     309          246 :       CALL pao_opt_finalize(pao)
     310              : 
     311          246 :       CALL timestop(handle)
     312          816 :    END SUBROUTINE pao_update
     313              : 
     314              : ! **************************************************************************************************
     315              : !> \brief Calculate PAO forces and store density matrix for future ASPC extrapolations
     316              : !> \param qs_env ...
     317              : !> \param ls_scf_env ...
     318              : !> \param pao_is_done ...
     319              : ! **************************************************************************************************
     320         1110 :    SUBROUTINE pao_post_scf(qs_env, ls_scf_env, pao_is_done)
     321              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     322              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     323              :       LOGICAL, INTENT(IN)                                :: pao_is_done
     324              : 
     325              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_post_scf'
     326              : 
     327              :       INTEGER                                            :: handle
     328              : 
     329         1034 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     330          512 :       IF (.NOT. pao_is_done) RETURN
     331              : 
     332          294 :       CALL timeset(routineN, handle)
     333              : 
     334              :       ! print out the matrices here before pao_store_P converts them back into matrices in
     335              :       ! terms of the primary basis
     336          294 :       CALL pao_write_ks_matrix_csr(qs_env, ls_scf_env)
     337          294 :       CALL pao_write_s_matrix_csr(qs_env, ls_scf_env)
     338              : 
     339          294 :       CALL pao_store_P(qs_env, ls_scf_env)
     340          294 :       IF (ls_scf_env%calculate_forces) CALL pao_add_forces(qs_env, ls_scf_env)
     341              : 
     342          294 :       CALL timestop(handle)
     343              :    END SUBROUTINE
     344              : 
     345              : ! **************************************************************************************************
     346              : !> \brief Finish a PAO optimization run.
     347              : !> \param ls_scf_env ...
     348              : ! **************************************************************************************************
     349          892 :    SUBROUTINE pao_optimization_end(ls_scf_env)
     350              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     351              : 
     352              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_end'
     353              : 
     354              :       INTEGER                                            :: handle
     355              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     356              :       TYPE(pao_env_type), POINTER                        :: pao
     357              : 
     358          598 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     359              : 
     360          294 :       pao => ls_scf_env%pao_env
     361          294 :       ls_mstruct => ls_scf_env%ls_mstruct
     362              : 
     363          294 :       CALL timeset(routineN, handle)
     364              : 
     365          294 :       CALL pao_param_finalize(pao)
     366              : 
     367              :       ! We keep pao%matrix_X for next scf-run, e.g. during MD or GEO-OPT
     368          294 :       CALL dbcsr_release(pao%matrix_X_orig)
     369          294 :       CALL dbcsr_release(pao%matrix_G)
     370          294 :       CALL dbcsr_release(ls_mstruct%matrix_A)
     371          294 :       CALL dbcsr_release(ls_mstruct%matrix_B)
     372              : 
     373          294 :       CALL linesearch_finalize(pao%linesearch)
     374              : 
     375          294 :       CALL timestop(handle)
     376              :    END SUBROUTINE pao_optimization_end
     377              : 
     378              : END MODULE pao_main
        

Generated by: LCOV version 2.0-1