LCOV - code coverage report
Current view: top level - src - rpa_exchange.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 97.6 % 373 364
Test Date: 2025-07-25 12:55:17 Functions: 60.0 % 15 9

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Auxiliary routines needed for RPA-exchange
      10              : !>        given blacs_env to another
      11              : !> \par History
      12              : !>      09.2016 created [Vladimir Rybkin]
      13              : !>      03.2019 Renamed [Frederick Stein]
      14              : !>      03.2019 Moved Functions from rpa_ri_gpw.F [Frederick Stein]
      15              : !>      04.2024 Added open-shell calculations, SOSEX [Frederick Stein]
      16              : !> \author Vladimir Rybkin
      17              : ! **************************************************************************************************
      18              : MODULE rpa_exchange
      19              :    USE atomic_kind_types,               ONLY: atomic_kind_type
      20              :    USE cell_types,                      ONLY: cell_type
      21              :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      22              :    USE cp_control_types,                ONLY: dft_control_type
      23              :    USE cp_dbcsr_api,                    ONLY: &
      24              :         dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_init_p, dbcsr_multiply, dbcsr_p_type, &
      25              :         dbcsr_release, dbcsr_set, dbcsr_type, dbcsr_type_no_symmetry
      26              :    USE cp_dbcsr_contrib,                ONLY: dbcsr_trace
      27              :    USE cp_dbcsr_operations,             ONLY: dbcsr_allocate_matrix_set
      28              :    USE cp_fm_basic_linalg,              ONLY: cp_fm_column_scale
      29              :    USE cp_fm_diag,                      ONLY: choose_eigv_solver
      30              :    USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
      31              :                                               cp_fm_struct_p_type,&
      32              :                                               cp_fm_struct_release
      33              :    USE cp_fm_types,                     ONLY: cp_fm_create,&
      34              :                                               cp_fm_get_info,&
      35              :                                               cp_fm_release,&
      36              :                                               cp_fm_set_all,&
      37              :                                               cp_fm_to_fm,&
      38              :                                               cp_fm_to_fm_submat_general,&
      39              :                                               cp_fm_type
      40              :    USE group_dist_types,                ONLY: create_group_dist,&
      41              :                                               get_group_dist,&
      42              :                                               group_dist_d1_type,&
      43              :                                               group_dist_proc,&
      44              :                                               maxsize,&
      45              :                                               release_group_dist
      46              :    USE hfx_admm_utils,                  ONLY: tddft_hfx_matrix
      47              :    USE hfx_types,                       ONLY: hfx_create,&
      48              :                                               hfx_release,&
      49              :                                               hfx_type
      50              :    USE input_constants,                 ONLY: rpa_exchange_axk,&
      51              :                                               rpa_exchange_none,&
      52              :                                               rpa_exchange_sosex
      53              :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      54              :                                               section_vals_type
      55              :    USE kinds,                           ONLY: dp,&
      56              :                                               int_8
      57              :    USE local_gemm_api,                  ONLY: LOCAL_GEMM_PU_GPU
      58              :    USE mathconstants,                   ONLY: sqrthalf
      59              :    USE message_passing,                 ONLY: mp_para_env_type,&
      60              :                                               mp_proc_null
      61              :    USE mp2_types,                       ONLY: mp2_type
      62              :    USE parallel_gemm_api,               ONLY: parallel_gemm
      63              :    USE particle_types,                  ONLY: particle_type
      64              :    USE qs_environment_types,            ONLY: get_qs_env,&
      65              :                                               qs_environment_type
      66              :    USE qs_kind_types,                   ONLY: qs_kind_type
      67              :    USE qs_subsys_types,                 ONLY: qs_subsys_get,&
      68              :                                               qs_subsys_type
      69              :    USE rpa_communication,               ONLY: gamma_fm_to_dbcsr
      70              :    USE rpa_util,                        ONLY: calc_fm_mat_S_rpa,&
      71              :                                               remove_scaling_factor_rpa
      72              :    USE scf_control_types,               ONLY: scf_control_type
      73              : #include "./base/base_uses.f90"
      74              : 
      75              :    IMPLICIT NONE
      76              : 
      77              :    PRIVATE
      78              : 
      79              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rpa_exchange'
      80              : 
      81              :    PUBLIC :: rpa_exchange_work_type, rpa_exchange_needed_mem
      82              : 
      83              :    TYPE rpa_exchange_env_type
      84              :       PRIVATE
      85              :       TYPE(qs_environment_type), POINTER             :: qs_env => NULL()
      86              :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER      :: mat_hfx => NULL()
      87              :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER      :: dbcsr_Gamma_munu_P => NULL()
      88              :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)    :: dbcsr_Gamma_inu_P
      89              :       ! Workaround GCC 8
      90              :       TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_o => NULL()
      91              :       TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_v => NULL()
      92              :       TYPE(dbcsr_type)                               :: work_ao
      93              :       TYPE(hfx_type), DIMENSION(:, :), POINTER       :: x_data => NULL()
      94              :       TYPE(mp_para_env_type), POINTER                :: para_env => NULL()
      95              :       TYPE(section_vals_type), POINTER               :: hfx_sections => NULL()
      96              :       LOGICAL :: my_recalc_hfx_integrals = .FALSE.
      97              :       REAL(KIND=dp) :: eps_filter = 0.0_dp
      98              :       TYPE(cp_fm_struct_p_type), DIMENSION(:), ALLOCATABLE :: struct_Gamma
      99              :    CONTAINS
     100              :       PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: create => hfx_create_subgroup
     101              :       !PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: integrate => integrate_exchange
     102              :       PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: release => hfx_release_subgroup
     103              :    END TYPE
     104              : 
     105              :    TYPE dbcsr_matrix_p_set
     106              :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: matrix_set
     107              :    END TYPE
     108              : 
     109              :    TYPE rpa_exchange_work_type
     110              :       PRIVATE
     111              :       INTEGER :: exchange_correction = rpa_exchange_none
     112              :       TYPE(rpa_exchange_env_type) :: exchange_env
     113              :       INTEGER, DIMENSION(:), ALLOCATABLE :: homo, virtual, dimen_ia
     114              :       TYPE(group_dist_d1_type) :: aux_func_dist = group_dist_d1_type()
     115              :       INTEGER, DIMENSION(:), ALLOCATABLE :: aux2send
     116              :       INTEGER :: dimen_RI = 0
     117              :       INTEGER :: block_size = 0
     118              :       INTEGER :: color_sub = 0
     119              :       INTEGER :: ngroup = 0
     120              :       TYPE(cp_fm_type) :: fm_mat_Q_tmp = cp_fm_type()
     121              :       TYPE(cp_fm_type) :: fm_mat_R_half_gemm = cp_fm_type()
     122              :       TYPE(cp_fm_type) :: fm_mat_U = cp_fm_type()
     123              :       TYPE(mp_para_env_type), POINTER :: para_env_sub => NULL()
     124              :    CONTAINS
     125              :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: create => rpa_exchange_work_create
     126              :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: compute => rpa_exchange_work_compute
     127              :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: release => rpa_exchange_work_release
     128              :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: redistribute_into_subgroups
     129              :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_fm => rpa_exchange_work_compute_fm
     130              :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_hfx => rpa_exchange_work_compute_hfx
     131              :    END TYPE
     132              : 
     133              : CONTAINS
     134              : 
     135              : ! **************************************************************************************************
     136              : !> \brief ...
     137              : !> \param mp2_env ...
     138              : !> \param homo ...
     139              : !> \param virtual ...
     140              : !> \param dimen_RI ...
     141              : !> \param para_env ...
     142              : !> \param mem_per_rank ...
     143              : !> \param mem_per_repl ...
     144              : ! **************************************************************************************************
     145          134 :    SUBROUTINE rpa_exchange_needed_mem(mp2_env, homo, virtual, dimen_RI, para_env, mem_per_rank, mem_per_repl)
     146              :       TYPE(mp2_type), INTENT(IN)                         :: mp2_env
     147              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: homo, virtual
     148              :       INTEGER, INTENT(IN)                                :: dimen_RI
     149              :       TYPE(mp_para_env_type), INTENT(IN)                 :: para_env
     150              :       REAL(KIND=dp), INTENT(INOUT)                       :: mem_per_rank, mem_per_repl
     151              : 
     152              :       INTEGER                                            :: block_size
     153              : 
     154              :       ! We need the block size and if it is unknown, an upper bound
     155          134 :       block_size = mp2_env%ri_rpa%exchange_block_size
     156          134 :       IF (block_size <= 0) block_size = MAX(1, (dimen_RI + para_env%num_pe - 1)/para_env%num_pe)
     157              : 
     158              :       ! storage of product matrix (upper bound only as it depends on the square of the potential still unknown block size)
     159          298 :       mem_per_rank = mem_per_rank + REAL(MAXVAL(homo), KIND=dp)**2*block_size**2*8.0_dp/(1024_dp**2)
     160              : 
     161              :       ! work arrays R (2x) and U, copies of Gamma (2x), communication buffer (as expensive as Gamma)
     162              :       mem_per_repl = mem_per_repl + 3.0_dp*dimen_RI*dimen_RI*8.0_dp/(1024_dp**2) &
     163          298 :                      + 3.0_dp*MAXVAL(homo*virtual)*dimen_RI*8.0_dp/(1024_dp**2)
     164          134 :    END SUBROUTINE rpa_exchange_needed_mem
     165              : 
     166              : ! **************************************************************************************************
     167              : !> \brief ...
     168              : !> \param exchange_work ...
     169              : !> \param qs_env ...
     170              : !> \param para_env_sub ...
     171              : !> \param mat_munu ...
     172              : !> \param dimen_RI ...
     173              : !> \param fm_mat_S ...
     174              : !> \param fm_mat_Q ...
     175              : !> \param fm_mat_Q_gemm ...
     176              : !> \param homo ...
     177              : !> \param virtual ...
     178              : ! **************************************************************************************************
     179          138 :    SUBROUTINE rpa_exchange_work_create(exchange_work, qs_env, para_env_sub, mat_munu, dimen_RI, &
     180          138 :                                        fm_mat_S, fm_mat_Q, fm_mat_Q_gemm, homo, virtual)
     181              :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     182              :       TYPE(qs_environment_type), POINTER :: qs_env
     183              :       TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
     184              :       TYPE(dbcsr_p_type), INTENT(IN) :: mat_munu
     185              :       INTEGER, INTENT(IN) :: dimen_RI
     186              :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
     187              :       TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q, fm_mat_Q_gemm
     188              :       INTEGER, DIMENSION(SIZE(fm_mat_S)), INTENT(IN) :: homo, virtual
     189              : 
     190              :       INTEGER :: nspins, aux_global, aux_local, my_process_row, proc, ispin
     191          138 :       INTEGER, DIMENSION(:), POINTER :: row_indices, aux_distribution_fm
     192              :       TYPE(cp_blacs_env_type), POINTER :: context
     193              : 
     194          138 :       exchange_work%exchange_correction = qs_env%mp2_env%ri_rpa%exchange_correction
     195              : 
     196          138 :       IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
     197              : 
     198              :       ASSOCIATE (para_env => fm_mat_S(1)%matrix_struct%para_env)
     199           12 :          exchange_work%para_env_sub => para_env_sub
     200           12 :          exchange_work%ngroup = para_env%num_pe/para_env_sub%num_pe
     201           12 :          exchange_work%color_sub = para_env%mepos/para_env_sub%num_pe
     202              :       END ASSOCIATE
     203              : 
     204           12 :       CALL cp_fm_get_info(fm_mat_S(1), row_indices=row_indices, nrow_locals=aux_distribution_fm, context=context)
     205           12 :       CALL context%get(my_process_row=my_process_row)
     206              : 
     207           12 :       CALL create_group_dist(exchange_work%aux_func_dist, exchange_work%ngroup, dimen_RI)
     208           36 :       ALLOCATE (exchange_work%aux2send(0:exchange_work%ngroup - 1))
     209           36 :       exchange_work%aux2send = 0
     210          499 :       DO aux_local = 1, aux_distribution_fm(my_process_row)
     211          487 :          aux_global = row_indices(aux_local)
     212          487 :          proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
     213          499 :          exchange_work%aux2send(proc) = exchange_work%aux2send(proc) + 1
     214              :       END DO
     215              : 
     216           12 :       nspins = SIZE(fm_mat_S)
     217              : 
     218           60 :       ALLOCATE (exchange_work%homo(nspins), exchange_work%virtual(nspins), exchange_work%dimen_ia(nspins))
     219           26 :       exchange_work%homo(:) = homo
     220           26 :       exchange_work%virtual(:) = virtual
     221           26 :       exchange_work%dimen_ia(:) = homo*virtual
     222           12 :       exchange_work%dimen_RI = dimen_RI
     223              : 
     224           12 :       exchange_work%block_size = qs_env%mp2_env%ri_rpa%exchange_block_size
     225           12 :       IF (exchange_work%block_size <= 0) exchange_work%block_size = dimen_RI
     226              : 
     227           12 :       CALL cp_fm_create(exchange_work%fm_mat_U, fm_mat_Q%matrix_struct, name="fm_mat_U")
     228           12 :       CALL cp_fm_create(exchange_work%fm_mat_Q_tmp, fm_mat_Q%matrix_struct, name="fm_mat_Q_tmp")
     229           12 :       CALL cp_fm_create(exchange_work%fm_mat_R_half_gemm, fm_mat_Q_gemm%matrix_struct)
     230              : 
     231           12 :       IF (qs_env%mp2_env%ri_rpa%use_hfx_implementation) THEN
     232            2 :          CALL exchange_work%exchange_env%create(qs_env, mat_munu%matrix, para_env_sub, fm_mat_S)
     233              :       END IF
     234              : 
     235           12 :       IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_o)) THEN
     236           22 :          DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_o)
     237           22 :             CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_o(ispin))
     238              :          END DO
     239           10 :          DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_o)
     240              :       END IF
     241              : 
     242           12 :       IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_v)) THEN
     243           22 :          DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_v)
     244           22 :             CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_v(ispin))
     245              :          END DO
     246           10 :          DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_v)
     247              :       END IF
     248          138 :    END SUBROUTINE
     249              : 
     250              : ! **************************************************************************************************
     251              : !> \brief ... Initializes x_data on a subgroup
     252              : !> \param exchange_env ...
     253              : !> \param qs_env ...
     254              : !> \param mat_munu ...
     255              : !> \param para_env_sub ...
     256              : !> \param fm_mat_S ...
     257              : !> \author Vladimir Rybkin
     258              : ! **************************************************************************************************
     259            2 :    SUBROUTINE hfx_create_subgroup(exchange_env, qs_env, mat_munu, para_env_sub, fm_mat_S)
     260              :       CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
     261              :       TYPE(dbcsr_type), INTENT(IN) :: mat_munu
     262              :       TYPE(qs_environment_type), POINTER   :: qs_env
     263              :       TYPE(mp_para_env_type), POINTER, INTENT(IN)            :: para_env_sub
     264              :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
     265              : 
     266              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_create_subgroup'
     267              : 
     268              :       INTEGER                                            :: handle, nelectron_total, ispin, &
     269              :                                                             number_of_aos, nspins, dimen_RI, dimen_ia
     270            2 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     271              :       TYPE(cell_type), POINTER                           :: my_cell
     272              :       TYPE(dft_control_type), POINTER                    :: dft_control
     273            2 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     274            2 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     275              :       TYPE(qs_subsys_type), POINTER                      :: subsys
     276              :       TYPE(scf_control_type), POINTER                    :: scf_control
     277              :       TYPE(section_vals_type), POINTER                   :: input
     278              : 
     279            2 :       CALL timeset(routineN, handle)
     280              : 
     281            2 :       exchange_env%mo_coeff_o => qs_env%mp2_env%ri_rpa%mo_coeff_o
     282            2 :       exchange_env%mo_coeff_v => qs_env%mp2_env%ri_rpa%mo_coeff_v
     283            2 :       NULLIFY (qs_env%mp2_env%ri_rpa%mo_coeff_o, qs_env%mp2_env%ri_rpa%mo_coeff_v)
     284              : 
     285            2 :       nspins = SIZE(exchange_env%mo_coeff_o)
     286              : 
     287            2 :       exchange_env%qs_env => qs_env
     288            2 :       exchange_env%para_env => para_env_sub
     289            2 :       exchange_env%eps_filter = qs_env%mp2_env%mp2_gpw%eps_filter
     290              : 
     291            2 :       NULLIFY (my_cell, atomic_kind_set, particle_set, dft_control, qs_kind_set, scf_control)
     292              : 
     293              :       CALL get_qs_env(qs_env, &
     294              :                       subsys=subsys, &
     295              :                       input=input, &
     296              :                       scf_control=scf_control, &
     297            2 :                       nelectron_total=nelectron_total)
     298              : 
     299              :       CALL qs_subsys_get(subsys, &
     300              :                          cell=my_cell, &
     301              :                          atomic_kind_set=atomic_kind_set, &
     302              :                          qs_kind_set=qs_kind_set, &
     303            2 :                          particle_set=particle_set)
     304              : 
     305            2 :       exchange_env%hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%WF_CORRELATION%RI_RPA%HF")
     306            2 :       CALL get_qs_env(qs_env, dft_control=dft_control)
     307              : 
     308              :       ! Retrieve particle_set and atomic_kind_set
     309              :       CALL hfx_create(exchange_env%x_data, para_env_sub, exchange_env%hfx_sections, atomic_kind_set, &
     310              :                       qs_kind_set, particle_set, dft_control, my_cell, orb_basis='ORB', &
     311            2 :                       nelectron_total=nelectron_total)
     312              : 
     313            2 :       exchange_env%my_recalc_hfx_integrals = .TRUE.
     314              : 
     315            2 :       CALL dbcsr_allocate_matrix_set(exchange_env%mat_hfx, nspins)
     316            4 :       DO ispin = 1, nspins
     317            2 :          ALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
     318            2 :          CALL dbcsr_init_p(exchange_env%mat_hfx(ispin)%matrix)
     319              :          CALL dbcsr_create(exchange_env%mat_hfx(ispin)%matrix, template=mat_munu, &
     320            2 :                            matrix_type=dbcsr_type_no_symmetry)
     321            4 :          CALL dbcsr_copy(exchange_env%mat_hfx(ispin)%matrix, mat_munu)
     322              :       END DO
     323              : 
     324            2 :       CALL dbcsr_get_info(mat_munu, nfullcols_total=number_of_aos)
     325              : 
     326              :       CALL dbcsr_create(exchange_env%work_ao, template=mat_munu, &
     327            2 :                         matrix_type=dbcsr_type_no_symmetry)
     328              : 
     329            8 :       ALLOCATE (exchange_env%dbcsr_Gamma_inu_P(nspins))
     330            2 :       CALL dbcsr_allocate_matrix_set(exchange_env%dbcsr_Gamma_munu_P, nspins)
     331            4 :       DO ispin = 1, nspins
     332            2 :          ALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     333              :          CALL dbcsr_create(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, template=mat_munu, &
     334            2 :                            matrix_type=dbcsr_type_no_symmetry)
     335            2 :          CALL dbcsr_copy(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, mat_munu)
     336            2 :          CALL dbcsr_set(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, 0.0_dp)
     337              : 
     338            2 :          CALL dbcsr_create(exchange_env%dbcsr_Gamma_inu_P(ispin), template=exchange_env%mo_coeff_o(ispin))
     339            2 :          CALL dbcsr_copy(exchange_env%dbcsr_Gamma_inu_P(ispin), exchange_env%mo_coeff_o(ispin))
     340            4 :          CALL dbcsr_set(exchange_env%dbcsr_Gamma_inu_P(ispin), 0.0_dp)
     341              :       END DO
     342              : 
     343            8 :       ALLOCATE (exchange_env%struct_Gamma(nspins))
     344            4 :       DO ispin = 1, nspins
     345            2 :          CALL cp_fm_get_info(fm_mat_S(ispin), nrow_global=dimen_RI, ncol_global=dimen_ia)
     346              :          CALL cp_fm_struct_create(exchange_env%struct_Gamma(ispin)%struct, template_fmstruct=fm_mat_S(ispin)%matrix_struct, &
     347            4 :                                   nrow_global=dimen_ia, ncol_global=dimen_RI)
     348              :       END DO
     349              : 
     350            2 :       CALL timestop(handle)
     351              : 
     352            4 :    END SUBROUTINE hfx_create_subgroup
     353              : 
     354              : ! **************************************************************************************************
     355              : !> \brief ...
     356              : !> \param exchange_work ...
     357              : ! **************************************************************************************************
     358          138 :    SUBROUTINE rpa_exchange_work_release(exchange_work)
     359              :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     360              : 
     361          138 :       IF (ALLOCATED(exchange_work%homo)) DEALLOCATE (exchange_work%homo)
     362          138 :       IF (ALLOCATED(exchange_work%virtual)) DEALLOCATE (exchange_work%virtual)
     363          138 :       IF (ALLOCATED(exchange_work%dimen_ia)) DEALLOCATE (exchange_work%dimen_ia)
     364          138 :       NULLIFY (exchange_work%para_env_sub)
     365          138 :       CALL release_group_dist(exchange_work%aux_func_dist)
     366          138 :       IF (ALLOCATED(exchange_work%aux2send)) DEALLOCATE (exchange_work%aux2send)
     367          138 :       CALL cp_fm_release(exchange_work%fm_mat_Q_tmp)
     368          138 :       CALL cp_fm_release(exchange_work%fm_mat_U)
     369          138 :       CALL cp_fm_release(exchange_work%fm_mat_R_half_gemm)
     370              : 
     371          138 :       CALL exchange_work%exchange_env%release()
     372          138 :    END SUBROUTINE
     373              : 
     374              : ! **************************************************************************************************
     375              : !> \brief ...
     376              : !> \param exchange_env ...
     377              : ! **************************************************************************************************
     378          138 :    SUBROUTINE hfx_release_subgroup(exchange_env)
     379              :       CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
     380              : 
     381              :       INTEGER :: ispin
     382              : 
     383          138 :       NULLIFY (exchange_env%para_env, exchange_env%hfx_sections)
     384              : 
     385          138 :       IF (ASSOCIATED(exchange_env%x_data)) THEN
     386            2 :          CALL hfx_release(exchange_env%x_data)
     387            2 :          NULLIFY (exchange_env%x_data)
     388              :       END IF
     389              : 
     390          138 :       CALL dbcsr_release(exchange_env%work_ao)
     391              : 
     392          138 :       IF (ASSOCIATED(exchange_env%dbcsr_Gamma_munu_P)) THEN
     393            4 :          DO ispin = 1, SIZE(exchange_env%mat_hfx, 1)
     394            2 :             CALL dbcsr_release(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     395            2 :             CALL dbcsr_release(exchange_env%mat_hfx(ispin)%matrix)
     396            2 :             CALL dbcsr_release(exchange_env%dbcsr_Gamma_inu_P(ispin))
     397            2 :             CALL dbcsr_release(exchange_env%mo_coeff_o(ispin))
     398            2 :             CALL dbcsr_release(exchange_env%mo_coeff_v(ispin))
     399            2 :             DEALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
     400            4 :             DEALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     401              :          END DO
     402            2 :          DEALLOCATE (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
     403            2 :          DEALLOCATE (exchange_env%dbcsr_Gamma_inu_P, exchange_env%mo_coeff_o, exchange_env%mo_coeff_v)
     404            2 :          NULLIFY (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
     405              :       END IF
     406          138 :       IF (ALLOCATED(exchange_env%struct_Gamma)) THEN
     407            4 :       DO ispin = 1, SIZE(exchange_env%struct_Gamma)
     408            4 :          CALL cp_fm_struct_release(exchange_env%struct_Gamma(ispin)%struct)
     409              :       END DO
     410            2 :       DEALLOCATE (exchange_env%struct_Gamma)
     411              :       END IF
     412          138 :    END SUBROUTINE hfx_release_subgroup
     413              : 
     414              : ! **************************************************************************************************
     415              : !> \brief Main driver for RPA-exchange energies
     416              : !> \param exchange_work ...
     417              : !> \param fm_mat_Q ...
     418              : !> \param eig ...
     419              : !> \param fm_mat_S ...
     420              : !> \param omega ...
     421              : !> \param e_exchange_corr exchange energy correction for a quadrature point
     422              : !> \param mp2_env ...
     423              : !> \author Vladimir Rybkin, 07/2016
     424              : ! **************************************************************************************************
     425           12 :    SUBROUTINE rpa_exchange_work_compute(exchange_work, fm_mat_Q, eig, fm_mat_S, omega, &
     426              :                                         e_exchange_corr, mp2_env)
     427              :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     428              :       TYPE(cp_fm_type), INTENT(IN)                       :: fm_mat_Q
     429              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     430              :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT)         :: fm_mat_S
     431              :       REAL(KIND=dp), INTENT(IN)                          :: omega
     432              :       REAL(KIND=dp), INTENT(INOUT)                       :: e_exchange_corr
     433              :       TYPE(mp2_type), INTENT(INOUT) :: mp2_env
     434              : 
     435              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rpa_exchange_work_compute'
     436              :       REAL(KIND=dp), PARAMETER                           :: thresh = 0.0000001_dp
     437              : 
     438              :       INTEGER :: handle, nspins, dimen_RI, iiB
     439           12 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigenval
     440              : 
     441           12 :       IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
     442              : 
     443           12 :       CALL timeset(routineN, handle)
     444              : 
     445           12 :       CALL cp_fm_get_info(fm_mat_Q, ncol_global=dimen_RI)
     446              : 
     447           12 :       nspins = SIZE(fm_mat_S)
     448              : 
     449              :       ! Eigenvalues
     450           36 :       ALLOCATE (eigenval(dimen_RI))
     451          986 :       eigenval = 0.0_dp
     452              : 
     453           12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_Q_tmp, alpha=0.0_dp)
     454           12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_U, alpha=0.0_dp)
     455              : 
     456              :       ! Copy Q to Q_tmp
     457           12 :       CALL cp_fm_to_fm(fm_mat_Q, exchange_work%fm_mat_Q_tmp)
     458              :       ! Diagonalize Q
     459           12 :       CALL choose_eigv_solver(exchange_work%fm_mat_Q_tmp, exchange_work%fm_mat_U, eigenval)
     460              : 
     461              :       ! Calculate diagonal matrix for R_half
     462              : 
     463              :       ! Manipulate eigenvalues to get diagonal matrix
     464           12 :       IF (exchange_work%exchange_correction == rpa_exchange_axk) THEN
     465          818 :          DO iib = 1, dimen_RI
     466          818 :             IF (ABS(eigenval(iib)) .GE. thresh) THEN
     467              :                eigenval(iib) = &
     468              :                   SQRT((1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
     469          718 :                        - 1.0_dp/(eigenval(iib)*(eigenval(iib) + 1.0_dp)))
     470              :             ELSE
     471           90 :                eigenval(iib) = sqrthalf
     472              :             END IF
     473              :          END DO
     474            2 :       ELSE IF (exchange_work%exchange_correction == rpa_exchange_sosex) THEN
     475          168 :          DO iib = 1, dimen_RI
     476          168 :             IF (ABS(eigenval(iib)) .GE. thresh) THEN
     477              :                eigenval(iib) = &
     478              :                   SQRT(-(1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
     479          144 :                        + 1.0_dp/eigenval(iib))
     480              :             ELSE
     481           22 :                eigenval(iib) = sqrthalf
     482              :             END IF
     483              :          END DO
     484              :       ELSE
     485            0 :          CPABORT("Unknown RPA exchange correction")
     486              :       END IF
     487              : 
     488              :       ! fm_mat_U now contains some sqrt of the required matrix-valued function
     489           12 :       CALL cp_fm_column_scale(exchange_work%fm_mat_U, eigenval)
     490              : 
     491              :       ! Release memory
     492           12 :       DEALLOCATE (eigenval)
     493              : 
     494              :       ! Redistribute fm_mat_U for "rectangular" multiplication: ia*P P*P
     495           12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_R_half_gemm, alpha=0.0_dp)
     496              : 
     497              :       CALL cp_fm_to_fm_submat_general(exchange_work%fm_mat_U, exchange_work%fm_mat_R_half_gemm, dimen_RI, &
     498           12 :                                       dimen_RI, 1, 1, 1, 1, exchange_work%fm_mat_U%matrix_struct%context)
     499              : 
     500           12 :       IF (mp2_env%ri_rpa%use_hfx_implementation) THEN
     501            2 :          CALL exchange_work%compute_hfx(fm_mat_S, eig, omega, e_exchange_corr)
     502              :       ELSE
     503           10 :          CALL exchange_work%compute_fm(fm_mat_S, eig, omega, e_exchange_corr, mp2_env)
     504              :       END IF
     505              : 
     506           12 :       CALL timestop(handle)
     507              : 
     508           12 :    END SUBROUTINE rpa_exchange_work_compute
     509              : 
     510              : ! **************************************************************************************************
     511              : !> \brief Main driver for RPA-exchange energies
     512              : !> \param exchange_work ...
     513              : !> \param fm_mat_S ...
     514              : !> \param eig ...
     515              : !> \param omega ...
     516              : !> \param e_exchange_corr exchange energy correction for a quadrature point
     517              : !> \param mp2_env ...
     518              : !> \author Frederick Stein, May-June 2024
     519              : ! **************************************************************************************************
     520           10 :    SUBROUTINE rpa_exchange_work_compute_fm(exchange_work, fm_mat_S, eig, omega, &
     521              :                                            e_exchange_corr, mp2_env)
     522              :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     523              :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN)         :: fm_mat_S
     524              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     525              :       REAL(KIND=dp), INTENT(IN)                          :: omega
     526              :       REAL(KIND=dp), INTENT(INOUT)                       :: e_exchange_corr
     527              :       TYPE(mp2_type), INTENT(INOUT) :: mp2_env
     528              : 
     529              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rpa_exchange_work_compute_fm'
     530              : 
     531              :       INTEGER :: handle, ispin, nspins, P, Q, L_size_Gamma, hom, virt, i, &
     532              :                  send_proc, recv_proc, recv_size, max_aux_size, proc_shift, dimen_ia, &
     533              :                  block_size, P_start, P_end, P_size, Q_start, Q_size, Q_end, handle2, my_aux_size, my_virt
     534           10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET :: mat_Gamma_3_3D
     535           10 :       REAL(KIND=dp), POINTER, DIMENSION(:), CONTIGUOUS :: mat_Gamma_3_1D
     536           10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: mat_Gamma_3_2D
     537           10 :       REAL(KIND=dp), ALLOCATABLE, TARGET, DIMENSION(:) :: recv_buffer_1D
     538           10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: recv_buffer_2D
     539           10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :, :), CONTIGUOUS :: recv_buffer_3D
     540           10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: mat_B_iaP
     541           10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: product_matrix_1D
     542           10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: product_matrix_2D
     543           10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :, :, :), CONTIGUOUS :: product_matrix_4D
     544              :       TYPE(cp_fm_type)        :: fm_mat_Gamma_3
     545              :       TYPE(mp_para_env_type), POINTER :: para_env
     546           10 :       TYPE(group_dist_d1_type)                           :: virt_dist
     547              : 
     548           10 :       CALL timeset(routineN, handle)
     549              : 
     550           10 :       nspins = SIZE(fm_mat_S)
     551              : 
     552           10 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, sizes=my_aux_size)
     553              : 
     554           10 :       e_exchange_corr = 0.0_dp
     555           10 :       max_aux_size = maxsize(exchange_work%aux_func_dist)
     556              : 
     557              :       ! local_gemm_ctx has a very large footprint the first time this routine is
     558              :       ! called.
     559           10 :       CALL mp2_env%local_gemm_ctx%create(LOCAL_GEMM_PU_GPU)
     560           10 :       CALL mp2_env%local_gemm_ctx%set_op_threshold_gpu(128*128*128*2)
     561              : 
     562           22 :       DO ispin = 1, nspins
     563           12 :          hom = exchange_work%homo(ispin)
     564           12 :          virt = exchange_work%virtual(ispin)
     565           12 :          dimen_ia = hom*virt
     566           12 :          IF (hom < 1 .OR. virt < 1) CYCLE
     567              : 
     568           12 :          CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
     569              : 
     570           12 :          CALL cp_fm_create(fm_mat_Gamma_3, fm_mat_S(ispin)%matrix_struct)
     571           12 :          CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
     572              : 
     573              :          ! Update G with a new value of Omega: in practice, it is G*S
     574              : 
     575              :          ! Scale fm_work_iaP
     576              :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
     577           12 :                                 hom, omega, 0.0_dp)
     578              : 
     579              :          ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
     580              :          CALL parallel_gemm(transa="T", transb="N", m=exchange_work%dimen_RI, n=dimen_ia, k=exchange_work%dimen_RI, alpha=1.0_dp, &
     581              :                             matrix_a=exchange_work%fm_mat_R_half_gemm, matrix_b=fm_mat_S(ispin), beta=0.0_dp, &
     582           12 :                             matrix_c=fm_mat_Gamma_3)
     583              : 
     584           12 :          CALL create_group_dist(virt_dist, exchange_work%para_env_sub%num_pe, virt)
     585              : 
     586              :          ! Remove extra factor from S after the multiplication (to return to the original matrix)
     587           12 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     588              : 
     589           12 :          CALL exchange_work%redistribute_into_subgroups(fm_mat_Gamma_3, mat_Gamma_3_3D, ispin, virt_dist)
     590           12 :          CALL cp_fm_release(fm_mat_Gamma_3)
     591              : 
     592              :          ! We need only the pure matrix
     593           12 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     594              : 
     595              :          ! Reorder matrix from (P, i*a) -> (a, i, P) with P being distributed within subgroups
     596           12 :          CALL exchange_work%redistribute_into_subgroups(fm_mat_S(ispin), mat_B_iaP, ispin, virt_dist)
     597              : 
     598              :          ! Return to the original tensor
     599           12 :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), hom, omega, 0.0_dp)
     600              : 
     601           12 :          L_size_Gamma = SIZE(mat_Gamma_3_3D, 3)
     602           12 :          my_virt = SIZE(mat_Gamma_3_3D, 1)
     603           12 :          block_size = exchange_work%block_size
     604              : 
     605           12 :          mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size) => mat_Gamma_3_3D(:, :, 1:my_aux_size)
     606           12 :          mat_Gamma_3_2D(1:my_virt, 1:hom*my_aux_size) => mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size)
     607              : 
     608            0 :          ALLOCATE (product_matrix_1D(INT(hom*MIN(block_size, L_size_gamma), KIND=int_8)* &
     609           36 :                                      INT(hom*MIN(block_size, max_aux_size), KIND=int_8)))
     610           36 :          ALLOCATE (recv_buffer_1D(INT(virt, KIND=int_8)*hom*max_aux_size))
     611           12 :          recv_buffer_2D(1:my_virt, 1:hom*max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
     612           12 :          recv_buffer_3D(1:my_virt, 1:hom, 1:max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
     613           36 :          DO proc_shift = 0, para_env%num_pe - 1, exchange_work%para_env_sub%num_pe
     614           24 :             send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
     615           24 :             recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
     616              : 
     617           24 :             CALL get_group_dist(exchange_work%aux_func_dist, recv_proc/exchange_work%para_env_sub%num_pe, sizes=recv_size)
     618              : 
     619           24 :             IF (recv_size == 0) recv_proc = mp_proc_null
     620              : 
     621           24 :             CALL para_env%sendrecv(mat_B_iaP, send_proc, recv_buffer_3D(:, :, 1:recv_size), recv_proc)
     622              : 
     623           24 :             IF (recv_size == 0) CYCLE
     624              : 
     625         1038 :             DO P_start = 1, L_size_Gamma, block_size
     626         1002 :                P_end = MIN(L_size_Gamma, P_start + block_size - 1)
     627         1002 :                P_size = P_end - P_start + 1
     628        43875 :                DO Q_start = 1, recv_size, block_size
     629        42849 :                   Q_end = MIN(recv_size, Q_start + block_size - 1)
     630        42849 :                   Q_size = Q_end - Q_start + 1
     631              : 
     632              :                   ! Reassign product_matrix pointers to enforce contiguity of target array
     633              :                   product_matrix_2D(1:hom*P_size, 1:hom*Q_size) => &
     634        42849 :                      product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
     635              :                   product_matrix_4D(1:hom, 1:P_size, 1:hom, 1:Q_size) => &
     636        42849 :                      product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
     637              : 
     638        42849 :                   CALL timeset(routineN//"_gemm", handle2)
     639              :                   CALL mp2_env%local_gemm_ctx%gemm("T", "N", hom*P_size, hom*Q_size, my_virt, 1.0_dp, &
     640              :                                                    mat_Gamma_3_2D(:, hom*(P_start - 1) + 1:hom*P_end), my_virt, &
     641              :                                                    recv_buffer_2D(:, hom*(Q_start - 1) + 1:hom*Q_end), my_virt, &
     642        42849 :                                                    0.0_dp, product_matrix_2D, hom*P_size)
     643        42849 :                   CALL timestop(handle2)
     644              : 
     645        42849 :                   CALL timeset(routineN//"_energy", handle2)
     646              : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(P_size, Q_size, hom, product_matrix_4D) &
     647        42849 : !$OMP             COLLAPSE(3) REDUCTION(+: e_exchange_corr) PRIVATE(P, Q, i)
     648              :                   DO P = 1, P_size
     649              :                   DO Q = 1, Q_size
     650              :                   DO i = 1, hom
     651              :                      e_exchange_corr = e_exchange_corr + DOT_PRODUCT(product_matrix_4D(i, P, :, Q), product_matrix_4D(:, P, i, Q))
     652              :                   END DO
     653              :                   END DO
     654              :                   END DO
     655        86700 :                   CALL timestop(handle2)
     656              :                END DO
     657              :             END DO
     658              :          END DO
     659              : 
     660           12 :          CALL release_group_dist(virt_dist)
     661           12 :          IF (ALLOCATED(mat_B_iaP)) DEALLOCATE (mat_B_iaP)
     662           12 :          IF (ALLOCATED(mat_Gamma_3_3D)) DEALLOCATE (mat_Gamma_3_3D)
     663           12 :          IF (ALLOCATED(product_matrix_1D)) DEALLOCATE (product_matrix_1D)
     664           58 :          IF (ALLOCATED(recv_buffer_1D)) DEALLOCATE (recv_buffer_1D)
     665              :       END DO
     666              : 
     667           10 :       CALL mp2_env%local_gemm_ctx%destroy()
     668              : 
     669           10 :       IF (nspins == 2) e_exchange_corr = e_exchange_corr*2.0_dp
     670           10 :       IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
     671              : 
     672           10 :       CALL timestop(handle)
     673              : 
     674           20 :    END SUBROUTINE rpa_exchange_work_compute_fm
     675              : 
     676              : ! **************************************************************************************************
     677              : !> \brief Contract RPA-exchange density matrix with HF exchange integrals and evaluate the correction
     678              : !> \param exchange_work ...
     679              : !> \param fm_mat_S ...
     680              : !> \param eig ...
     681              : !> \param omega ...
     682              : !> \param e_exchange_corr ...
     683              : !> \author Vladimir Rybkin, 08/2016
     684              : ! **************************************************************************************************
     685            2 :    SUBROUTINE rpa_exchange_work_compute_hfx(exchange_work, fm_mat_S, eig, omega, e_exchange_corr)
     686              :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     687              :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
     688              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     689              :       REAL(KIND=dp), INTENT(IN)                          :: omega
     690              :       REAL(KIND=dp), INTENT(OUT) :: e_exchange_corr
     691              : 
     692              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_hfx'
     693              : 
     694              :       INTEGER                                            :: handle, ispin, my_aux_start, my_aux_end, &
     695              :                                                             my_aux_size, nspins, L_counter, dimen_ia, hom, virt
     696              :       REAL(KIND=dp)                                      :: e_exchange_P
     697            2 :       TYPE(dbcsr_matrix_p_set), DIMENSION(:), ALLOCATABLE          :: dbcsr_Gamma_3
     698              :       TYPE(cp_fm_type) :: fm_mat_Gamma_3
     699              :       TYPE(mp_para_env_type), POINTER :: para_env
     700              : 
     701            2 :       CALL timeset(routineN, handle)
     702              : 
     703            2 :       e_exchange_corr = 0.0_dp
     704              : 
     705            2 :       nspins = SIZE(fm_mat_S)
     706              : 
     707            2 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
     708              : 
     709            8 :       ALLOCATE (dbcsr_Gamma_3(nspins))
     710            4 :       DO ispin = 1, nspins
     711            2 :          hom = exchange_work%homo(ispin)
     712            2 :          virt = exchange_work%virtual(ispin)
     713            2 :          dimen_ia = hom*virt
     714            2 :          IF (hom < 1 .OR. virt < 1) CYCLE
     715              : 
     716            2 :          CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
     717              : 
     718            2 :          CALL cp_fm_create(fm_mat_Gamma_3, exchange_work%exchange_env%struct_Gamma(ispin)%struct)
     719            2 :          CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
     720              : 
     721              :          ! Update G with a new value of Omega: in practice, it is G*S
     722              : 
     723              :          ! Scale fm_work_iaP
     724              :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
     725            2 :                                 hom, omega, 0.0_dp)
     726              : 
     727              :          ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
     728              :          CALL parallel_gemm(transa="T", transb="N", m=dimen_ia, n=exchange_work%dimen_RI, &
     729              :                             k=exchange_work%dimen_RI, alpha=1.0_dp, &
     730              :                             matrix_a=fm_mat_S(ispin), matrix_b=exchange_work%fm_mat_R_half_gemm, beta=0.0_dp, &
     731            2 :                             matrix_c=fm_mat_Gamma_3)
     732              : 
     733              :          ! Remove extra factor from S after the multiplication (to return to the original matrix)
     734            2 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     735              : 
     736              :          ! Copy Gamma_ia_P^3 to dbcsr matrix set
     737              :          CALL gamma_fm_to_dbcsr(fm_mat_Gamma_3, dbcsr_Gamma_3(ispin)%matrix_set, &
     738              :                                 para_env, exchange_work%para_env_sub, hom, virt, &
     739              :                                 exchange_work%exchange_env%mo_coeff_o(ispin), &
     740            6 :                                 exchange_work%ngroup, my_aux_start, my_aux_end, my_aux_size)
     741              :       END DO
     742              : 
     743           85 :       DO L_counter = 1, my_aux_size
     744          166 :          DO ispin = 1, nspins
     745              :             ! Do dbcsr multiplication: transform the virtual index
     746              :             CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mo_coeff_v(ispin), &
     747              :                                 dbcsr_Gamma_3(ispin)%matrix_set(L_counter), &
     748              :                                 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     749           83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     750              : 
     751           83 :             CALL dbcsr_release(dbcsr_Gamma_3(ispin)%matrix_set(L_counter))
     752              : 
     753              :             ! Do dbcsr multiplication: transform the occupied index
     754              :             CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     755              :                                 exchange_work%exchange_env%mo_coeff_o(ispin), &
     756              :                                 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     757           83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     758              :             CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%mo_coeff_o(ispin), &
     759              :                                 exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     760              :                                 1.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     761           83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     762              : 
     763          166 :             CALL dbcsr_set(exchange_work%exchange_env%mat_hfx(ispin)%matrix, 0.0_dp)
     764              :          END DO
     765              : 
     766              :          CALL tddft_hfx_matrix(exchange_work%exchange_env%mat_hfx, exchange_work%exchange_env%dbcsr_Gamma_munu_P, &
     767              :                                exchange_work%exchange_env%qs_env, .FALSE., &
     768              :                                exchange_work%exchange_env%my_recalc_hfx_integrals, &
     769              :                                exchange_work%exchange_env%hfx_sections, exchange_work%exchange_env%x_data, &
     770           83 :                                exchange_work%exchange_env%para_env)
     771              : 
     772           83 :          exchange_work%exchange_env%my_recalc_hfx_integrals = .FALSE.
     773          168 :          DO ispin = 1, nspins
     774              :             CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mat_hfx(ispin)%matrix, &
     775              :                                 exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     776           83 :                                 0.0_dp, exchange_work%exchange_env%work_ao, filter_eps=exchange_work%exchange_env%eps_filter)
     777           83 :             CALL dbcsr_trace(exchange_work%exchange_env%work_ao, e_exchange_P)
     778          166 :             e_exchange_corr = e_exchange_corr - e_exchange_P
     779              :          END DO
     780              :       END DO
     781              : 
     782              :       IF (nspins == 2) e_exchange_corr = e_exchange_corr
     783            2 :       IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
     784              : 
     785            2 :       CALL timestop(handle)
     786              : 
     787            6 :    END SUBROUTINE rpa_exchange_work_compute_hfx
     788              : 
     789              : ! **************************************************************************************************
     790              : !> \brief ...
     791              : !> \param exchange_work ...
     792              : !> \param fm_mat ...
     793              : !> \param mat ...
     794              : !> \param ispin ...
     795              : !> \param virt_dist ...
     796              : ! **************************************************************************************************
     797           24 :    SUBROUTINE redistribute_into_subgroups(exchange_work, fm_mat, mat, ispin, virt_dist)
     798              :       CLASS(rpa_exchange_work_type), INTENT(IN) :: exchange_work
     799              :       TYPE(cp_fm_type), INTENT(IN)                       :: fm_mat
     800              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
     801              :          INTENT(OUT)                                     :: mat
     802              :       INTEGER, INTENT(IN)                                :: ispin
     803              :       TYPE(group_dist_d1_type), INTENT(IN)               :: virt_dist
     804              : 
     805              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'redistribute_into_subgroups'
     806              : 
     807              :       INTEGER :: aux_counter, aux_global, aux_local, aux_proc, avirt, dimen_RI, handle, handle2, &
     808              :                  ia_global, ia_local, iocc, max_number_recv, max_number_send, my_aux_end, my_aux_size, &
     809              :                  my_aux_start, my_process_column, my_process_row, my_virt_end, my_virt_size, &
     810              :                  my_virt_start, proc, proc_shift, recv_proc, send_proc, virt_counter, virt_proc, group_size
     811           24 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: data2send, recv_col_indices, &
     812           24 :                                             recv_row_indices, send_aux_indices, send_virt_indices, virt2send
     813              :       INTEGER, DIMENSION(2)                              :: recv_shape
     814           24 :       INTEGER, DIMENSION(:), POINTER                     :: aux_distribution_fm, col_indices, &
     815           24 :                                                             ia_distribution_fm, row_indices
     816           24 :       INTEGER, DIMENSION(:, :), POINTER                  :: mpi2blacs
     817           24 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET   :: recv_buffer, send_buffer
     818              :       REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
     819           24 :          POINTER                                         :: recv_ptr, send_ptr
     820              :       TYPE(cp_blacs_env_type), POINTER                   :: context
     821              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     822              : 
     823           24 :       CALL timeset(routineN, handle)
     824              : 
     825              :       CALL cp_fm_get_info(matrix=fm_mat, &
     826              :                           nrow_locals=aux_distribution_fm, &
     827              :                           col_indices=col_indices, &
     828              :                           row_indices=row_indices, &
     829              :                           ncol_locals=ia_distribution_fm, &
     830              :                           context=context, &
     831              :                           nrow_global=dimen_RI, &
     832           24 :                           para_env=para_env)
     833              : 
     834           24 :       IF (exchange_work%homo(ispin) <= 0 .OR. exchange_work%virtual(ispin) <= 0) THEN
     835            0 :          CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
     836            0 :          ALLOCATE (mat(exchange_work%homo(ispin), my_virt_size, dimen_RI))
     837            0 :          CALL timestop(handle)
     838            0 :          RETURN
     839              :       END IF
     840              : 
     841           24 :       group_size = exchange_work%para_env_sub%num_pe
     842              : 
     843           24 :       CALL timeset(routineN//"_prep", handle2)
     844           24 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
     845           24 :       CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
     846           24 :       CALL context%get(my_process_column=my_process_column, my_process_row=my_process_row, mpi2blacs=mpi2blacs)
     847              : 
     848              :       ! Determine the number of columns to send
     849          120 :       ALLOCATE (send_aux_indices(MAXVAL(exchange_work%aux2send)))
     850           72 :       ALLOCATE (virt2send(0:group_size - 1))
     851           48 :       virt2send = 0
     852         1924 :       DO ia_local = 1, ia_distribution_fm(my_process_column)
     853         1900 :          ia_global = col_indices(ia_local)
     854         1900 :          avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     855         1900 :          proc = group_dist_proc(virt_dist, avirt)
     856         1924 :          virt2send(proc) = virt2send(proc) + 1
     857              :       END DO
     858              : 
     859           72 :       ALLOCATE (data2send(0:para_env%num_pe - 1))
     860           72 :       DO aux_proc = 0, exchange_work%ngroup - 1
     861          120 :       DO virt_proc = 0, group_size - 1
     862           96 :          data2send(aux_proc*group_size + virt_proc) = exchange_work%aux2send(aux_proc)*virt2send(virt_proc)
     863              :       END DO
     864              :       END DO
     865              : 
     866           96 :       ALLOCATE (send_virt_indices(MAXVAL(virt2send)))
     867           72 :       max_number_send = MAXVAL(data2send)
     868              : 
     869           72 :       ALLOCATE (send_buffer(INT(max_number_send, KIND=int_8)*exchange_work%homo(ispin)))
     870           24 :       max_number_recv = max_number_send
     871           24 :       CALL para_env%max(max_number_recv)
     872           72 :       ALLOCATE (recv_buffer(max_number_recv))
     873              : 
     874          120 :       ALLOCATE (mat(my_virt_size, exchange_work%homo(ispin), my_aux_size))
     875              : 
     876           24 :       CALL timestop(handle2)
     877              : 
     878           24 :       CALL timeset(routineN//"_own", handle2)
     879              :       ! Start with own data
     880         1026 :       DO aux_local = 1, aux_distribution_fm(my_process_row)
     881         1002 :          aux_global = row_indices(aux_local)
     882         1002 :          IF (aux_global < my_aux_start .OR. aux_global > my_aux_end) CYCLE
     883        40848 :          DO ia_local = 1, ia_distribution_fm(my_process_column)
     884        40318 :             ia_global = fm_mat%matrix_struct%col_indices(ia_local)
     885              : 
     886        40318 :             iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
     887        40318 :             avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     888              : 
     889        40318 :             IF (my_virt_start > avirt .OR. my_virt_end < avirt) CYCLE
     890              : 
     891        41320 :             mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = fm_mat%local_data(aux_local, ia_local)
     892              :          END DO
     893              :       END DO
     894           24 :       CALL timestop(handle2)
     895              : 
     896           48 :       DO proc_shift = 1, para_env%num_pe - 1
     897           24 :          send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
     898           24 :          recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
     899              : 
     900           24 :          CALL timeset(routineN//"_pack_buffer", handle2)
     901              :          send_ptr(1:virt2send(MOD(send_proc, group_size)), &
     902              :                   1:exchange_work%aux2send(send_proc/group_size)) => &
     903              :             send_buffer(1:INT(virt2send(MOD(send_proc, group_size)), KIND=int_8)* &
     904           24 :                         exchange_work%aux2send(send_proc/group_size))
     905              : ! Pack send buffer
     906           24 :          aux_counter = 0
     907         1026 :          DO aux_local = 1, aux_distribution_fm(my_process_row)
     908         1002 :             aux_global = row_indices(aux_local)
     909         1002 :             proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
     910         1002 :             IF (proc /= send_proc/group_size) CYCLE
     911          496 :             aux_counter = aux_counter + 1
     912          496 :             virt_counter = 0
     913        40016 :             DO ia_local = 1, ia_distribution_fm(my_process_column)
     914        39520 :                ia_global = col_indices(ia_local)
     915        39520 :                avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     916              : 
     917        39520 :                proc = group_dist_proc(virt_dist, avirt)
     918        39520 :                IF (proc /= MOD(send_proc, group_size)) CYCLE
     919        39520 :                virt_counter = virt_counter + 1
     920        39520 :                send_ptr(virt_counter, aux_counter) = fm_mat%local_data(aux_local, ia_local)
     921        40016 :                send_virt_indices(virt_counter) = ia_global
     922              :             END DO
     923         1026 :             send_aux_indices(aux_counter) = aux_global
     924              :          END DO
     925           24 :          CALL timestop(handle2)
     926              : 
     927           24 :          CALL timeset(routineN//"_ex_size", handle2)
     928           24 :          recv_shape = [1, 1]
     929           72 :          CALL para_env%sendrecv(SHAPE(send_ptr), send_proc, recv_shape, recv_proc)
     930           24 :          CALL timestop(handle2)
     931              : 
     932           72 :          IF (SIZE(send_ptr) == 0) send_proc = mp_proc_null
     933           72 :          IF (PRODUCT(recv_shape) == 0) recv_proc = mp_proc_null
     934              : 
     935           24 :          CALL timeset(routineN//"_ex_idx", handle2)
     936          120 :          ALLOCATE (recv_row_indices(recv_shape(1)), recv_col_indices(recv_shape(2)))
     937           24 :          CALL para_env%sendrecv(send_virt_indices(1:virt_counter), send_proc, recv_row_indices, recv_proc)
     938           24 :          CALL para_env%sendrecv(send_aux_indices(1:aux_counter), send_proc, recv_col_indices, recv_proc)
     939           24 :          CALL timestop(handle2)
     940              : 
     941              :          ! Prepare pointer to recv buffer (consider transposition while packing the send buffer)
     942           24 :          recv_ptr(1:recv_shape(1), 1:MAX(1, recv_shape(2))) => recv_buffer(1:recv_shape(1)*MAX(1, recv_shape(2)))
     943              : 
     944           24 :          CALL timeset(routineN//"_sendrecv", handle2)
     945              : ! Perform communication
     946           24 :          CALL para_env%sendrecv(send_ptr, send_proc, recv_ptr, recv_proc)
     947           24 :          CALL timestop(handle2)
     948              : 
     949           24 :          IF (recv_proc == mp_proc_null) THEN
     950            0 :             DEALLOCATE (recv_row_indices, recv_col_indices)
     951            0 :             CYCLE
     952              :          END IF
     953              : 
     954           24 :          CALL timeset(routineN//"_unpack", handle2)
     955              : ! Unpack receive buffer
     956          520 :          DO aux_local = 1, SIZE(recv_col_indices)
     957          496 :             aux_global = recv_col_indices(aux_local)
     958              : 
     959        40040 :             DO ia_local = 1, SIZE(recv_row_indices)
     960        39520 :                ia_global = recv_row_indices(ia_local)
     961              : 
     962        39520 :                iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
     963        39520 :                avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     964              : 
     965        40016 :                mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = recv_ptr(ia_local, aux_local)
     966              :             END DO
     967              :          END DO
     968           24 :          CALL timestop(handle2)
     969              : 
     970           24 :          IF (ALLOCATED(recv_row_indices)) DEALLOCATE (recv_row_indices)
     971          168 :          IF (ALLOCATED(recv_col_indices)) DEALLOCATE (recv_col_indices)
     972              :       END DO
     973              : 
     974           24 :       DEALLOCATE (send_aux_indices, send_virt_indices)
     975              : 
     976           24 :       CALL timestop(handle)
     977              : 
     978           96 :    END SUBROUTINE redistribute_into_subgroups
     979              : 
     980            0 : END MODULE rpa_exchange
        

Generated by: LCOV version 2.0-1