LCOV - code coverage report
Current view: top level - src - hfx_ace_methods.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:06f838d) Lines: 83.2 % 191 159
Test Date: 2026-06-05 07:04:50 Functions: 83.3 % 6 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Adaptively Compressed Exchange (ACE) operator for HFX.
      10              : !>        Reference: Lin, J. Chem. Theory Comput. 2016, 12, 5, 2242-2249
      11              : !>
      12              : !>  Algorithm (per spin):
      13              : !>
      14              : !>    BUILD  (first call, or every rebuild_frequency steps):
      15              : !>      1. Full HFX: ks_matrix = K_HFX + H_core, energy%ex = E_x(full)
      16              : !>      2. K_AO  = ks_matrix - H_core           (nao x nao, negative semidefinite)
      17              : !>      3. C_occ = first nocc columns of mo_coeff, redistributed to a layout
      18              : !>                 compatible with K_AO so that PDGEMM works correctly
      19              : !>      4. xi    = K_AO * C_occ                 (nao x nocc)
      20              : !>      5. M     = C_occ^T * xi                 (nocc x nocc, negative definite)
      21              : !>      6. -M    = U^T U  via Cholesky          (U upper triangular, stored in M_fm)
      22              : !>      7. W     = xi * U^{-1}                  (nao x nocc, the ACE projector)
      23              : !>      8. Apply (see below) to update ks_matrix and energy%ex
      24              : !>
      25              : !>    APPLY  (all other steps):
      26              : !>      ks_matrix = H_core - W * W^T
      27              : !>      E_x       = -0.5 * Tr[W^T * P * W]
      28              : !>
      29              : !>  Diagnostics (controlled by DBG_STALE / DBG_EXACT_EX module flags):
      30              : !>
      31              : !>    DIAG A  projector staleness  (cheap, runs every APPLY step)
      32              : !>      Computes ||W^T C_occ^current||_F / ||W^T C_occ^BUILD||_F.
      33              : !>      Ratio = 1 → W still accurate.  Ratio -> 0 → W is stale.
      34              : !>
      35              : !>    DIAG B  exact vs ACE exchange energy  (expensive: one full HFX per APPLY)
      36              : !>      Calls full HFX with just_energy=.TRUE. to get E_x^exact[P^k] and
      37              : !>      compares to E_x^ACE[P^k].  Growing |delta| confirms stale W.
      38              : !>      ACE ks_matrix and energy%ex are restored after the diagnostic.
      39              : !>
      40              : !> \author Ritama Kar
      41              : ! **************************************************************************************************
      42              : 
      43              : MODULE hfx_ace_methods
      44              : 
      45              :    USE admm_types,                      ONLY: admm_type,&
      46              :                                               get_admm_env
      47              :    USE bibliography,                    ONLY: Lin2016ACE,&
      48              :                                               cite_reference
      49              :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      50              :    USE cp_control_types,                ONLY: dft_control_type
      51              :    USE cp_dbcsr_api,                    ONLY: dbcsr_add,&
      52              :                                               dbcsr_copy,&
      53              :                                               dbcsr_create,&
      54              :                                               dbcsr_p_type,&
      55              :                                               dbcsr_release,&
      56              :                                               dbcsr_set,&
      57              :                                               dbcsr_type
      58              :    USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm,&
      59              :                                               cp_dbcsr_plus_fm_fm_t
      60              :    USE cp_fm_basic_linalg,              ONLY: cp_fm_scale,&
      61              :                                               cp_fm_trace,&
      62              :                                               cp_fm_triangular_multiply
      63              :    USE cp_fm_cholesky,                  ONLY: cp_fm_cholesky_decompose
      64              :    USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
      65              :                                               cp_fm_struct_release,&
      66              :                                               cp_fm_struct_type
      67              :    USE cp_fm_types,                     ONLY: cp_fm_create,&
      68              :                                               cp_fm_get_info,&
      69              :                                               cp_fm_release,&
      70              :                                               cp_fm_to_fm,&
      71              :                                               cp_fm_type
      72              :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      73              :                                               cp_logger_get_default_io_unit,&
      74              :                                               cp_logger_type
      75              :    USE hfx_admm_utils,                  ONLY: hfx_ks_matrix
      76              :    USE hfx_types,                       ONLY: hfx_type
      77              :    USE input_section_types,             ONLY: section_vals_type
      78              :    USE kinds,                           ONLY: dp
      79              :    USE message_passing,                 ONLY: mp_para_env_type
      80              :    USE parallel_gemm_api,               ONLY: parallel_gemm
      81              :    USE pw_types,                        ONLY: pw_r3d_rs_type
      82              :    USE qs_energy_types,                 ONLY: qs_energy_type
      83              :    USE qs_environment_types,            ONLY: get_qs_env,&
      84              :                                               qs_environment_type
      85              :    USE qs_mo_types,                     ONLY: get_mo_set,&
      86              :                                               mo_set_type
      87              :    USE qs_rho_types,                    ONLY: qs_rho_get,&
      88              :                                               qs_rho_type
      89              :    USE scf_control_types,               ONLY: scf_control_type
      90              : #include "./base/base_uses.f90"
      91              : 
      92              :    IMPLICIT NONE
      93              :    PRIVATE
      94              : 
      95              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'hfx_ace_methods'
      96              : 
      97              :    ! -----------------------------------------------------------------------
      98              :    ! Module-level state: persists across SCF steps within one run.
      99              :    !
     100              :    !   ace_W(1,ispin)   ACE projector W, shape nao x nocc
     101              :    !   ace_is_built     .FALSE. until at least one successful BUILD
     102              :    !   ace_step_counter counts calls since last BUILD
     103              :    !   ace_W_ref_norm   ||W^T C_occ^BUILD||_F stored at BUILD for DIAG A
     104              :    ! -----------------------------------------------------------------------
     105              :    TYPE(cp_fm_type), ALLOCATABLE, SAVE :: ace_W(:, :)
     106              :    LOGICAL, SAVE :: ace_is_built = .FALSE.
     107              :    INTEGER, SAVE :: ace_step_counter = 0
     108              :    REAL(dp), SAVE :: ace_W_ref_norm = 0.0_dp
     109              :    INTEGER, SAVE :: ace_geo_step = 0   ! NEW
     110              : 
     111              :    ! -----------------------------------------------------------------------
     112              :    ! Debug / diagnostic flags  — set all .FALSE. for production.
     113              :    !
     114              :    !   DBG_ROUTING   BUILD/APPLY/DEFER decisions, counters
     115              :    !   DBG_BUILD     norms during BUILD
     116              :    !   DBG_ENERGY    E_x(ACE) at every step; DIAG C on BUILD steps
     117              :    !   DBG_STALE     DIAG A: staleness ratio at every APPLY step (cheap)
     118              :    !   DBG_EXACT_EX  DIAG B: full HFX energy at every APPLY step (expensive)
     119              :    ! -----------------------------------------------------------------------
     120              :    LOGICAL, PARAMETER, PRIVATE :: DBG_ROUTING = .FALSE.
     121              :    LOGICAL, PARAMETER, PRIVATE :: DBG_BUILD = .FALSE.
     122              :    LOGICAL, PARAMETER, PRIVATE :: DBG_ENERGY = .FALSE.
     123              :    LOGICAL, PARAMETER, PRIVATE :: DBG_STALE = .FALSE.
     124              :    LOGICAL, PARAMETER, PRIVATE :: DBG_EXACT_EX = .FALSE.
     125              : 
     126              :    LOGICAL, SAVE :: ace_dynamic_mode = .FALSE.
     127              :    ! Set to .TRUE. by hfx_ace_set_dynamic_mode before geo_opt/MD starts.
     128              :    ! Stays .FALSE. for ENERGY/ENERGY_FORCE single-point runs.
     129              : 
     130              :    PUBLIC :: hfx_ace_ks_matrix, hfx_ace_release, hfx_ace_set_dynamic_mode
     131              : 
     132              : CONTAINS
     133              : 
     134              : ! **************************************************************************************************
     135              : !> \brief Main ACE entry point, replacing hfx_ks_matrix in qs_ks_methods.
     136              : !> \param qs_env ...
     137              : !> \param ks_matrix ...
     138              : !> \param rho ...
     139              : !> \param energy ...
     140              : !> \param calculate_forces ...
     141              : !> \param just_energy ...
     142              : !> \param v_rspace_new ...
     143              : !> \param v_tau_rspace ...
     144              : !> \param ace_rebuild_frequency ...
     145              : !> \param ext_xc_section ...
     146              : ! **************************************************************************************************
     147           48 :    SUBROUTINE hfx_ace_ks_matrix(qs_env, ks_matrix, rho, energy, &
     148              :                                 calculate_forces, just_energy, &
     149              :                                 v_rspace_new, v_tau_rspace, &
     150              :                                 ace_rebuild_frequency, ext_xc_section)
     151              : 
     152              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     153              :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     154              :       TYPE(qs_rho_type), POINTER                         :: rho
     155              :       TYPE(qs_energy_type), POINTER                      :: energy
     156              :       LOGICAL, INTENT(IN)                                :: calculate_forces, just_energy
     157              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: v_rspace_new, v_tau_rspace
     158              :       INTEGER, INTENT(IN)                                :: ace_rebuild_frequency
     159              :       TYPE(section_vals_type), OPTIONAL, POINTER         :: ext_xc_section
     160              : 
     161              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'hfx_ace_ks_matrix'
     162              : 
     163              :       INTEGER                                            :: handle, iw, n_rep_hf, nspins, &
     164              :                                                             rebuild_freq
     165              :       LOGICAL                                            :: ace_built_now, rebuild_ace
     166              :       REAL(dp)                                           :: ex_ace
     167              :       TYPE(cp_logger_type), POINTER                      :: logger
     168              :       TYPE(dft_control_type), POINTER                    :: dft_control
     169           48 :       TYPE(hfx_type), DIMENSION(:, :), POINTER           :: x_data
     170              :       TYPE(scf_control_type), POINTER                    :: scf_control
     171              : 
     172           48 :       CALL timeset(routineN, handle)
     173              : 
     174           48 :       CALL cite_reference(Lin2016ACE)
     175           48 :       NULLIFY (logger, dft_control, x_data, scf_control)
     176              : 
     177           48 :       logger => cp_get_default_logger()
     178           48 :       iw = cp_logger_get_default_io_unit(logger)
     179              : 
     180           48 :       CALL get_qs_env(qs_env, x_data=x_data, dft_control=dft_control)
     181           48 :       n_rep_hf = SIZE(x_data, 1)
     182           48 :       nspins = dft_control%nspins
     183              : 
     184           48 :       IF (n_rep_hf /= 1) CPABORT("ACE: only one &HF section is supported.")
     185           48 :       IF (dft_control%nimages /= 1) &
     186            0 :          CPABORT("ACE: k-points / multiple images are not implemented.")
     187              : 
     188              :       ! ACE requires explicit MO coefficients (C_occ) which are only available
     189              :       ! with diagonalization-based SCF.  OT never constructs mo_coeff during
     190              :       ! the SCF, so the ACE projector build loop would silently get garbage.
     191           48 :       CALL get_qs_env(qs_env, scf_control=scf_control)
     192           48 :       IF (scf_control%use_ot) &
     193            0 :          CPABORT("ACE: OT doesn't work, use diagonalization-based SCF.")
     194              : 
     195           48 :       rebuild_freq = MAX(1, ace_rebuild_frequency)
     196              : 
     197              :       ! ------------------------------------------------------------------
     198              :       ! Bypass A: energy-only call
     199              :       ! ------------------------------------------------------------------
     200           48 :       IF (just_energy) THEN
     201              :          IF (DBG_ROUTING .AND. iw > 0) &
     202              :             WRITE (iw, '(T2,A)') 'ACE | just_energy=T: full HFX (no matrix update)'
     203              :          CALL hfx_call(qs_env, ks_matrix, rho, energy, &
     204              :                        calculate_forces, just_energy, &
     205            0 :                        v_rspace_new, v_tau_rspace, ext_xc_section)
     206            0 :          CALL timestop(handle)
     207            8 :          RETURN
     208              :       END IF
     209              : 
     210              :       ! ------------------------------------------------------------------
     211              :       ! Bypass B: ionic forces requested
     212              :       ! ------------------------------------------------------------------
     213           48 :       IF (calculate_forces) THEN
     214              :          IF (DBG_ROUTING .AND. iw > 0) &
     215              :             WRITE (iw, '(T2,A)') 'ACE | calculate_forces=T: full HFX for exact forces'
     216              :          CALL hfx_call(qs_env, ks_matrix, rho, energy, &
     217              :                        calculate_forces, just_energy, &
     218            8 :                        v_rspace_new, v_tau_rspace, ext_xc_section)
     219            8 :          ace_is_built = .FALSE.
     220            8 :          ace_step_counter = 0
     221            8 :          ace_geo_step = ace_geo_step + 1   ! NEW: step 0 done, ACE active from now
     222            8 :          CALL timestop(handle)
     223            8 :          RETURN
     224              :       END IF
     225              : 
     226              :       ! ------------------------------------------------------------------
     227              :       ! Bypass C: first geometry step.
     228              :       !
     229              :       ! The ATOMIC initial guess gives C_occ far from self-consistency,
     230              :       ! which would produce an inaccurate W.  Running full HFX for the
     231              :       ! entire first geometry step ensures that wavefunction extrapolation
     232              :       ! delivers a near-converged C_occ to geometry step 1, making the
     233              :       ! ACE BUILD there accurate from the first application.
     234              :       ! ------------------------------------------------------------------
     235           40 :       IF (ace_geo_step == 0 .AND. ace_dynamic_mode) THEN
     236            0 :          IF (iw > 0) WRITE (iw, '(T2,A)') &
     237            0 :             'ACE | geo_step=0 (MD/GEO_OPT): full HFX for reference wavefunction'
     238              :          CALL hfx_call(qs_env, ks_matrix, rho, energy, &
     239              :                        .FALSE., just_energy, &
     240            0 :                        v_rspace_new, v_tau_rspace, ext_xc_section)
     241            0 :          CALL timestop(handle)
     242            0 :          RETURN
     243              :       END IF
     244              : 
     245              :       ! ------------------------------------------------------------------
     246              :       ! Rebuild decision
     247              :       ! ------------------------------------------------------------------
     248              :       rebuild_ace = (.NOT. ace_is_built) .OR. &
     249           40 :                     (MOD(ace_step_counter, rebuild_freq) == 0)
     250              : 
     251              :       IF (DBG_ROUTING .AND. iw > 0) THEN
     252              :          WRITE (iw, '(/,T2,A)') REPEAT('-', 56)
     253              :          WRITE (iw, '(T2,A)') 'ACE | hfx_ace_ks_matrix'
     254              :          WRITE (iw, '(T4,A,L1)') 'ace_is_built  = ', ace_is_built
     255              :          WRITE (iw, '(T4,A,L1)') 'rebuild_ace   = ', rebuild_ace
     256              :          WRITE (iw, '(T4,A,I6)') 'step_counter  = ', ace_step_counter
     257              :          WRITE (iw, '(T4,A,I6)') 'rebuild_freq  = ', rebuild_freq
     258              :          WRITE (iw, '(T4,A,I4)') 'nspins        = ', nspins
     259              :          WRITE (iw, '(T4,A)') MERGE('-> BUILD', '-> APPLY', rebuild_ace)
     260              :          WRITE (iw, '(T2,A)') REPEAT('-', 56)
     261              :       END IF
     262              : 
     263           10 :       IF (rebuild_ace) THEN
     264              : 
     265              :          ace_built_now = .FALSE.
     266              :          CALL hfx_ace_build_projector(qs_env, ks_matrix, rho, energy, &
     267              :                                       just_energy, &
     268              :                                       v_rspace_new, v_tau_rspace, &
     269              :                                       nspins, iw, ace_built_now, &
     270           10 :                                       ext_xc_section)
     271           10 :          IF (ace_built_now) THEN
     272            8 :             ace_is_built = .TRUE.
     273            8 :             ace_step_counter = 1
     274              :             IF (DBG_ROUTING .AND. iw > 0) &
     275              :                WRITE (iw, '(T4,A)') 'ACE | W built. Projector live from next step.'
     276              :          ELSE
     277            2 :             ace_is_built = .FALSE.
     278            2 :             ace_step_counter = 0
     279              :             IF (DBG_ROUTING .AND. iw > 0) &
     280              :                WRITE (iw, '(T4,A)') 'ACE | Build deferred (C_occ=0). Full HFX in ks_matrix.'
     281              :          END IF
     282              : 
     283              :       ELSE
     284              : 
     285           30 :          CALL hfx_ace_apply_projector(qs_env, ks_matrix, rho, energy, nspins, iw)
     286           30 :          ace_step_counter = ace_step_counter + 1
     287              : 
     288              :          ! ----------------------------------------------------------------
     289              :          ! DIAGNOSTIC B: compare E_x^ACE[P^k] with E_x^exact[P^k].
     290              :          !
     291              :          ! Calls full HFX (just_energy=.TRUE.) to get the exact exchange
     292              :          ! energy at the current ACE-converging density P^k.  Compares
     293              :          ! with E_x^ACE[P^k] already stored in energy%ex.
     294              :          !
     295              :          ! Growing |delta| over the SCF confirms the root cause: W was
     296              :          ! built from C_occ^(step 1), which is far from the converged
     297              :          ! C_occ, so K_ACE = -WW^T no longer represents K_x accurately.
     298              :          !
     299              :          ! After the comparison, hfx_ace_apply_projector is called a
     300              :          ! second time to restore ks_matrix and energy%ex to ACE values
     301              :          ! so the SCF continues correctly.  Cost: +2 full HFX per step.
     302              :          ! ----------------------------------------------------------------
     303              :          IF (DBG_EXACT_EX) THEN
     304              :             ex_ace = energy%ex
     305              : 
     306              :             CALL hfx_call(qs_env, ks_matrix, rho, energy, &
     307              :                           .FALSE., .TRUE., &
     308              :                           v_rspace_new, v_tau_rspace, ext_xc_section)
     309              :             IF (iw > 0) THEN
     310              :                WRITE (iw, '(/,T2,A)') REPEAT('-', 56)
     311              :                WRITE (iw, '(T2,A,I6)') 'ACE DIAG B | ace_step_counter = ', ace_step_counter
     312              :                WRITE (iw, '(T4,A,F20.10)') 'E_x(exact, P^k)  = ', energy%ex
     313              :                WRITE (iw, '(T4,A,F20.10)') 'E_x(ACE,   P^k)  = ', ex_ace
     314              :                WRITE (iw, '(T4,A,ES12.4)') '|delta|          = ', ABS(ex_ace - energy%ex)
     315              :                WRITE (iw, '(T4,A)') &
     316              :                   '|delta|->0 on BUILD step; growth confirms stale projector'
     317              :                WRITE (iw, '(T2,A)') REPEAT('-', 56)
     318              :             END IF
     319              : 
     320              :             ! Restore ACE ks_matrix and energy%ex
     321              :             CALL hfx_ace_apply_projector(qs_env, ks_matrix, rho, energy, nspins, iw)
     322              :          END IF
     323              : 
     324              :       END IF
     325              : 
     326              :       IF (DBG_ROUTING .AND. iw > 0) THEN
     327              :          WRITE (iw, '(T4,A,F20.10)') 'energy%ex on exit = ', energy%ex
     328              :          WRITE (iw, '(T4,A,I6)') 'step_counter now  = ', ace_step_counter
     329              :       END IF
     330              : 
     331           40 :       CALL timestop(handle)
     332              : 
     333           48 :    END SUBROUTINE hfx_ace_ks_matrix
     334              : 
     335              : ! **************************************************************************************************
     336              : !> \brief Build the ACE projector W.
     337              : !> \param qs_env ...
     338              : !> \param ks_matrix ...
     339              : !> \param rho ...
     340              : !> \param energy ...
     341              : !> \param just_energy ...
     342              : !> \param v_rspace_new ...
     343              : !> \param v_tau_rspace ...
     344              : !> \param nspins ...
     345              : !> \param iw ...
     346              : !> \param build_succeeded ...
     347              : !> \param ext_xc_section ...
     348              : ! **************************************************************************************************
     349           10 :    SUBROUTINE hfx_ace_build_projector(qs_env, ks_matrix, rho, energy, &
     350              :                                       just_energy, &
     351              :                                       v_rspace_new, v_tau_rspace, &
     352              :                                       nspins, iw, build_succeeded, &
     353              :                                       ext_xc_section)
     354              : 
     355              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     356              :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     357              :       TYPE(qs_rho_type), POINTER                         :: rho
     358              :       TYPE(qs_energy_type), POINTER                      :: energy
     359              :       LOGICAL, INTENT(IN)                                :: just_energy
     360              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: v_rspace_new, v_tau_rspace
     361              :       INTEGER, INTENT(IN)                                :: nspins, iw
     362              :       LOGICAL, INTENT(OUT)                               :: build_succeeded
     363              :       TYPE(section_vals_type), OPTIONAL, POINTER         :: ext_xc_section
     364              : 
     365              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ace_build_projector'
     366              : 
     367              :       INTEGER                                            :: handle, info_chol, ispin, nao, nmo, nocc
     368              :       LOGICAL                                            :: do_admm
     369              :       REAL(dp)                                           :: ehfx_full, frob
     370           10 :       REAL(dp), DIMENSION(:), POINTER                    :: occ_nums
     371              :       TYPE(admm_type), POINTER                           :: admm_env
     372              :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
     373              :       TYPE(cp_fm_struct_type), POINTER                   :: fmstruct
     374              :       TYPE(cp_fm_type)                                   :: A_ref_fm, C_occ_fm, K_fm, M_fm, xi_fm
     375              :       TYPE(cp_fm_type), POINTER                          :: mo_coeff
     376           10 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: ks_aux_fit
     377           10 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_h
     378              :       TYPE(dbcsr_type)                                   :: K_ao_dbcsr
     379              :       TYPE(dft_control_type), POINTER                    :: dft_control
     380           10 :       TYPE(mo_set_type), DIMENSION(:), POINTER           :: mos, mos_for_ace
     381              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     382              : 
     383              : ! A_ref_fm: temporary nocc x nocc scratch for DIAG A reference norm
     384              : 
     385           10 :       CALL timeset(routineN, handle)
     386           10 :       NULLIFY (blacs_env, para_env, mos, mo_coeff, matrix_h, occ_nums, fmstruct)
     387           10 :       NULLIFY (dft_control, admm_env, mos_for_ace, ks_aux_fit)
     388              : 
     389           10 :       build_succeeded = .FALSE.
     390              : 
     391              :       CALL get_qs_env(qs_env, blacs_env=blacs_env, para_env=para_env, &
     392              :                       mos=mos, matrix_h_kp=matrix_h, &
     393           10 :                       dft_control=dft_control)
     394              : 
     395           10 :       do_admm = dft_control%do_admm
     396              : 
     397           10 :       IF (do_admm) THEN
     398            6 :          CALL get_qs_env(qs_env, admm_env=admm_env)
     399              :          CALL get_admm_env(admm_env, &
     400              :                            matrix_ks_aux_fit=ks_aux_fit, &
     401            6 :                            mos_aux_fit=mos_for_ace)
     402              :       ELSE
     403            4 :          mos_for_ace => mos
     404              :       END IF
     405              : 
     406              :       ! Step 1: full HFX
     407              :       CALL hfx_call(qs_env, ks_matrix, rho, energy, &
     408              :                     .FALSE., just_energy, &
     409           10 :                     v_rspace_new, v_tau_rspace, ext_xc_section)
     410           10 :       ehfx_full = energy%ex
     411              : 
     412              :       IF (DBG_BUILD .AND. iw > 0) &
     413              :          WRITE (iw, '(/,T2,A,F20.10)') 'ACE BUILD | E_x(full HFX) = ', ehfx_full
     414              : 
     415              :       ! Allocate / resize module storage
     416           10 :       IF (ALLOCATED(ace_W)) THEN
     417            2 :          IF (SIZE(ace_W, 2) /= nspins) CALL hfx_ace_release()
     418              :       END IF
     419           42 :       IF (.NOT. ALLOCATED(ace_W)) ALLOCATE (ace_W(1, nspins))
     420              : 
     421              :       ! Reset reference norm; accumulated per spin in the loop below
     422           10 :       ace_W_ref_norm = 0.0_dp
     423              : 
     424              :       ! ----------------------------------------------------------------
     425              :       ! Per-spin build loop
     426              :       ! ----------------------------------------------------------------
     427           18 :       DO ispin = 1, nspins
     428              : 
     429           10 :          IF (mos_for_ace(ispin)%use_mo_coeff_b) &
     430              :             CALL copy_dbcsr_to_fm(mos_for_ace(ispin)%mo_coeff_b, &
     431            0 :                                   mos_for_ace(ispin)%mo_coeff)
     432              : 
     433              :          CALL get_mo_set(mos_for_ace(ispin), mo_coeff=mo_coeff, &
     434              :                          nao=nao, nmo=nmo, homo=nocc, &
     435           10 :                          occupation_numbers=occ_nums)
     436              : 
     437           10 :          IF (nocc <= 0) CPABORT("ACE: homo <= 0.")
     438           10 :          IF (nocc > nmo) CPABORT("ACE: homo > nmo.")
     439           10 :          CPASSERT(ASSOCIATED(mo_coeff))
     440              : 
     441           10 :          CALL cp_fm_trace(mo_coeff, mo_coeff, frob)
     442              : 
     443              :          IF (DBG_BUILD .AND. iw > 0) THEN
     444              :             WRITE (iw, '(/,T2,A,I4)') 'ACE BUILD | ispin = ', ispin
     445              :             WRITE (iw, '(T4,A,I8)') 'nao  = ', nao
     446              :             WRITE (iw, '(T4,A,I8)') 'nmo  = ', nmo
     447              :             WRITE (iw, '(T4,A,I8)') 'nocc = ', nocc
     448              :             WRITE (iw, '(T4,A,L1)') 'use_mo_coeff_b = ', mos_for_ace(ispin)%use_mo_coeff_b
     449              :             WRITE (iw, '(T4,A,ES12.4)') '||mo_coeff||_F = ', SQRT(MAX(frob, 0.0_dp))
     450              :          END IF
     451              : 
     452           10 :          IF (frob < 1.0e-20_dp) THEN
     453              :             IF (DBG_BUILD .AND. iw > 0) &
     454              :                WRITE (iw, '(T4,A)') 'mo_coeff=0: build deferred to next step.'
     455            2 :             CALL timestop(handle)
     456            2 :             RETURN
     457              :          END IF
     458              : 
     459              :          ! Step 2: K_AO
     460            8 :          IF (do_admm) THEN
     461              :             CALL dbcsr_create(K_ao_dbcsr, template=ks_aux_fit(ispin)%matrix, &
     462            6 :                               name="K_ACE_aux")
     463            6 :             CALL dbcsr_copy(K_ao_dbcsr, ks_aux_fit(ispin)%matrix)
     464              :          ELSE
     465              :             CALL dbcsr_create(K_ao_dbcsr, template=ks_matrix(ispin, 1)%matrix, &
     466            2 :                               name="K_AO")
     467            2 :             CALL dbcsr_copy(K_ao_dbcsr, ks_matrix(ispin, 1)%matrix)
     468            2 :             CALL dbcsr_add(K_ao_dbcsr, matrix_h(1, 1)%matrix, 1.0_dp, -1.0_dp)
     469              :          END IF
     470              : 
     471            8 :          NULLIFY (fmstruct)
     472              :          CALL cp_fm_struct_create(fmstruct, context=blacs_env, para_env=para_env, &
     473            8 :                                   nrow_global=nao, ncol_global=nao)
     474            8 :          CALL cp_fm_create(K_fm, fmstruct, name="K_dense")
     475            8 :          CALL cp_fm_struct_release(fmstruct)
     476            8 :          CALL copy_dbcsr_to_fm(K_ao_dbcsr, K_fm)
     477            8 :          CALL dbcsr_release(K_ao_dbcsr)
     478              : 
     479              :          IF (DBG_BUILD .AND. iw > 0) THEN
     480              :             CALL cp_fm_trace(K_fm, K_fm, frob)
     481              :             WRITE (iw, '(T4,A,ES12.4)') '||K_AO||_F = ', SQRT(MAX(frob, 0.0_dp))
     482              :          END IF
     483              : 
     484              :          ! Step 3: C_occ
     485            8 :          NULLIFY (fmstruct)
     486              :          CALL cp_fm_struct_create(fmstruct, context=blacs_env, para_env=para_env, &
     487            8 :                                   nrow_global=nao, ncol_global=nocc)
     488            8 :          CALL cp_fm_create(C_occ_fm, fmstruct, name="C_occ")
     489            8 :          CALL cp_fm_create(xi_fm, fmstruct, name="xi")
     490            8 :          CALL cp_fm_struct_release(fmstruct)
     491              : 
     492            8 :          CALL cp_fm_to_fm(mo_coeff, C_occ_fm)
     493              : 
     494              :          ! Step 4: xi = K_AO * C_occ
     495              :          CALL parallel_gemm('N', 'N', nao, nocc, nao, &
     496            8 :                             1.0_dp, K_fm, C_occ_fm, 0.0_dp, xi_fm)
     497            8 :          CALL cp_fm_release(K_fm)
     498              : 
     499              :          IF (DBG_BUILD .AND. iw > 0) THEN
     500              :             CALL cp_fm_trace(xi_fm, xi_fm, frob)
     501              :             WRITE (iw, '(T4,A,ES12.4)') '||xi||_F = ', SQRT(MAX(frob, 0.0_dp))
     502              :          END IF
     503              : 
     504              :          ! Step 5: M = C_occ^T * xi
     505            8 :          NULLIFY (fmstruct)
     506              :          CALL cp_fm_struct_create(fmstruct, context=blacs_env, para_env=para_env, &
     507            8 :                                   nrow_global=nocc, ncol_global=nocc)
     508            8 :          CALL cp_fm_create(M_fm, fmstruct, name="M")
     509            8 :          CALL cp_fm_struct_release(fmstruct)
     510              : 
     511              :          CALL parallel_gemm('T', 'N', nocc, nocc, nao, &
     512            8 :                             1.0_dp, C_occ_fm, xi_fm, 0.0_dp, M_fm)
     513            8 :          CALL cp_fm_release(C_occ_fm)
     514              : 
     515              :          IF (DBG_BUILD .AND. iw > 0) THEN
     516              :             CALL cp_fm_trace(M_fm, M_fm, frob)
     517              :             WRITE (iw, '(T4,A,ES12.4)') '||M||_F = ', SQRT(MAX(frob, 0.0_dp))
     518              :          END IF
     519              : 
     520              :          ! Step 6: Cholesky of -M = U^T U
     521            8 :          CALL cp_fm_scale(-1.0_dp, M_fm)
     522            8 :          CALL cp_fm_cholesky_decompose(M_fm, n=nocc, info_out=info_chol)
     523              : 
     524            8 :          IF (info_chol /= 0) THEN
     525            0 :             IF (iw > 0) THEN
     526            0 :                WRITE (iw, '(T4,A,I6)') 'ACE | Cholesky failed, info = ', info_chol
     527            0 :                WRITE (iw, '(T4,A,F20.10)') 'ACE | E_x(full) = ', ehfx_full
     528            0 :                WRITE (iw, '(T4,A,I8,A,I8)') 'ACE | nao=', nao, '  nocc=', nocc
     529              :             END IF
     530            0 :             CPABORT("ACE: Cholesky of -M failed (not positive definite).")
     531              :          END IF
     532              : 
     533              :          IF (DBG_BUILD .AND. iw > 0) &
     534              :             WRITE (iw, '(T4,A)') 'Cholesky OK (info=0).'
     535              : 
     536              :          ! Step 7: W = xi * U^{-1}
     537            8 :          IF (ASSOCIATED(ace_W(1, ispin)%matrix_struct)) &
     538            0 :             CALL cp_fm_release(ace_W(1, ispin))
     539              : 
     540            8 :          CALL cp_fm_create(ace_W(1, ispin), xi_fm%matrix_struct, name="W_ACE")
     541            8 :          CALL cp_fm_to_fm(xi_fm, ace_W(1, ispin))
     542              : 
     543              :          CALL cp_fm_triangular_multiply(M_fm, ace_W(1, ispin), &
     544              :                                         side='R', uplo_tr='U', &
     545              :                                         transpose_tr=.FALSE., &
     546              :                                         invert_tr=.TRUE., &
     547              :                                         n_rows=nao, n_cols=nocc, &
     548            8 :                                         alpha=1.0_dp)
     549              : 
     550            8 :          CALL cp_fm_release(xi_fm)
     551            8 :          CALL cp_fm_release(M_fm)
     552              : 
     553              :          IF (DBG_BUILD .AND. iw > 0) THEN
     554              :             CALL cp_fm_trace(ace_W(1, ispin), ace_W(1, ispin), frob)
     555              :             WRITE (iw, '(T4,A,I4,A,2I8,A,ES12.4)') &
     556              :                'W spin=', ispin, ' shape=', nao, nocc, &
     557              :                '  ||W||_F=', SQRT(MAX(frob, 0.0_dp))
     558              :          END IF
     559              : 
     560              :          ! ----------------------------------------------------------------
     561              :          ! DIAG A reference: compute ||W^T C_occ^BUILD||_F for this spin.
     562              :          !
     563              :          ! C_occ_fm was released after step 5, but mo_coeff is still valid
     564              :          ! (it is a pointer into mos_for_ace, not allocated here).
     565              :          ! We re-create C_occ_fm from mo_coeff.
     566              :          !
     567              :          ! Theory: W^T C_occ^BUILD = U^{-T}(-U^T U) = -U  →  norm = ||U||_F
     568              :          ! (computed directly rather than storing U).
     569              :          ! ----------------------------------------------------------------
     570           60 :          IF (DBG_STALE) THEN
     571              :             NULLIFY (fmstruct)
     572              :             CALL cp_fm_struct_create(fmstruct, context=blacs_env, &
     573              :                                      para_env=para_env, &
     574              :                                      nrow_global=nao, ncol_global=nocc)
     575              :             CALL cp_fm_create(C_occ_fm, fmstruct, name="C_occ_ref_diag")
     576              :             CALL cp_fm_struct_release(fmstruct)
     577              :             CALL cp_fm_to_fm(mo_coeff, C_occ_fm)
     578              : 
     579              :             NULLIFY (fmstruct)
     580              :             CALL cp_fm_struct_create(fmstruct, context=blacs_env, &
     581              :                                      para_env=para_env, &
     582              :                                      nrow_global=nocc, ncol_global=nocc)
     583              :             CALL cp_fm_create(A_ref_fm, fmstruct, name="WtC_ref")
     584              :             CALL cp_fm_struct_release(fmstruct)
     585              : 
     586              :             CALL parallel_gemm('T', 'N', nocc, nocc, nao, &
     587              :                                1.0_dp, ace_W(1, ispin), C_occ_fm, 0.0_dp, A_ref_fm)
     588              :             CALL cp_fm_trace(A_ref_fm, A_ref_fm, frob)
     589              :             ace_W_ref_norm = ace_W_ref_norm + SQRT(MAX(frob, 0.0_dp))
     590              : 
     591              :             CALL cp_fm_release(C_occ_fm)
     592              :             CALL cp_fm_release(A_ref_fm)
     593              :          END IF
     594              : 
     595              :       END DO  ! ispin
     596              : 
     597              :       IF (DBG_STALE .AND. iw > 0) THEN
     598              :          WRITE (iw, '(/,T2,A)') REPEAT('-', 56)
     599              :          WRITE (iw, '(T2,A)') 'ACE DIAG A | Reference norm stored at BUILD'
     600              :          WRITE (iw, '(T4,A,ES12.4)') &
     601              :             '||W^T C_occ^BUILD||_F (sum over spins) = ', ace_W_ref_norm
     602              :          WRITE (iw, '(T4,A)') &
     603              :             'Staleness ratio = 1.0 at BUILD step; decreasing means W is becoming stale'
     604              :          WRITE (iw, '(T2,A)') REPEAT('-', 56)
     605              :       END IF
     606              : 
     607            8 :       build_succeeded = .TRUE.
     608              : 
     609              :       ! Step 8: apply immediately (DIAG C ratio printed via ehfx_full_ref)
     610              :       CALL hfx_ace_apply_projector(qs_env, ks_matrix, rho, energy, &
     611            8 :                                    nspins, iw, ehfx_full_ref=ehfx_full)
     612              : 
     613            8 :       CALL timestop(handle)
     614              : 
     615           10 :    END SUBROUTINE hfx_ace_build_projector
     616              : 
     617              : ! **************************************************************************************************
     618              : !> \brief Apply the stored ACE projector.
     619              : !> \param qs_env ...
     620              : !> \param ks_matrix ...
     621              : !> \param rho ...
     622              : !> \param energy ...
     623              : !> \param nspins ...
     624              : !> \param iw ...
     625              : !> \param ehfx_full_ref ...
     626              : ! **************************************************************************************************
     627           38 :    SUBROUTINE hfx_ace_apply_projector(qs_env, ks_matrix, rho, energy, &
     628              :                                       nspins, iw, ehfx_full_ref)
     629              : 
     630              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     631              :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     632              :       TYPE(qs_rho_type), POINTER                         :: rho
     633              :       TYPE(qs_energy_type), POINTER                      :: energy
     634              :       INTEGER, INTENT(IN)                                :: nspins, iw
     635              :       REAL(dp), INTENT(IN), OPTIONAL                     :: ehfx_full_ref
     636              : 
     637              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ace_apply_projector'
     638              : 
     639              :       INTEGER                                            :: handle, ispin, nao, nao_d, nmo_d, nocc, &
     640              :                                                             nocc_d
     641              :       LOGICAL                                            :: do_admm
     642              :       REAL(dp)                                           :: ehfx_ace, frob_A, stale_norm, trace_val
     643              :       TYPE(admm_type), POINTER                           :: admm_env
     644              :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
     645              :       TYPE(cp_fm_struct_type), POINTER                   :: fmstruct, fmstruct_diag
     646              :       TYPE(cp_fm_type)                                   :: A_diag_fm, C_occ_diag, P_fm, PW_fm
     647              :       TYPE(cp_fm_type), POINTER                          :: mo_coeff_diag
     648           38 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: ks_aux_fit, ks_aux_fit_hfx
     649           38 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_h, rho_ao
     650              :       TYPE(dft_control_type), POINTER                    :: dft_control
     651           38 :       TYPE(mo_set_type), DIMENSION(:), POINTER           :: mos_aux_diag, mos_diag
     652              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     653              :       TYPE(qs_rho_type), POINTER                         :: rho_aux_fit
     654              : 
     655              : ! ------------------------------------------------------------------
     656              : ! DIAG A local variables
     657              : !   stale_norm    ||W^T C_occ^current||_F summed over spins
     658              : !   frob_A        scratch for cp_fm_trace
     659              : !   C_occ_diag    current C_occ redistributed to W layout
     660              : !   A_diag_fm     nocc x nocc overlap  W^T C_occ^current
     661              : !   mos_diag      primary mos (non-ADMM path)
     662              : !   mos_aux_diag  auxiliary mos (ADMM path)
     663              : !   mo_coeff_diag pointer to the relevant mo_coeff
     664              : !   nao_d, nmo_d, nocc_d  dimensions from get_mo_set
     665              : ! ------------------------------------------------------------------
     666              : 
     667           38 :       CALL timeset(routineN, handle)
     668           38 :       NULLIFY (blacs_env, para_env, matrix_h, rho_ao, fmstruct)
     669           38 :       NULLIFY (dft_control, admm_env, ks_aux_fit, ks_aux_fit_hfx, rho_aux_fit)
     670           38 :       NULLIFY (mos_diag, mos_aux_diag, mo_coeff_diag, fmstruct_diag)
     671           38 :       CPASSERT(ALLOCATED(ace_W))
     672              : 
     673              :       CALL get_qs_env(qs_env, blacs_env=blacs_env, para_env=para_env, &
     674           38 :                       matrix_h_kp=matrix_h, dft_control=dft_control)
     675           38 :       do_admm = dft_control%do_admm
     676              : 
     677           38 :       IF (do_admm) THEN
     678           30 :          CALL get_qs_env(qs_env, admm_env=admm_env)
     679              :          CALL get_admm_env(admm_env, &
     680              :                            matrix_ks_aux_fit=ks_aux_fit, &
     681              :                            matrix_ks_aux_fit_hfx=ks_aux_fit_hfx, &
     682           30 :                            rho_aux_fit=rho_aux_fit)
     683           30 :          CALL qs_rho_get(rho_aux_fit, rho_ao_kp=rho_ao)
     684              :       ELSE
     685            8 :          CALL qs_rho_get(rho, rho_ao_kp=rho_ao)
     686              :       END IF
     687              : 
     688           76 :       DO ispin = 1, nspins
     689           38 :          CALL cp_fm_get_info(ace_W(1, ispin), nrow_global=nao, ncol_global=nocc)
     690              : 
     691           76 :          IF (do_admm) THEN
     692           30 :             CALL dbcsr_set(ks_matrix(ispin, 1)%matrix, 0.0_dp)
     693              :             CALL dbcsr_add(ks_matrix(ispin, 1)%matrix, &
     694           30 :                            matrix_h(1, 1)%matrix, 1.0_dp, 1.0_dp)
     695           30 :             CALL dbcsr_set(ks_aux_fit(ispin)%matrix, 0.0_dp)
     696              :             CALL cp_dbcsr_plus_fm_fm_t( &
     697              :                sparse_matrix=ks_aux_fit(ispin)%matrix, &
     698              :                matrix_v=ace_W(1, ispin), &
     699              :                ncol=nocc, &
     700              :                alpha=-1.0_dp, &
     701           30 :                keep_sparsity=.TRUE.)
     702              :             CALL dbcsr_add(ks_aux_fit_hfx(ispin)%matrix, &
     703           30 :                            ks_aux_fit(ispin)%matrix, 0.0_dp, 1.0_dp)
     704              :          ELSE
     705            8 :             CALL dbcsr_set(ks_matrix(ispin, 1)%matrix, 0.0_dp)
     706              :             CALL cp_dbcsr_plus_fm_fm_t( &
     707              :                sparse_matrix=ks_matrix(ispin, 1)%matrix, &
     708              :                matrix_v=ace_W(1, ispin), &
     709              :                ncol=nocc, &
     710              :                alpha=-1.0_dp, &
     711            8 :                keep_sparsity=.TRUE.)
     712              :             CALL dbcsr_add(ks_matrix(ispin, 1)%matrix, &
     713            8 :                            matrix_h(1, 1)%matrix, 1.0_dp, 1.0_dp)
     714              :          END IF
     715              :       END DO
     716              : 
     717              :       ! Exchange energy: E_x = -0.5 * sum_spin Tr[ W^T * P * W ]
     718              :       ehfx_ace = 0.0_dp
     719           76 :       DO ispin = 1, nspins
     720           38 :          CALL cp_fm_get_info(ace_W(1, ispin), nrow_global=nao, ncol_global=nocc)
     721              : 
     722           38 :          NULLIFY (fmstruct)
     723              :          CALL cp_fm_struct_create(fmstruct, context=blacs_env, para_env=para_env, &
     724           38 :                                   nrow_global=nao, ncol_global=nao)
     725           38 :          CALL cp_fm_create(P_fm, fmstruct, name="P_dense")
     726           38 :          CALL cp_fm_struct_release(fmstruct)
     727           38 :          CALL copy_dbcsr_to_fm(rho_ao(ispin, 1)%matrix, P_fm)
     728              : 
     729           38 :          CALL cp_fm_create(PW_fm, ace_W(1, ispin)%matrix_struct, name="PW")
     730              :          CALL parallel_gemm('N', 'N', nao, nocc, nao, &
     731           38 :                             1.0_dp, P_fm, ace_W(1, ispin), 0.0_dp, PW_fm)
     732           38 :          CALL cp_fm_trace(ace_W(1, ispin), PW_fm, trace_val)
     733           38 :          ehfx_ace = ehfx_ace - 0.5_dp*trace_val
     734              : 
     735           38 :          CALL cp_fm_release(P_fm)
     736           38 :          CALL cp_fm_release(PW_fm)
     737              : 
     738              :          IF (DBG_ENERGY .AND. iw > 0) &
     739              :             WRITE (iw, '(T4,A,I4,A,F20.10)') &
     740          190 :             'ispin=', ispin, '  E_x(ACE) += ', -0.5_dp*trace_val
     741              :       END DO
     742              : 
     743           38 :       energy%ex = ehfx_ace
     744              : 
     745              :       ! DIAG C: BUILD-step consistency check (printed when ehfx_full_ref present)
     746              :       IF (DBG_ENERGY .AND. iw > 0) THEN
     747              :          WRITE (iw, '(T2,A,F20.10)') 'ACE | E_x(ACE)  = ', ehfx_ace
     748              :          IF (PRESENT(ehfx_full_ref)) THEN
     749              :             WRITE (iw, '(T2,A,F20.10)') 'ACE | E_x(full) = ', ehfx_full_ref
     750              :             WRITE (iw, '(T2,A,ES12.4)') 'ACE | |delta|   = ', ABS(ehfx_ace - ehfx_full_ref)
     751              :             WRITE (iw, '(T2,A)') '(|delta| should be ~0 on BUILD steps; small is good)'
     752              :          END IF
     753              :       END IF
     754              : 
     755              :       ! ----------------------------------------------------------------
     756              :       ! DIAG A: projector staleness check.
     757              :       !
     758              :       ! Computes ||W^T C_occ^current||_F (summed over spins) and divides
     759              :       ! by ace_W_ref_norm = ||W^T C_occ^BUILD||_F stored at BUILD time.
     760              :       !
     761              :       ! staleness_ratio:
     762              :       !   1.0  → C_occ hasn't changed since BUILD; projector is fresh
     763              :       !   < 1  → C_occ has rotated; how much depends on the SCF dynamics
     764              :       !   → 0  → C_occ is orthogonal to the BUILD-time span; W is useless
     765              :       !
     766              :       ! For non-ADMM: C_occ comes from primary mos (nao_orb x nocc).
     767              :       ! For ADMM:     C_occ comes from mos_aux_fit (nao_aux x nocc_aux),
     768              :       !               consistent with ace_W dimensions.
     769              :       ! ----------------------------------------------------------------
     770              :       IF (DBG_STALE .AND. ace_W_ref_norm > 0.0_dp) THEN
     771              :          stale_norm = 0.0_dp
     772              :          CALL get_qs_env(qs_env, mos=mos_diag)
     773              : 
     774              :          DO ispin = 1, nspins
     775              :             CALL cp_fm_get_info(ace_W(1, ispin), nrow_global=nao_d, ncol_global=nocc_d)
     776              : 
     777              :             IF (do_admm) THEN
     778              :                CALL get_admm_env(admm_env, mos_aux_fit=mos_aux_diag)
     779              :                IF (mos_aux_diag(ispin)%use_mo_coeff_b) &
     780              :                   CALL copy_dbcsr_to_fm(mos_aux_diag(ispin)%mo_coeff_b, &
     781              :                                         mos_aux_diag(ispin)%mo_coeff)
     782              :                CALL get_mo_set(mos_aux_diag(ispin), mo_coeff=mo_coeff_diag, &
     783              :                                nao=nao_d, nmo=nmo_d, homo=nocc_d)
     784              :             ELSE
     785              :                IF (mos_diag(ispin)%use_mo_coeff_b) &
     786              :                   CALL copy_dbcsr_to_fm(mos_diag(ispin)%mo_coeff_b, &
     787              :                                         mos_diag(ispin)%mo_coeff)
     788              :                CALL get_mo_set(mos_diag(ispin), mo_coeff=mo_coeff_diag, &
     789              :                                nao=nao_d, nmo=nmo_d, homo=nocc_d)
     790              :             END IF
     791              : 
     792              :             NULLIFY (fmstruct_diag)
     793              :             CALL cp_fm_struct_create(fmstruct_diag, context=blacs_env, &
     794              :                                      para_env=para_env, &
     795              :                                      nrow_global=nao_d, ncol_global=nocc_d)
     796              :             CALL cp_fm_create(C_occ_diag, fmstruct_diag, name="C_stale")
     797              :             CALL cp_fm_struct_release(fmstruct_diag)
     798              :             CALL cp_fm_to_fm(mo_coeff_diag, C_occ_diag)
     799              : 
     800              :             NULLIFY (fmstruct_diag)
     801              :             CALL cp_fm_struct_create(fmstruct_diag, context=blacs_env, &
     802              :                                      para_env=para_env, &
     803              :                                      nrow_global=nocc_d, ncol_global=nocc_d)
     804              :             CALL cp_fm_create(A_diag_fm, fmstruct_diag, name="WtC_stale")
     805              :             CALL cp_fm_struct_release(fmstruct_diag)
     806              : 
     807              :             CALL parallel_gemm('T', 'N', nocc_d, nocc_d, nao_d, &
     808              :                                1.0_dp, ace_W(1, ispin), C_occ_diag, 0.0_dp, A_diag_fm)
     809              :             CALL cp_fm_trace(A_diag_fm, A_diag_fm, frob_A)
     810              :             stale_norm = stale_norm + SQRT(MAX(frob_A, 0.0_dp))
     811              : 
     812              :             CALL cp_fm_release(C_occ_diag)
     813              :             CALL cp_fm_release(A_diag_fm)
     814              :          END DO
     815              :          IF (iw > 0) THEN
     816              :             WRITE (iw, '(/,T2,A)') REPEAT('-', 56)
     817              :             WRITE (iw, '(T2,A,I6)') 'ACE DIAG A | ace_step_counter      = ', ace_step_counter
     818              :             WRITE (iw, '(T4,A,ES12.4)') '||W^T C_occ^current||_F          = ', stale_norm
     819              :             WRITE (iw, '(T4,A,ES12.4)') '||W^T C_occ^BUILD||_F  (ref)     = ', ace_W_ref_norm
     820              :             WRITE (iw, '(T4,A,F10.6)') 'staleness ratio (1=fresh, 0=stale) = ', &
     821              :                stale_norm/MAX(ace_W_ref_norm, 1.0e-30_dp)
     822              :             WRITE (iw, '(T2,A)') REPEAT('-', 56)
     823              :          END IF
     824              :       END IF
     825              : 
     826           38 :       CALL timestop(handle)
     827              : 
     828           38 :    END SUBROUTINE hfx_ace_apply_projector
     829              : 
     830              : ! **************************************************************************************************
     831              : !> \brief Release all ACE storage and reset state flags.
     832              : !> \param iw_opt ...
     833              : ! **************************************************************************************************
     834            0 :    SUBROUTINE hfx_ace_release(iw_opt)
     835              : 
     836              :       INTEGER, INTENT(IN), OPTIONAL                      :: iw_opt
     837              : 
     838              :       INTEGER                                            :: i, iw, j
     839              : 
     840            0 :       iw = -1
     841            0 :       IF (PRESENT(iw_opt)) iw = iw_opt
     842              : 
     843            0 :       IF (ALLOCATED(ace_W)) THEN
     844            0 :          DO j = 1, SIZE(ace_W, 2)
     845            0 :             DO i = 1, SIZE(ace_W, 1)
     846            0 :                IF (ASSOCIATED(ace_W(i, j)%matrix_struct)) CALL cp_fm_release(ace_W(i, j))
     847              :             END DO
     848              :          END DO
     849            0 :          DEALLOCATE (ace_W)
     850              :       END IF
     851              : 
     852            0 :       ace_is_built = .FALSE.
     853            0 :       ace_step_counter = 0
     854            0 :       ace_W_ref_norm = 0.0_dp
     855            0 :       ace_geo_step = 0
     856            0 :       ace_dynamic_mode = .FALSE.   ! ADDED: reset dynamic mode on release, so it must be explicitly re-enabled for GEO_OPT/MD runs
     857              : 
     858            0 :       IF (iw > 0) WRITE (iw, '(T2,A)') 'ACE | storage released, counters reset'
     859              : 
     860            0 :    END SUBROUTINE hfx_ace_release
     861              : 
     862              :    ! **************************************************************************************************
     863              :    !> \brief Mark this run as dynamic (GEO_OPT/MD) so Bypass C fires for geo step 0.
     864              :    !>        Call this once from the geo_opt or MD driver before the first SCF.
     865              :    !> \param is_dynamic .TRUE. for GEO_OPT/MD, .FALSE. to reset.
     866              :    ! **************************************************************************************************
     867              : ! **************************************************************************************************
     868              : !> \brief ...
     869              : !> \param is_dynamic ...
     870              : ! **************************************************************************************************
     871         2636 :    SUBROUTINE hfx_ace_set_dynamic_mode(is_dynamic)
     872              :       LOGICAL, INTENT(IN)                                :: is_dynamic
     873              : 
     874         2636 :       ace_dynamic_mode = is_dynamic
     875         2636 :    END SUBROUTINE hfx_ace_set_dynamic_mode
     876              : 
     877              :    ! **************************************************************************************************
     878              :    !> \brief Private helper: call hfx_ks_matrix with or without ext_xc_section.
     879              :    !> \param qs_env ...
     880              :    !> \param ks_matrix ...
     881              :    !> \param rho ...
     882              :    !> \param energy ...
     883              :    !> \param calculate_forces ...
     884              :    !> \param just_energy ...
     885              :    !> \param v_rspace_new ...
     886              :    !> \param v_tau_rspace ...
     887              :    !> \param ext_xc_section ...
     888              :    ! **************************************************************************************************
     889              : ! **************************************************************************************************
     890              : !> \brief ...
     891              : !> \param qs_env ...
     892              : !> \param ks_matrix ...
     893              : !> \param rho ...
     894              : !> \param energy ...
     895              : !> \param calculate_forces ...
     896              : !> \param just_energy ...
     897              : !> \param v_rspace_new ...
     898              : !> \param v_tau_rspace ...
     899              : !> \param ext_xc_section ...
     900              : ! **************************************************************************************************
     901           18 :    SUBROUTINE hfx_call(qs_env, ks_matrix, rho, energy, &
     902              :                        calculate_forces, just_energy, &
     903              :                        v_rspace_new, v_tau_rspace, ext_xc_section)
     904              : 
     905              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     906              :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     907              :       TYPE(qs_rho_type), POINTER                         :: rho
     908              :       TYPE(qs_energy_type), POINTER                      :: energy
     909              :       LOGICAL, INTENT(IN)                                :: calculate_forces, just_energy
     910              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: v_rspace_new, v_tau_rspace
     911              :       TYPE(section_vals_type), OPTIONAL, POINTER         :: ext_xc_section
     912              : 
     913           18 :       IF (PRESENT(ext_xc_section)) THEN
     914              :          CALL hfx_ks_matrix(qs_env, ks_matrix, rho, energy, &
     915              :                             calculate_forces, just_energy, &
     916              :                             v_rspace_new, v_tau_rspace, &
     917           18 :                             ext_xc_section=ext_xc_section)
     918              :       ELSE
     919              :          CALL hfx_ks_matrix(qs_env, ks_matrix, rho, energy, &
     920              :                             calculate_forces, just_energy, &
     921            0 :                             v_rspace_new, v_tau_rspace)
     922              :       END IF
     923              : 
     924           18 :    END SUBROUTINE hfx_call
     925              : 
     926              : END MODULE hfx_ace_methods
        

Generated by: LCOV version 2.0-1