LCOV - code coverage report
Current view: top level - src - pao_main.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:1f285aa) Lines: 135 135 100.0 %
Date: 2024-04-23 06:49:27 Functions: 5 5 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Main module for the PAO method
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_main
      13             :    USE bibliography,                    ONLY: Schuett2018,&
      14             :                                               cite_reference
      15             :    USE cp_external_control,             ONLY: external_control
      16             :    USE dbcsr_api,                       ONLY: dbcsr_add,&
      17             :                                               dbcsr_copy,&
      18             :                                               dbcsr_create,&
      19             :                                               dbcsr_p_type,&
      20             :                                               dbcsr_release,&
      21             :                                               dbcsr_reserve_diag_blocks,&
      22             :                                               dbcsr_set,&
      23             :                                               dbcsr_type
      24             :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      25             :                                               ls_scf_env_type
      26             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      27             :                                               section_vals_type
      28             :    USE kinds,                           ONLY: dp
      29             :    USE linesearch,                      ONLY: linesearch_finalize,&
      30             :                                               linesearch_init,&
      31             :                                               linesearch_reset,&
      32             :                                               linesearch_step
      33             :    USE machine,                         ONLY: m_walltime
      34             :    USE pao_input,                       ONLY: parse_pao_section
      35             :    USE pao_io,                          ONLY: pao_read_restart,&
      36             :                                               pao_write_ks_matrix_csr,&
      37             :                                               pao_write_restart,&
      38             :                                               pao_write_s_matrix_csr
      39             :    USE pao_methods,                     ONLY: &
      40             :         pao_add_forces, pao_build_core_hamiltonian, pao_build_diag_distribution, &
      41             :         pao_build_matrix_X, pao_build_orthogonalizer, pao_build_selector, pao_calc_energy, &
      42             :         pao_check_grad, pao_check_trace_ps, pao_guess_initial_P, pao_init_kinds, &
      43             :         pao_print_atom_info, pao_store_P, pao_test_convergence
      44             :    USE pao_ml,                          ONLY: pao_ml_init,&
      45             :                                               pao_ml_predict
      46             :    USE pao_optimizer,                   ONLY: pao_opt_finalize,&
      47             :                                               pao_opt_init,&
      48             :                                               pao_opt_new_dir
      49             :    USE pao_param,                       ONLY: pao_calc_AB,&
      50             :                                               pao_param_finalize,&
      51             :                                               pao_param_init,&
      52             :                                               pao_param_initial_guess
      53             :    USE pao_types,                       ONLY: pao_env_type
      54             :    USE qs_environment_types,            ONLY: get_qs_env,&
      55             :                                               qs_environment_type
      56             : #include "./base/base_uses.f90"
      57             : 
      58             :    IMPLICIT NONE
      59             : 
      60             :    PRIVATE
      61             : 
      62             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_main'
      63             : 
      64             :    PUBLIC :: pao_init, pao_update, pao_post_scf, pao_optimization_start, pao_optimization_end
      65             : 
      66             : CONTAINS
      67             : 
      68             : ! **************************************************************************************************
      69             : !> \brief Initialize the PAO environment
      70             : !> \param qs_env ...
      71             : !> \param ls_scf_env ...
      72             : ! **************************************************************************************************
      73         430 :    SUBROUTINE pao_init(qs_env, ls_scf_env)
      74             :       TYPE(qs_environment_type), POINTER                 :: qs_env
      75             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
      76             : 
      77             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init'
      78             : 
      79             :       INTEGER                                            :: handle
      80         336 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
      81             :       TYPE(pao_env_type), POINTER                        :: pao
      82             :       TYPE(section_vals_type), POINTER                   :: input
      83             : 
      84         242 :       IF (.NOT. ls_scf_env%do_pao) RETURN
      85             : 
      86          94 :       CALL timeset(routineN, handle)
      87          94 :       CALL cite_reference(Schuett2018)
      88          94 :       pao => ls_scf_env%pao_env
      89          94 :       CALL get_qs_env(qs_env=qs_env, input=input, matrix_s=matrix_s)
      90             : 
      91             :       ! parse input
      92          94 :       CALL parse_pao_section(pao, input)
      93             : 
      94          94 :       CALL pao_init_kinds(pao, qs_env)
      95             : 
      96             :       ! train machine learning
      97          94 :       CALL pao_ml_init(pao, qs_env)
      98             : 
      99          94 :       CALL timestop(handle)
     100         336 :    END SUBROUTINE pao_init
     101             : 
     102             : ! **************************************************************************************************
     103             : !> \brief Start a PAO optimization run.
     104             : !> \param qs_env ...
     105             : !> \param ls_scf_env ...
     106             : ! **************************************************************************************************
     107         874 :    SUBROUTINE pao_optimization_start(qs_env, ls_scf_env)
     108             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     109             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     110             : 
     111             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_start'
     112             : 
     113             :       INTEGER                                            :: handle
     114             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     115             :       TYPE(pao_env_type), POINTER                        :: pao
     116             :       TYPE(section_vals_type), POINTER                   :: input, section
     117             : 
     118         596 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     119             : 
     120         278 :       CALL timeset(routineN, handle)
     121         278 :       CALL get_qs_env(qs_env, input=input)
     122         278 :       pao => ls_scf_env%pao_env
     123         278 :       ls_mstruct => ls_scf_env%ls_mstruct
     124             : 
     125             :       ! reset state
     126         278 :       pao%step_start_time = m_walltime()
     127         278 :       pao%istep = 0
     128         278 :       pao%matrix_P_ready = .FALSE.
     129             : 
     130             :       ! ready stuff that does not depend on atom positions
     131         278 :       IF (.NOT. pao%constants_ready) THEN
     132          94 :          CALL pao_build_diag_distribution(pao, qs_env)
     133          94 :          CALL pao_build_orthogonalizer(pao, qs_env)
     134          94 :          CALL pao_build_selector(pao, qs_env)
     135          94 :          CALL pao_build_core_hamiltonian(pao, qs_env)
     136          94 :          pao%constants_ready = .TRUE.
     137             :       END IF
     138             : 
     139         278 :       CALL pao_param_init(pao, qs_env)
     140             : 
     141             :       ! ready PAO parameter matrix_X
     142         278 :       IF (.NOT. pao%matrix_X_ready) THEN
     143          94 :          CALL pao_build_matrix_X(pao, qs_env)
     144          94 :          CALL pao_print_atom_info(pao)
     145          94 :          IF (LEN_TRIM(pao%restart_file) > 0) THEN
     146           8 :             CALL pao_read_restart(pao, qs_env)
     147          86 :          ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     148          18 :             CALL pao_ml_predict(pao, qs_env)
     149             :          ELSE
     150          68 :             CALL pao_param_initial_guess(pao, qs_env)
     151             :          END IF
     152          94 :          pao%matrix_X_ready = .TRUE.
     153         184 :       ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     154         120 :          CALL pao_ml_predict(pao, qs_env)
     155             :       ELSE
     156          64 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| reusing matrix_X from previous optimization"
     157             :       END IF
     158             : 
     159             :       ! init line-search
     160         278 :       section => section_vals_get_subs_vals(input, "DFT%LS_SCF%PAO%LINE_SEARCH")
     161         278 :       CALL linesearch_init(pao%linesearch, section, "PAO|")
     162             : 
     163             :       ! create some more matrices
     164         278 :       CALL dbcsr_copy(pao%matrix_G, pao%matrix_X)
     165         278 :       CALL dbcsr_set(pao%matrix_G, 0.0_dp)
     166             : 
     167         278 :       CALL dbcsr_create(ls_mstruct%matrix_A, template=pao%matrix_Y)
     168         278 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_A)
     169         278 :       CALL dbcsr_create(ls_mstruct%matrix_B, template=pao%matrix_Y)
     170         278 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_B)
     171             : 
     172             :       ! fill PAO transformation matrices
     173         278 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE.)
     174             : 
     175         278 :       CALL timestop(handle)
     176             :    END SUBROUTINE pao_optimization_start
     177             : 
     178             : ! **************************************************************************************************
     179             : !> \brief Called after the SCF optimization, updates the PAO basis.
     180             : !> \param qs_env ...
     181             : !> \param ls_scf_env ...
     182             : !> \param pao_is_done ...
     183             : ! **************************************************************************************************
     184        1060 :    SUBROUTINE pao_update(qs_env, ls_scf_env, pao_is_done)
     185             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     186             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     187             :       LOGICAL, INTENT(OUT)                               :: pao_is_done
     188             : 
     189             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_update'
     190             : 
     191             :       INTEGER                                            :: handle, icycle
     192             :       LOGICAL                                            :: cycle_converged, do_mixing, should_stop
     193             :       REAL(KIND=dp)                                      :: energy, penalty
     194             :       TYPE(dbcsr_type)                                   :: matrix_X_mixing
     195             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     196             :       TYPE(pao_env_type), POINTER                        :: pao
     197             : 
     198         814 :       IF (.NOT. ls_scf_env%do_pao) THEN
     199         318 :          pao_is_done = .TRUE.
     200         568 :          RETURN
     201             :       END IF
     202             : 
     203         496 :       ls_mstruct => ls_scf_env%ls_mstruct
     204         496 :       pao => ls_scf_env%pao_env
     205             : 
     206         496 :       IF (.NOT. pao%matrix_P_ready) THEN
     207         278 :          CALL pao_guess_initial_P(pao, qs_env, ls_scf_env)
     208         278 :          pao%matrix_P_ready = .TRUE.
     209             :       END IF
     210             : 
     211         496 :       IF (pao%max_pao == 0) THEN
     212         202 :          pao_is_done = .TRUE.
     213         202 :          RETURN
     214             :       END IF
     215             : 
     216         294 :       IF (pao%need_initial_scf) THEN
     217          48 :          pao_is_done = .FALSE.
     218          48 :          pao%need_initial_scf = .FALSE.
     219          48 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Performing initial SCF optimization."
     220          48 :          RETURN
     221             :       END IF
     222             : 
     223         246 :       CALL timeset(routineN, handle)
     224             : 
     225             :       ! perform mixing once we are well into the optimization
     226         246 :       do_mixing = pao%mixing /= 1.0_dp .AND. pao%istep > 1
     227             :       IF (do_mixing) THEN
     228         128 :          CALL dbcsr_copy(matrix_X_mixing, pao%matrix_X)
     229             :       END IF
     230             : 
     231         246 :       cycle_converged = .FALSE.
     232         246 :       icycle = 0
     233         246 :       CALL linesearch_reset(pao%linesearch)
     234         246 :       CALL pao_opt_init(pao)
     235             : 
     236       20056 :       DO WHILE (.TRUE.)
     237       10142 :          pao%istep = pao%istep + 1
     238             : 
     239       15213 :          IF (pao%iw > 0) WRITE (pao%iw, "(A,I9,A)") " PAO| ======================= Iteration: ", &
     240       10142 :             pao%istep, " ============================="
     241             : 
     242             :          ! calc energy and check trace_PS
     243       10142 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     244       10142 :          CALL pao_check_trace_PS(ls_scf_env)
     245             : 
     246       10142 :          IF (pao%linesearch%starts) THEN
     247        2622 :             icycle = icycle + 1
     248             :             ! calc new gradient including penalty terms
     249        2622 :             CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., penalty=penalty)
     250        2622 :             CALL pao_check_grad(pao, qs_env, ls_scf_env)
     251             : 
     252             :             ! calculate new direction for line-search
     253        2622 :             CALL pao_opt_new_dir(pao, icycle)
     254             : 
     255             :             !backup X
     256        2622 :             CALL dbcsr_copy(pao%matrix_X_orig, pao%matrix_X)
     257             : 
     258             :             ! print info and convergence test
     259        2622 :             CALL pao_test_convergence(pao, ls_scf_env, energy, cycle_converged)
     260        2622 :             IF (cycle_converged) THEN
     261         212 :                pao_is_done = icycle < 3
     262         212 :                IF (pao_is_done .AND. pao%iw > 0) WRITE (pao%iw, *) "PAO| converged after ", pao%istep, " steps :-)"
     263             :                EXIT
     264             :             END IF
     265             : 
     266             :             ! if we have reached the maximum number of cycles exit in order
     267             :             ! to restart with a fresh hamiltonian
     268        2410 :             IF (icycle >= pao%max_cycles) THEN
     269          16 :                IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| CG not yet converged after ", icycle, " cylces."
     270          16 :                pao_is_done = .FALSE.
     271          16 :                EXIT
     272             :             END IF
     273             : 
     274        2394 :             IF (MOD(icycle, pao%write_cycles) == 0) &
     275           8 :                CALL pao_write_restart(pao, qs_env, energy) ! write an intermediate restart file
     276             :          END IF
     277             : 
     278             :          ! check for early abort without convergence?
     279        9914 :          CALL external_control(should_stop, "PAO", start_time=qs_env%start_time, target_time=qs_env%target_time)
     280        9914 :          IF (should_stop .OR. pao%istep >= pao%max_pao) THEN
     281          18 :             CPWARN("PAO not converged!")
     282          18 :             pao_is_done = .TRUE.
     283          18 :             EXIT
     284             :          END IF
     285             : 
     286             :          ! perform line-search step
     287        9896 :          CALL linesearch_step(pao%linesearch, energy=energy, slope=pao%norm_G**2)
     288             : 
     289        9896 :          IF (pao%linesearch%step_size < 1e-10_dp) CPABORT("PAO gradient is wrong.")
     290             : 
     291        9896 :          CALL dbcsr_copy(pao%matrix_X, pao%matrix_X_orig) !restore X
     292        9896 :          CALL dbcsr_add(pao%matrix_X, pao%matrix_D, 1.0_dp, pao%linesearch%step_size)
     293             :       END DO
     294             : 
     295             :       ! perform mixing of matrix_X
     296         246 :       IF (do_mixing) THEN
     297         128 :          CALL dbcsr_add(pao%matrix_X, matrix_X_mixing, pao%mixing, 1.0_dp - pao%mixing)
     298         128 :          CALL dbcsr_release(matrix_X_mixing)
     299         128 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Recalculating energy after mixing."
     300         128 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     301             :       END IF
     302             : 
     303         246 :       CALL pao_write_restart(pao, qs_env, energy)
     304         246 :       CALL pao_opt_finalize(pao)
     305             : 
     306         246 :       CALL timestop(handle)
     307         814 :    END SUBROUTINE pao_update
     308             : 
     309             : ! **************************************************************************************************
     310             : !> \brief Calculate PAO forces and store density matrix for future ASPC extrapolations
     311             : !> \param qs_env ...
     312             : !> \param ls_scf_env ...
     313             : !> \param pao_is_done ...
     314             : ! **************************************************************************************************
     315        1092 :    SUBROUTINE pao_post_scf(qs_env, ls_scf_env, pao_is_done)
     316             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     317             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     318             :       LOGICAL, INTENT(IN)                                :: pao_is_done
     319             : 
     320             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_post_scf'
     321             : 
     322             :       INTEGER                                            :: handle
     323             : 
     324        1032 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     325         496 :       IF (.NOT. pao_is_done) RETURN
     326             : 
     327         278 :       CALL timeset(routineN, handle)
     328             : 
     329             :       ! print out the matrices here before pao_store_P converts them back into matrices in
     330             :       ! terms of the primary basis
     331         278 :       CALL pao_write_ks_matrix_csr(qs_env, ls_scf_env)
     332         278 :       CALL pao_write_s_matrix_csr(qs_env, ls_scf_env)
     333             : 
     334         278 :       CALL pao_store_P(qs_env, ls_scf_env)
     335         278 :       IF (ls_scf_env%calculate_forces) CALL pao_add_forces(qs_env, ls_scf_env)
     336             : 
     337         278 :       CALL timestop(handle)
     338             :    END SUBROUTINE
     339             : 
     340             : ! **************************************************************************************************
     341             : !> \brief Finish a PAO optimization run.
     342             : !> \param ls_scf_env ...
     343             : ! **************************************************************************************************
     344         874 :    SUBROUTINE pao_optimization_end(ls_scf_env)
     345             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     346             : 
     347             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_end'
     348             : 
     349             :       INTEGER                                            :: handle
     350             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     351             :       TYPE(pao_env_type), POINTER                        :: pao
     352             : 
     353         596 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     354             : 
     355         278 :       pao => ls_scf_env%pao_env
     356         278 :       ls_mstruct => ls_scf_env%ls_mstruct
     357             : 
     358         278 :       CALL timeset(routineN, handle)
     359             : 
     360         278 :       CALL pao_param_finalize(pao)
     361             : 
     362             :       ! We keep pao%matrix_X for next scf-run, e.g. during MD or GEO-OPT
     363         278 :       CALL dbcsr_release(pao%matrix_X_orig)
     364         278 :       CALL dbcsr_release(pao%matrix_G)
     365         278 :       CALL dbcsr_release(ls_mstruct%matrix_A)
     366         278 :       CALL dbcsr_release(ls_mstruct%matrix_B)
     367             : 
     368         278 :       CALL linesearch_finalize(pao%linesearch)
     369             : 
     370         278 :       CALL timestop(handle)
     371             :    END SUBROUTINE pao_optimization_end
     372             : 
     373             : END MODULE pao_main

Generated by: LCOV version 1.15