LCOV - code coverage report
Current view: top level - src - hfx_ri_kp.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b1f098b) Lines: 2160 2190 98.6 %
Date: 2024-05-05 06:30:09 Functions: 38 40 95.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief RI-methods for HFX and K-points.
      10             : !> \auhtor Augustin Bussy (01.2023)
      11             : ! **************************************************************************************************
      12             : 
      13             : MODULE hfx_ri_kp
      14             :    USE admm_types,                      ONLY: get_admm_env
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      16             :                                               get_atomic_kind_set
      17             :    USE basis_set_types,                 ONLY: get_gto_basis_set,&
      18             :                                               gto_basis_set_p_type
      19             :    USE cell_types,                      ONLY: cell_type,&
      20             :                                               pbc,&
      21             :                                               real_to_scaled,&
      22             :                                               scaled_to_real
      23             :    USE cp_array_utils,                  ONLY: cp_1d_logical_p_type,&
      24             :                                               cp_2d_r_p_type,&
      25             :                                               cp_3d_r_p_type
      26             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_create,&
      27             :                                               cp_blacs_env_release,&
      28             :                                               cp_blacs_env_type
      29             :    USE cp_control_types,                ONLY: dft_control_type
      30             :    USE cp_dbcsr_cholesky,               ONLY: cp_dbcsr_cholesky_decompose,&
      31             :                                               cp_dbcsr_cholesky_invert
      32             :    USE cp_dbcsr_cp2k_link,              ONLY: cp_dbcsr_alloc_block_from_nbl
      33             :    USE cp_dbcsr_diag,                   ONLY: cp_dbcsr_power
      34             :    USE cp_dbcsr_operations,             ONLY: cp_dbcsr_dist2d_to_dist
      35             :    USE dbcsr_api,                       ONLY: &
      36             :         dbcsr_add, dbcsr_clear, dbcsr_copy, dbcsr_create, dbcsr_distribution_get, &
      37             :         dbcsr_distribution_new, dbcsr_distribution_release, dbcsr_distribution_type, dbcsr_dot, &
      38             :         dbcsr_filter, dbcsr_finalize, dbcsr_get_block_p, dbcsr_get_info, &
      39             :         dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, dbcsr_iterator_start, &
      40             :         dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, dbcsr_put_block, dbcsr_release, &
      41             :         dbcsr_type, dbcsr_type_no_symmetry, dbcsr_type_symmetric
      42             :    USE dbt_api,                         ONLY: &
      43             :         dbt_batched_contract_finalize, dbt_batched_contract_init, dbt_clear, dbt_contract, &
      44             :         dbt_copy, dbt_copy_matrix_to_tensor, dbt_copy_tensor_to_matrix, dbt_create, dbt_destroy, &
      45             :         dbt_distribution_destroy, dbt_distribution_new, dbt_distribution_type, dbt_filter, &
      46             :         dbt_finalize, dbt_get_block, dbt_get_info, dbt_get_stored_coordinates, &
      47             :         dbt_iterator_blocks_left, dbt_iterator_next_block, dbt_iterator_start, dbt_iterator_stop, &
      48             :         dbt_iterator_type, dbt_mp_environ_pgrid, dbt_pgrid_create, dbt_pgrid_destroy, &
      49             :         dbt_pgrid_type, dbt_put_block, dbt_scale, dbt_type
      50             :    USE distribution_2d_types,           ONLY: distribution_2d_release,&
      51             :                                               distribution_2d_type
      52             :    USE hfx_ri,                          ONLY: get_idx_to_atom,&
      53             :                                               hfx_ri_pre_scf_calc_tensors
      54             :    USE hfx_types,                       ONLY: hfx_ri_type
      55             :    USE input_constants,                 ONLY: do_potential_short,&
      56             :                                               hfx_ri_do_2c_cholesky,&
      57             :                                               hfx_ri_do_2c_diag,&
      58             :                                               hfx_ri_do_2c_iter
      59             :    USE input_cp2k_hfx,                  ONLY: ri_pmat
      60             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      61             :                                               section_vals_type,&
      62             :                                               section_vals_val_get,&
      63             :                                               section_vals_val_set
      64             :    USE iterate_matrix,                  ONLY: invert_hotelling
      65             :    USE kinds,                           ONLY: dp,&
      66             :                                               int_8
      67             :    USE kpoint_types,                    ONLY: get_kpoint_info,&
      68             :                                               kpoint_type
      69             :    USE libint_2c_3c,                    ONLY: cutoff_screen_factor
      70             :    USE machine,                         ONLY: m_flush,&
      71             :                                               m_walltime
      72             :    USE mathlib,                         ONLY: erfc_cutoff
      73             :    USE message_passing,                 ONLY: mp_cart_type,&
      74             :                                               mp_para_env_type,&
      75             :                                               mp_request_type,&
      76             :                                               mp_waitall
      77             :    USE particle_methods,                ONLY: get_particle_set
      78             :    USE particle_types,                  ONLY: particle_type
      79             :    USE physcon,                         ONLY: angstrom
      80             :    USE qs_environment_types,            ONLY: get_qs_env,&
      81             :                                               qs_environment_type
      82             :    USE qs_force_types,                  ONLY: qs_force_type
      83             :    USE qs_integral_utils,               ONLY: basis_set_list_setup
      84             :    USE qs_interactions,                 ONLY: init_interaction_radii_orb_basis
      85             :    USE qs_kind_types,                   ONLY: qs_kind_type
      86             :    USE qs_neighbor_list_types,          ONLY: get_iterator_info,&
      87             :                                               neighbor_list_iterate,&
      88             :                                               neighbor_list_iterator_create,&
      89             :                                               neighbor_list_iterator_p_type,&
      90             :                                               neighbor_list_iterator_release,&
      91             :                                               neighbor_list_set_p_type,&
      92             :                                               release_neighbor_list_sets
      93             :    USE qs_scf_types,                    ONLY: qs_scf_env_type
      94             :    USE qs_tensors,                      ONLY: &
      95             :         build_2c_derivatives, build_2c_neighbor_lists, build_3c_derivatives, &
      96             :         build_3c_neighbor_lists, get_3c_iterator_info, get_tensor_occupancy, &
      97             :         neighbor_list_3c_destroy, neighbor_list_3c_iterate, neighbor_list_3c_iterator_create, &
      98             :         neighbor_list_3c_iterator_destroy
      99             :    USE qs_tensors_types,                ONLY: create_2c_tensor,&
     100             :                                               create_3c_tensor,&
     101             :                                               create_tensor_batches,&
     102             :                                               distribution_2d_create,&
     103             :                                               distribution_3d_create,&
     104             :                                               distribution_3d_type,&
     105             :                                               neighbor_list_3c_iterator_type,&
     106             :                                               neighbor_list_3c_type
     107             :    USE util,                            ONLY: get_limit
     108             :    USE virial_types,                    ONLY: virial_type
     109             : #include "./base/base_uses.f90"
     110             : 
     111             : !$ USE OMP_LIB, ONLY: omp_get_num_threads
     112             : 
     113             :    IMPLICIT NONE
     114             :    PRIVATE
     115             : 
     116             :    PUBLIC :: hfx_ri_update_ks_kp, hfx_ri_update_forces_kp
     117             : 
     118             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'hfx_ri_kp'
     119             : CONTAINS
     120             : 
     121             : ! NOTES: for a start, we do not seek performance, but accuracy. So in this first implementation,
     122             : !        we give little consideration to batching, load balance and such.
     123             : !        We also put everything here, even if there is some code replication with the original RI_HFX
     124             : !        We will only work in the RHO flavor
     125             : !        For now, we will also always assume that there is a single para_env, and that there is no
     126             : !        K-point subgroup. This might change in the future
     127             : 
     128             : ! **************************************************************************************************
     129             : !> \brief I_1nitialize the ri_data for K-point. For now, we take the normal, usual existing ri_data
     130             : !>        and we adapt it to our needs
     131             : !> \param dbcsr_template ...
     132             : !> \param ri_data ...
     133             : !> \param qs_env ...
     134             : ! **************************************************************************************************
     135          70 :    SUBROUTINE adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     136             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     137             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     138             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     139             : 
     140             :       INTEGER                                            :: i_img, i_RI, i_spin, iatom, natom, &
     141             :                                                             nblks_RI, nimg, nkind, nspins
     142          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, dist1, dist2, dist3
     143             :       TYPE(dft_control_type), POINTER                    :: dft_control
     144             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     145             : 
     146          70 :       NULLIFY (dft_control, para_env)
     147             : 
     148             :       !The main thing that we need to do is to allocate more space for the integrals, such that there
     149             :       !is room for each periodic image. Note that we only go in 1D, i.e. we store (mu^0 sigma^a|P^0),
     150             :       !and (P^0|Q^a) => the RI basis is always in the main cell.
     151             : 
     152             :       !Get kpoint info
     153          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, para_env=para_env, nkind=nkind)
     154          70 :       nimg = ri_data%nimg
     155             : 
     156             :       !Along the RI direction we have basis elements spread accross ncell_RI images.
     157          70 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
     158         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     159         506 :       DO i_RI = 1, ri_data%ncell_RI
     160        2344 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
     161             :       END DO
     162             : 
     163        4058 :       ALLOCATE (ri_data%t_3c_int_ctr_1(1, nimg))
     164             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_1(1, 1), dist1, dist2, dist3, &
     165             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     166          70 :                             ri_data%bsizes_AO_split, [1, 2], [3], name="(AO RI | AO)")
     167             : 
     168        1644 :       DO i_img = 2, nimg
     169        1644 :          CALL dbt_create(ri_data%t_3c_int_ctr_1(1, 1), ri_data%t_3c_int_ctr_1(1, i_img))
     170             :       END DO
     171          70 :       DEALLOCATE (dist1, dist2, dist3)
     172             : 
     173         770 :       ALLOCATE (ri_data%t_3c_int_ctr_2(1, 1))
     174             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_2(1, 1), dist1, dist2, dist3, &
     175             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     176          70 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
     177          70 :       DEALLOCATE (dist1, dist2, dist3)
     178             : 
     179             :       !We use full block sizes for the 2c quantities
     180          70 :       DEALLOCATE (bsizes_RI_ext)
     181          70 :       nblks_RI = SIZE(ri_data%bsizes_RI)
     182         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     183         506 :       DO i_RI = 1, ri_data%ncell_RI
     184        1378 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
     185             :       END DO
     186             : 
     187        3010 :       ALLOCATE (ri_data%t_2c_inv(1, natom), ri_data%t_2c_int(1, natom), ri_data%t_2c_pot(1, natom))
     188             :       CALL create_2c_tensor(ri_data%t_2c_inv(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     189             :                             bsizes_RI_ext, bsizes_RI_ext, &
     190          70 :                             name="(RI | RI)")
     191          70 :       DEALLOCATE (dist1, dist2)
     192          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, 1))
     193          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, 1))
     194         140 :       DO iatom = 2, natom
     195          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_inv(1, iatom))
     196          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, iatom))
     197         140 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, iatom))
     198             :       END DO
     199             : 
     200         350 :       ALLOCATE (ri_data%kp_cost(natom, natom, nimg))
     201       11578 :       ri_data%kp_cost = 0.0_dp
     202             : 
     203             :       !We store the density and KS matrix in tensor format
     204          70 :       nspins = dft_control%nspins
     205        8930 :       ALLOCATE (ri_data%rho_ao_t(nspins, nimg), ri_data%ks_t(nspins, nimg))
     206             :       CALL create_2c_tensor(ri_data%rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     207             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     208          70 :                             name="(AO | AO)")
     209          70 :       DEALLOCATE (dist1, dist2)
     210             : 
     211          70 :       CALL dbt_create(dbcsr_template, ri_data%ks_t(1, 1))
     212             : 
     213          70 :       IF (nspins == 2) THEN
     214          24 :          CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(2, 1))
     215          24 :          CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(2, 1))
     216             :       END IF
     217        1644 :       DO i_img = 2, nimg
     218        3566 :          DO i_spin = 1, nspins
     219        1922 :             CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(i_spin, i_img))
     220        3496 :             CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(i_spin, i_img))
     221             :          END DO
     222             :       END DO
     223             : 
     224         210 :    END SUBROUTINE adapt_ri_data_to_kp
     225             : 
     226             : ! **************************************************************************************************
     227             : !> \brief The pre-scf steps for RI-HFX k-points calculation. Namely the calculation of the integrals
     228             : !> \param dbcsr_template ...
     229             : !> \param ri_data ...
     230             : !> \param qs_env ...
     231             : ! **************************************************************************************************
     232          70 :    SUBROUTINE hfx_ri_pre_scf_kp(dbcsr_template, ri_data, qs_env)
     233             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     234             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     235             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     236             : 
     237             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'hfx_ri_pre_scf_kp'
     238             : 
     239             :       INTEGER                                            :: handle, i_img, iatom, natom, nimg, nkind
     240          70 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: t_2c_op_pot, t_2c_op_RI
     241          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: t_3c_int
     242             :       TYPE(dft_control_type), POINTER                    :: dft_control
     243             : 
     244          70 :       NULLIFY (dft_control)
     245             : 
     246          70 :       CALL timeset(routineN, handle)
     247             : 
     248          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, nkind=nkind)
     249             : 
     250          70 :       CALL cleanup_kp(ri_data)
     251             : 
     252             :       !We do all the checks on what we allow in this initial implementation
     253          70 :       IF (ri_data%flavor .NE. ri_pmat) CPABORT("K-points RI-HFX only with RHO flavor")
     254          70 :       IF (ri_data%same_op) ri_data%same_op = .FALSE. !force the full calculation with RI metric
     255          70 :       IF (ABS(ri_data%eps_pgf_orb - dft_control%qs_control%eps_pgf_orb) > 1.0E-16_dp) &
     256           0 :          CPABORT("RI%EPS_PGF_ORB and QS%EPS_PGF_ORB must be identical for RI-HFX k-points")
     257             : 
     258          70 :       CALL get_kp_and_ri_images(ri_data, qs_env)
     259          70 :       nimg = ri_data%nimg
     260             : 
     261             :       !Calculate the integrals
     262        3568 :       ALLOCATE (t_2c_op_pot(nimg), t_2c_op_RI(nimg))
     263        4058 :       ALLOCATE (t_3c_int(1, nimg))
     264          70 :       CALL hfx_ri_pre_scf_calc_tensors(qs_env, ri_data, t_2c_op_RI, t_2c_op_pot, t_3c_int, do_kpoints=.TRUE.)
     265             : 
     266             :       !Make sure the internals have the k-point format
     267          70 :       CALL adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     268             : 
     269             :       !For each atom i, we calculate the inverse RI metric (P^0 | Q^0)^-1 without external bumping yet
     270             :       !Also store the off-diagonal integrals of the RI metric in case of forces, bumped from the left
     271         210 :       DO iatom = 1, natom
     272             :          CALL get_ext_2c_int(ri_data%t_2c_inv(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     273         140 :                              do_inverse=.TRUE.)
     274             :          !for the forces:
     275             :          !off-diagonl RI metric bumped from the left
     276             :          CALL get_ext_2c_int(ri_data%t_2c_int(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, &
     277         140 :                              qs_env, off_diagonal=.TRUE.)
     278         140 :          CALL apply_bump(ri_data%t_2c_int(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     279             : 
     280             :          !RI metric with bumped off-diagonal blocks (but not inverted), depumed from left and right
     281             :          CALL get_ext_2c_int(ri_data%t_2c_pot(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     282         140 :                              do_inverse=.TRUE., skip_inverse=.TRUE.)
     283             :          CALL apply_bump(ri_data%t_2c_pot(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., &
     284         210 :                          from_right=.TRUE., debump=.TRUE.)
     285             : 
     286             :       END DO
     287             : 
     288        1714 :       DO i_img = 1, nimg
     289        1714 :          CALL dbcsr_release(t_2c_op_RI(i_img))
     290             :       END DO
     291             : 
     292        3428 :       ALLOCATE (ri_data%kp_mat_2c_pot(1, nimg))
     293        1714 :       DO i_img = 1, nimg
     294        1644 :          CALL dbcsr_create(ri_data%kp_mat_2c_pot(1, i_img), template=t_2c_op_pot(i_img))
     295        1644 :          CALL dbcsr_copy(ri_data%kp_mat_2c_pot(1, i_img), t_2c_op_pot(i_img))
     296        1714 :          CALL dbcsr_release(t_2c_op_pot(i_img))
     297             :       END DO
     298             : 
     299             :       !Pre-contract all 3c integrals with the bumped inverse RI metric (P^0|Q^0)^-1,
     300             :       !and store in ri_data%t_3c_int_ctr_1
     301          70 :       CALL precontract_3c_ints(t_3c_int, ri_data, qs_env)
     302             : 
     303             :       !reorder the 3c integrals such that empty images are bunched up together
     304          70 :       CALL reorder_3c_ints(ri_data%t_3c_int_ctr_1(1, :), ri_data)
     305             : 
     306          70 :       CALL timestop(handle)
     307             : 
     308        1784 :    END SUBROUTINE hfx_ri_pre_scf_kp
     309             : 
     310             : ! **************************************************************************************************
     311             : !> \brief clean-up the KP specific data from ri_data
     312             : !> \param ri_data ...
     313             : ! **************************************************************************************************
     314          70 :    SUBROUTINE cleanup_kp(ri_data)
     315             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     316             : 
     317             :       INTEGER                                            :: i, j
     318             : 
     319          70 :       IF (ALLOCATED(ri_data%kp_cost)) DEALLOCATE (ri_data%kp_cost)
     320          70 :       IF (ALLOCATED(ri_data%idx_to_img)) DEALLOCATE (ri_data%idx_to_img)
     321          70 :       IF (ALLOCATED(ri_data%img_to_idx)) DEALLOCATE (ri_data%img_to_idx)
     322          70 :       IF (ALLOCATED(ri_data%present_images)) DEALLOCATE (ri_data%present_images)
     323          70 :       IF (ALLOCATED(ri_data%img_to_RI_cell)) DEALLOCATE (ri_data%img_to_RI_cell)
     324          70 :       IF (ALLOCATED(ri_data%RI_cell_to_img)) DEALLOCATE (ri_data%RI_cell_to_img)
     325             : 
     326          70 :       IF (ALLOCATED(ri_data%kp_mat_2c_pot)) THEN
     327         540 :          DO j = 1, SIZE(ri_data%kp_mat_2c_pot, 2)
     328        1060 :             DO i = 1, SIZE(ri_data%kp_mat_2c_pot, 1)
     329        1040 :                CALL dbcsr_release(ri_data%kp_mat_2c_pot(i, j))
     330             :             END DO
     331             :          END DO
     332          20 :          DEALLOCATE (ri_data%kp_mat_2c_pot)
     333             :       END IF
     334             : 
     335          70 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
     336         540 :          DO i = 1, SIZE(ri_data%kp_t_3c_int)
     337         540 :             CALL dbt_destroy(ri_data%kp_t_3c_int(i))
     338             :          END DO
     339         540 :          DEALLOCATE (ri_data%kp_t_3c_int)
     340             :       END IF
     341             : 
     342          70 :       IF (ALLOCATED(ri_data%t_2c_inv)) THEN
     343         160 :          DO j = 1, SIZE(ri_data%t_2c_inv, 2)
     344         250 :             DO i = 1, SIZE(ri_data%t_2c_inv, 1)
     345         180 :                CALL dbt_destroy(ri_data%t_2c_inv(i, j))
     346             :             END DO
     347             :          END DO
     348         160 :          DEALLOCATE (ri_data%t_2c_inv)
     349             :       END IF
     350             : 
     351          70 :       IF (ALLOCATED(ri_data%t_2c_int)) THEN
     352         160 :          DO j = 1, SIZE(ri_data%t_2c_int, 2)
     353         250 :             DO i = 1, SIZE(ri_data%t_2c_int, 1)
     354         180 :                CALL dbt_destroy(ri_data%t_2c_int(i, j))
     355             :             END DO
     356             :          END DO
     357         160 :          DEALLOCATE (ri_data%t_2c_int)
     358             :       END IF
     359             : 
     360          70 :       IF (ALLOCATED(ri_data%t_2c_pot)) THEN
     361         160 :          DO j = 1, SIZE(ri_data%t_2c_pot, 2)
     362         250 :             DO i = 1, SIZE(ri_data%t_2c_pot, 1)
     363         180 :                CALL dbt_destroy(ri_data%t_2c_pot(i, j))
     364             :             END DO
     365             :          END DO
     366         160 :          DEALLOCATE (ri_data%t_2c_pot)
     367             :       END IF
     368             : 
     369          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_1)) THEN
     370         640 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_1, 2)
     371        1210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_1, 1)
     372        1140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_1(i, j))
     373             :             END DO
     374             :          END DO
     375         640 :          DEALLOCATE (ri_data%t_3c_int_ctr_1)
     376             :       END IF
     377             : 
     378          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_2)) THEN
     379         140 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_2, 2)
     380         210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_2, 1)
     381         140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_2(i, j))
     382             :             END DO
     383             :          END DO
     384         140 :          DEALLOCATE (ri_data%t_3c_int_ctr_2)
     385             :       END IF
     386             : 
     387          70 :       IF (ALLOCATED(ri_data%rho_ao_t)) THEN
     388         640 :          DO j = 1, SIZE(ri_data%rho_ao_t, 2)
     389        1428 :             DO i = 1, SIZE(ri_data%rho_ao_t, 1)
     390        1358 :                CALL dbt_destroy(ri_data%rho_ao_t(i, j))
     391             :             END DO
     392             :          END DO
     393         858 :          DEALLOCATE (ri_data%rho_ao_t)
     394             :       END IF
     395             : 
     396          70 :       IF (ALLOCATED(ri_data%ks_t)) THEN
     397         640 :          DO j = 1, SIZE(ri_data%ks_t, 2)
     398        1428 :             DO i = 1, SIZE(ri_data%ks_t, 1)
     399        1358 :                CALL dbt_destroy(ri_data%ks_t(i, j))
     400             :             END DO
     401             :          END DO
     402         858 :          DEALLOCATE (ri_data%ks_t)
     403             :       END IF
     404             : 
     405          70 :    END SUBROUTINE cleanup_kp
     406             : 
     407             : ! **************************************************************************************************
     408             : !> \brief Update the KS matrices for each real-space image
     409             : !> \param qs_env ...
     410             : !> \param ri_data ...
     411             : !> \param ks_matrix ...
     412             : !> \param ehfx ...
     413             : !> \param rho_ao ...
     414             : !> \param geometry_did_change ...
     415             : !> \param nspins ...
     416             : !> \param hf_fraction ...
     417             : ! **************************************************************************************************
     418         214 :    SUBROUTINE hfx_ri_update_ks_kp(qs_env, ri_data, ks_matrix, ehfx, rho_ao, &
     419             :                                   geometry_did_change, nspins, hf_fraction)
     420             : 
     421             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     422             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     423             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     424             :       REAL(KIND=dp), INTENT(OUT)                         :: ehfx
     425             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     426             :       LOGICAL, INTENT(IN)                                :: geometry_did_change
     427             :       INTEGER, INTENT(IN)                                :: nspins
     428             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     429             : 
     430             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_ks_kp'
     431             : 
     432             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_spin, iatom, &
     433             :          iblk, igroup, jatom, mb_img, n_batch_nze, natom, ngroups, nimg, nimg_nze
     434             :       INTEGER(int_8)                                     :: nflop, nze
     435         214 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_at, batch_ranges_nze, &
     436         214 :                                                             idx_to_at_AO
     437         214 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     438         214 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: sparsity_pattern
     439             :       LOGICAL                                            :: use_delta_p
     440             :       REAL(dp)                                           :: etmp, fac, occ, pfac, pref, t1, t2, t3, &
     441             :                                                             t4
     442             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     443             :       TYPE(dbcsr_type)                                   :: ks_desymm, rho_desymm, tmp
     444         214 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     445             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     446         214 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ks_t_split, t_2c_ao_tmp, t_2c_work, &
     447         214 :                                                             t_3c_int, t_3c_work_2, t_3c_work_3
     448         214 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: ks_t, ks_t_sub, t_3c_apc, t_3c_apc_sub
     449             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     450             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     451             : 
     452         214 :       NULLIFY (para_env, para_env_sub, blacs_env_sub, hfx_section, dbcsr_template)
     453             : 
     454         214 :       CALL timeset(routineN, handle)
     455             : 
     456         214 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natom)
     457             : 
     458         214 :       IF (nspins == 1) THEN
     459         130 :          fac = 0.5_dp*hf_fraction
     460             :       ELSE
     461          84 :          fac = 1.0_dp*hf_fraction
     462             :       END IF
     463             : 
     464         214 :       IF (geometry_did_change) THEN
     465          70 :          CALL hfx_ri_pre_scf_kp(ks_matrix(1, 1)%matrix, ri_data, qs_env)
     466             :       END IF
     467         214 :       nimg = ri_data%nimg
     468         214 :       nimg_nze = ri_data%nimg_nze
     469             : 
     470             :       !We need to calculate the KS matrix for each periodic cell with index b: F_mu^0,nu^b
     471             :       !F_mu^0,nu^b = -0.5 sum_a,c P_sigma^0,lambda^c (mu^0, sigma^a| P^0) V_P^0,Q^b (Q^b| nu^b lambda^a+c)
     472             :       !with V_P^0,Q^b = (P^0|R^0)^-1 * (R^0|S^b) * (S^b|Q^b)^-1
     473             : 
     474             :       !We use a local RI basis set for each atom in the system, which inlcudes RI basis elements for
     475             :       !each neighboring atom standing within the KIND radius (decay of Gaussian with smallest exponent)
     476             : 
     477             :       !We also limit the number of periodic images we consider accorrding to the HFX potentail in the
     478             :       !RI basis, because if V_P^0,Q^b is zero everywhere, then image b can be ignored (RI basis less diffuse)
     479             : 
     480             :       !We manage to calculate each KS matrix doing a double loop on iamges, and a double loop on atoms
     481             :       !First, we pre-contract and store P_sigma^0,lambda^c (mu^0, sigma^a| P^0) (P^0|R^0)^-1 into T_mu^0,lambda^a+c,P^0
     482             :       !Then, we loop over b_img, iatom, jatom to get (R^0|S^b)
     483             :       !Finally, we do an additional loop over a+c images where we do (R^0|S^b) (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c)
     484             :       !and the final contraction with T_mu^0,lambda^a+c,P^0
     485             : 
     486             :       !Note that the 3-center integrals are pre-contracted with the RI metric, and that the same tensor can be used
     487             :       !(mu^0, sigma^a| P^0) (P^0|R^0)  <===> (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c) by relabelling the images
     488             : 
     489         214 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     490         214 :       CALL section_vals_val_get(hfx_section, "KP_USE_DELTA_P", l_val=use_delta_p)
     491             : 
     492             :       !By default, build the density tensor based on the difference of this SCF P and that of the prev. SCF
     493         214 :       pfac = -1.0_dp
     494         214 :       IF (.NOT. use_delta_p) pfac = 0.0_dp
     495         214 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, pfac, ri_data, qs_env)
     496             : 
     497       14740 :       ALLOCATE (ks_t(nspins, nimg))
     498        5514 :       DO i_img = 1, nimg
     499       12386 :          DO i_spin = 1, nspins
     500       12172 :             CALL dbt_create(ri_data%ks_t(1, 1), ks_t(i_spin, i_img))
     501             :          END DO
     502             :       END DO
     503             : 
     504         642 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     505         214 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     506             : 
     507             :       !First we calculate and store T^1_mu^0,lambda^a+c,P = P_mu^0,lambda^c * (mu_0 sigma^a | P^0) (P^0|R^0)^-1
     508             :       !To avoid doing nimg**2 tiny contractions that do not scale well with a large number of CPUs,
     509             :       !we instead do a single loop over the a+c image index. For each a+c, we get a list of allowed
     510             :       !combination of a,c indices. Then we build TAS tensors P_mu^0,lambda^c with all concerned c's
     511             :       !and (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 with all a's. Then we perform a single contraction with larger tensors,
     512             :       !were the sum over a,c is automatically taken care of
     513       14526 :       ALLOCATE (t_3c_apc(nspins, nimg))
     514        5514 :       DO i_img = 1, nimg
     515       12386 :          DO i_spin = 1, nspins
     516       12172 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     517             :          END DO
     518             :       END DO
     519         214 :       CALL contract_pmat_3c(t_3c_apc, ri_data%rho_ao_t, ri_data, qs_env)
     520             : 
     521         214 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     522         214 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     523         214 :       CALL section_vals_val_get(hfx_section, "KP_STACK_SIZE", i_val=batch_size)
     524         214 :       ri_data%kp_stack_size = batch_size
     525             : 
     526         214 :       IF (MOD(para_env%num_pe, ngroups) .NE. 0) THEN
     527           0 :          CPWARN("KP_NGROUPS must be an integer divisor of the total number of MPI ranks. It was set to 1.")
     528           0 :          ngroups = 1
     529           0 :          CALL section_vals_val_set(hfx_section, "KP_NGROUPS", i_val=ngroups)
     530             :       END IF
     531         214 :       IF ((MOD(ngroups, natom) .NE. 0) .AND. (MOD(natom, ngroups) .NE. 0) .AND. geometry_did_change) THEN
     532           0 :          IF (ngroups > 1) &
     533           0 :             CPWARN("Better load balancing is reached if NGROUPS is a multiple/divisor of the number of atoms")
     534             :       END IF
     535         214 :       group_size = para_env%num_pe/ngroups
     536         214 :       igroup = para_env%mepos/group_size
     537             : 
     538         214 :       ALLOCATE (para_env_sub)
     539         214 :       CALL para_env_sub%from_split(para_env, igroup)
     540         214 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     541             : 
     542             :       ! The sparsity pattern of each iatom, jatom pair, on each b_img, and on which subgroup
     543        1070 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     544         214 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     545         214 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     546             : 
     547             :       !Get all the required tensors in the subgroups
     548       26674 :       ALLOCATE (mat_2c_pot(nimg), ks_t_sub(nspins, nimg), t_2c_ao_tmp(1), ks_t_split(2), t_2c_work(3))
     549             :       CALL get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
     550         214 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     551             : 
     552       26888 :       ALLOCATE (t_3c_int(nimg), t_3c_apc_sub(nspins, nimg), t_3c_work_2(3), t_3c_work_3(3))
     553             :       CALL get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
     554         214 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     555             : 
     556             :       !We go atom by atom, therefore there is an automatic batching along that direction
     557             :       !Also, because we stack the 3c tensors nimg times, we naturally do some batching there too
     558         642 :       ALLOCATE (batch_ranges_at(natom + 1))
     559         214 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     560         214 :       iatom = 0
     561        1042 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     562        1042 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     563         428 :             iatom = iatom + 1
     564         428 :             batch_ranges_at(iatom) = iblk
     565             :          END IF
     566             :       END DO
     567             : 
     568         214 :       n_batch_nze = nimg_nze/batch_size
     569         214 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
     570         642 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
     571         452 :       DO i_batch = 1, n_batch_nze
     572         452 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     573             :       END DO
     574         214 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
     575             : 
     576         214 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     577         214 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     578         214 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     579         214 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     580             : 
     581         214 :       t1 = m_walltime()
     582       37314 :       ri_data%kp_cost(:, :, :) = 0.0_dp
     583         642 :       ALLOCATE (iapc_pairs(nimg, 2))
     584        5514 :       DO b_img = 1, nimg
     585        5300 :          CALL dbt_batched_contract_init(ks_t_split(1))
     586        5300 :          CALL dbt_batched_contract_init(ks_t_split(2))
     587       15900 :          DO jatom = 1, natom
     588       37100 :             DO iatom = 1, natom
     589       21200 :                IF (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup) CYCLE
     590        3606 :                pref = 1.0_dp
     591        3606 :                IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp
     592             : 
     593             :                !measure the cost of the given i, j, b configuration
     594        3606 :                t3 = m_walltime()
     595             : 
     596             :                !Get the proper HFX potential 2c integrals (R_i^0|S_j^b)
     597        3606 :                CALL timeset(routineN//"_2c", handle2)
     598             :                CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     599             :                                    blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     600        3606 :                                    dbcsr_template=dbcsr_template)
     601        3606 :                CALL dbt_copy(t_2c_work(1), t_2c_work(2), move_data=.TRUE.) !move to split blocks
     602        3606 :                CALL dbt_filter(t_2c_work(2), ri_data%filter_eps)
     603        3606 :                CALL timestop(handle2)
     604             : 
     605        3606 :                CALL dbt_batched_contract_init(t_2c_work(2))
     606        3606 :                CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env)
     607        3606 :                CALL timeset(routineN//"_3c", handle2)
     608             : 
     609             :                !Stack the (S^b|Q^b)^-1 * (Q^b| nu^b lambda^a+c) integrals over a+c and multiply by (R_i^0|S_j^b)
     610        8079 :                DO i_batch = 1, n_batch_nze
     611             :                   CALL fill_3c_stack(t_3c_work_3(3), t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
     612             :                                      filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
     613       13419 :                                      img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     614        4473 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1), move_data=.TRUE.)
     615             : 
     616             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_3c_work_3(1), &
     617             :                                     0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
     618             :                                     contract_1=[2], notcontract_1=[1], &
     619             :                                     contract_2=[1], notcontract_2=[2, 3], &
     620        4473 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     621        4473 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     622        4473 :                   CALL dbt_copy(t_3c_work_3(2), t_3c_work_2(2), order=[2, 1, 3], move_data=.TRUE.)
     623        4473 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1))
     624             : 
     625             :                   !Stack the P_sigma^a,lambda^a+c * (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 integrals over a+c and contract
     626             :                   !to get the final block of the KS matrix
     627       13750 :                   DO i_spin = 1, nspins
     628             :                      CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
     629             :                                         ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
     630       17013 :                                         img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     631        5671 :                      CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
     632        5671 :                      IF (nze == 0) CYCLE
     633        5651 :                      CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
     634             :                      CALL dbt_contract(-pref*fac, t_3c_work_2(1), t_3c_work_2(2), &
     635             :                                        1.0_dp, ks_t_split(i_spin), map_1=[1], map_2=[2], &
     636             :                                        contract_1=[2, 3], notcontract_1=[1], &
     637             :                                        contract_2=[2, 3], notcontract_2=[1], &
     638             :                                        filter_eps=ri_data%filter_eps, &
     639        5651 :                                        move_data=i_spin == nspins, flop=nflop)
     640       15795 :                      ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     641             :                   END DO
     642             :                END DO !i_batch
     643        3606 :                CALL timestop(handle2)
     644        3606 :                CALL dbt_batched_contract_finalize(t_2c_work(2))
     645             : 
     646        3606 :                t4 = m_walltime()
     647       39012 :                ri_data%kp_cost(iatom, jatom, b_img) = t4 - t3
     648             :             END DO !iatom
     649             :          END DO !jatom
     650        5300 :          CALL dbt_batched_contract_finalize(ks_t_split(1))
     651        5300 :          CALL dbt_batched_contract_finalize(ks_t_split(2))
     652             : 
     653       12386 :          DO i_spin = 1, nspins
     654        6872 :             CALL dbt_copy(ks_t_split(i_spin), t_2c_ao_tmp(1), move_data=.TRUE.)
     655       12172 :             CALL dbt_copy(t_2c_ao_tmp(1), ks_t_sub(i_spin, b_img), summation=.TRUE.)
     656             :          END DO
     657             :       END DO !b_img
     658         214 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
     659         214 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
     660         214 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
     661         214 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
     662         214 :       CALL para_env%sync()
     663         214 :       CALL para_env%sum(ri_data%dbcsr_nflop)
     664         214 :       CALL para_env%sum(ri_data%kp_cost)
     665         214 :       t2 = m_walltime()
     666         214 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
     667             : 
     668             :       !transfer KS tensor from subgroup to main group
     669         214 :       CALL gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
     670             : 
     671             :       !Keep the 3c integrals on the subgroups to avoid communication at next SCF step
     672        5514 :       DO i_img = 1, nimg
     673        5514 :          CALL dbt_copy(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img), move_data=.TRUE.)
     674             :       END DO
     675             : 
     676             :       !clean-up subgroup tensors
     677         214 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     678         214 :       CALL dbt_destroy(ks_t_split(1))
     679         214 :       CALL dbt_destroy(ks_t_split(2))
     680         214 :       CALL dbt_destroy(t_2c_work(1))
     681         214 :       CALL dbt_destroy(t_2c_work(2))
     682         214 :       CALL dbt_destroy(t_3c_work_2(1))
     683         214 :       CALL dbt_destroy(t_3c_work_2(2))
     684         214 :       CALL dbt_destroy(t_3c_work_2(3))
     685         214 :       CALL dbt_destroy(t_3c_work_3(1))
     686         214 :       CALL dbt_destroy(t_3c_work_3(2))
     687         214 :       CALL dbt_destroy(t_3c_work_3(3))
     688        5514 :       DO i_img = 1, nimg
     689        5300 :          CALL dbt_destroy(t_3c_int(i_img))
     690        5300 :          CALL dbcsr_release(mat_2c_pot(i_img))
     691       12386 :          DO i_spin = 1, nspins
     692        6872 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
     693       12172 :             CALL dbt_destroy(ks_t_sub(i_spin, i_img))
     694             :          END DO
     695             :       END DO
     696         214 :       IF (ASSOCIATED(dbcsr_template)) THEN
     697         214 :          CALL dbcsr_release(dbcsr_template)
     698         214 :          DEALLOCATE (dbcsr_template)
     699             :       END IF
     700             : 
     701             :       !End of subgroup parallelization
     702         214 :       CALL cp_blacs_env_release(blacs_env_sub)
     703         214 :       CALL para_env_sub%free()
     704         214 :       DEALLOCATE (para_env_sub)
     705             : 
     706             :       !Currently, rho_ao_t holds the density difference (wrt to pref SCF step).
     707             :       !ks_t also hold that diff, while only having half the blocks => need to add to prev ks_t and symmetrize
     708             :       !We need the full thing for the energy, on the next SCF step
     709         214 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     710         512 :       DO i_spin = 1, nspins
     711        7384 :          DO b_img = 1, nimg
     712        6872 :             CALL dbt_copy(ks_t(i_spin, b_img), ri_data%ks_t(i_spin, b_img), summation=.TRUE.)
     713             : 
     714             :             !desymmetrize
     715        6872 :             mb_img = get_opp_index(b_img, qs_env)
     716        7170 :             IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
     717        6194 :                CALL dbt_copy(ks_t(i_spin, mb_img), ri_data%ks_t(i_spin, b_img), order=[2, 1], summation=.TRUE.)
     718             :             END IF
     719             :          END DO
     720             :       END DO
     721        5514 :       DO b_img = 1, nimg
     722       12386 :          DO i_spin = 1, nspins
     723       12172 :             CALL dbt_destroy(ks_t(i_spin, b_img))
     724             :          END DO
     725             :       END DO
     726             : 
     727             :       !calculate the energy
     728         214 :       CALL dbt_create(ri_data%ks_t(1, 1), t_2c_ao_tmp(1))
     729         214 :       CALL dbcsr_create(tmp, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_symmetric)
     730         214 :       CALL dbcsr_create(ks_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     731         214 :       CALL dbcsr_create(rho_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     732         214 :       ehfx = 0.0_dp
     733        5514 :       DO i_img = 1, nimg
     734       12386 :          DO i_spin = 1, nspins
     735        6872 :             CALL dbt_filter(ri_data%ks_t(i_spin, i_img), ri_data%filter_eps)
     736        6872 :             CALL dbt_copy(ri_data%ks_t(i_spin, i_img), t_2c_ao_tmp(1))
     737        6872 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), ks_desymm)
     738        6872 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), tmp)
     739        6872 :             CALL dbcsr_add(ks_matrix(i_spin, i_img)%matrix, tmp, 1.0_dp, 1.0_dp)
     740             : 
     741        6872 :             CALL dbt_copy(ri_data%rho_ao_t(i_spin, i_img), t_2c_ao_tmp(1))
     742        6872 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), rho_desymm)
     743             : 
     744        6872 :             CALL dbcsr_dot(ks_desymm, rho_desymm, etmp)
     745        6872 :             ehfx = ehfx + 0.5_dp*etmp
     746             : 
     747       12172 :             IF (.NOT. use_delta_p) CALL dbt_clear(ri_data%ks_t(i_spin, i_img))
     748             :          END DO
     749             :       END DO
     750         214 :       CALL dbcsr_release(rho_desymm)
     751         214 :       CALL dbcsr_release(ks_desymm)
     752         214 :       CALL dbcsr_release(tmp)
     753         214 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     754             : 
     755         214 :       CALL timestop(handle)
     756             : 
     757       36426 :    END SUBROUTINE hfx_ri_update_ks_kp
     758             : 
     759             : ! **************************************************************************************************
     760             : !> \brief Update the K-points RI-HFX forces
     761             : !> \param qs_env ...
     762             : !> \param ri_data ...
     763             : !> \param nspins ...
     764             : !> \param hf_fraction ...
     765             : !> \param rho_ao ...
     766             : !> \param use_virial ...
     767             : !> \note Because this routine uses stored quantities calculated in the energy calculation, they should
     768             : !>       always be called by pairs, and with the same input densities
     769             : ! **************************************************************************************************
     770          42 :    SUBROUTINE hfx_ri_update_forces_kp(qs_env, ri_data, nspins, hf_fraction, rho_ao, use_virial)
     771             : 
     772             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     773             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     774             :       INTEGER, INTENT(IN)                                :: nspins
     775             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     776             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     777             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_virial
     778             : 
     779             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_forces_kp'
     780             : 
     781             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_loop, i_spin, &
     782             :          i_xyz, iatom, iblk, igroup, j_xyz, jatom, k_xyz, n_batch, natom, ngroups, nimg, nimg_nze
     783             :       INTEGER(int_8)                                     :: nflop, nze
     784          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_of_kind, batch_ranges_at, &
     785          42 :                                                             batch_ranges_nze, dist1, dist2, &
     786          42 :                                                             i_images, idx_to_at_AO, idx_to_at_RI, &
     787          42 :                                                             kind_of
     788          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     789          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: force_pattern, sparsity_pattern
     790             :       INTEGER, DIMENSION(2, 1)                           :: bounds_iat, bounds_jat
     791             :       LOGICAL                                            :: use_virial_prv
     792             :       REAL(dp)                                           :: fac, occ, pref, t1, t2
     793             :       REAL(dp), DIMENSION(3, 3)                          :: work_virial
     794          42 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     795             :       TYPE(cell_type), POINTER                           :: cell
     796             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     797          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     798          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_pot, mat_der_pot_sub
     799             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     800         714 :       TYPE(dbt_type)                                     :: t_2c_R, t_2c_R_split
     801          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_2c_bint, t_2c_binv, t_2c_der_pot, &
     802          84 :                                                             t_2c_inv, t_2c_metric, t_2c_work, &
     803          42 :                                                             t_3c_der_stack, t_3c_work_2, &
     804          42 :                                                             t_3c_work_3
     805          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :) :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
     806          84 :          t_2c_der_metric_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_AO, t_3c_der_AO_sub, t_3c_der_RI, &
     807          42 :          t_3c_der_RI_sub
     808             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     809          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     810          42 :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
     811             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     812             :       TYPE(virial_type), POINTER                         :: virial
     813             : 
     814          42 :       NULLIFY (para_env, para_env_sub, hfx_section, blacs_env_sub, dbcsr_template, force, atomic_kind_set, &
     815          42 :                virial, particle_set, cell)
     816             : 
     817          42 :       CALL timeset(routineN, handle)
     818             : 
     819          42 :       use_virial_prv = .FALSE.
     820          42 :       IF (PRESENT(use_virial)) use_virial_prv = use_virial
     821             : 
     822          42 :       IF (nspins == 1) THEN
     823          28 :          fac = 0.5_dp*hf_fraction
     824             :       ELSE
     825          14 :          fac = 1.0_dp*hf_fraction
     826             :       END IF
     827             : 
     828             :       CALL get_qs_env(qs_env, natom=natom, para_env=para_env, force=force, cell=cell, virial=virial, &
     829          42 :                       atomic_kind_set=atomic_kind_set, particle_set=particle_set)
     830          42 :       CALL get_atomic_kind_set(atomic_kind_set, kind_of=kind_of, atom_of_kind=atom_of_kind)
     831             : 
     832         126 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     833          42 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     834             : 
     835         126 :       ALLOCATE (idx_to_at_RI(SIZE(ri_data%bsizes_RI_split)))
     836          42 :       CALL get_idx_to_atom(idx_to_at_RI, ri_data%bsizes_RI_split, ri_data%bsizes_RI)
     837             : 
     838          42 :       nimg = ri_data%nimg
     839       10542 :       ALLOCATE (t_3c_der_RI(nimg, 3), t_3c_der_AO(nimg, 3), mat_der_pot(nimg, 3), t_2c_der_metric(natom, 3))
     840             : 
     841             :       !We assume that the integrals are available from the SCF
     842             :       !pre-calculate the derivs. 3c tensors as (P^0| sigma^a mu^0), with t_3c_der_AO holding deriv wrt mu^0
     843          42 :       CALL precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
     844             : 
     845             :       !Calculate the density matrix at each image
     846        2546 :       ALLOCATE (rho_ao_t(nspins, nimg))
     847             :       CALL create_2c_tensor(rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     848             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     849          42 :                             name="(AO | AO)")
     850          42 :       DEALLOCATE (dist1, dist2)
     851          42 :       IF (nspins == 2) CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(2, 1))
     852         938 :       DO i_img = 2, nimg
     853        1986 :          DO i_spin = 1, nspins
     854        1944 :             CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(i_spin, i_img))
     855             :          END DO
     856             :       END DO
     857          42 :       CALL get_pmat_images(rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     858             : 
     859             :       !Contract integrals with the density matrix
     860        2546 :       ALLOCATE (t_3c_apc(nspins, nimg))
     861         980 :       DO i_img = 1, nimg
     862        2084 :          DO i_spin = 1, nspins
     863        2042 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     864             :          END DO
     865             :       END DO
     866          42 :       CALL contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
     867             : 
     868             :       !Setup the subgroups
     869          42 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     870          42 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     871          42 :       group_size = para_env%num_pe/ngroups
     872          42 :       igroup = para_env%mepos/group_size
     873             : 
     874          42 :       ALLOCATE (para_env_sub)
     875          42 :       CALL para_env_sub%from_split(para_env, igroup)
     876          42 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     877             : 
     878             :       !Get the ususal sparsity pattern
     879         210 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     880          42 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     881          42 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     882             : 
     883             :       !Get the 2-center quantities in the subgroups (note: main group derivs are deleted wihtin)
     884           0 :       ALLOCATE (t_2c_inv(natom), mat_2c_pot(nimg), rho_ao_t_sub(nspins, nimg), t_2c_work(5), &
     885           0 :                 t_2c_der_metric_sub(natom, 3), mat_der_pot_sub(nimg, 3), t_2c_bint(natom), &
     886        9868 :                 t_2c_metric(natom), t_2c_binv(natom))
     887             :       CALL get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
     888             :                                   rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
     889          42 :                                   mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
     890          42 :       CALL dbt_create(t_2c_work(1), t_2c_R) !nRI x nRI
     891          42 :       CALL dbt_create(t_2c_work(5), t_2c_R_split) !nRI x nRI with split blocks
     892             : 
     893         504 :       ALLOCATE (t_2c_der_pot(3))
     894         168 :       DO i_xyz = 1, 3
     895         168 :          CALL dbt_create(t_2c_R, t_2c_der_pot(i_xyz))
     896             :       END DO
     897             : 
     898             :       !Get the 3-center quantities in the subgroups. The integrals and t_3c_apc already there
     899           0 :       ALLOCATE (t_3c_work_2(3), t_3c_work_3(4), t_3c_der_stack(6), t_3c_der_AO_sub(nimg, 3), &
     900       10736 :                 t_3c_der_RI_sub(nimg, 3), t_3c_apc_sub(nspins, nimg))
     901             :       CALL get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
     902             :                                   t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_stack, &
     903          42 :                                   group_size, ngroups, para_env, para_env_sub, ri_data)
     904             : 
     905             :       !Set up batched contraction (go atom by atom)
     906         126 :       ALLOCATE (batch_ranges_at(natom + 1))
     907          42 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     908          42 :       iatom = 0
     909         212 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     910         212 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     911          84 :             iatom = iatom + 1
     912          84 :             batch_ranges_at(iatom) = iblk
     913             :          END IF
     914             :       END DO
     915             : 
     916          42 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     917          42 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     918          42 :       CALL dbt_batched_contract_init(t_3c_work_3(3), batch_range_2=batch_ranges_at)
     919          42 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     920          42 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     921             : 
     922             :       !Preparing for the stacking of 3c tensors
     923          42 :       nimg_nze = ri_data%nimg_nze
     924          42 :       batch_size = ri_data%kp_stack_size
     925          42 :       n_batch = nimg_nze/batch_size
     926          42 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch = n_batch + 1
     927         126 :       ALLOCATE (batch_ranges_nze(n_batch + 1))
     928          90 :       DO i_batch = 1, n_batch
     929          90 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     930             :       END DO
     931          42 :       batch_ranges_nze(n_batch + 1) = nimg_nze + 1
     932             : 
     933             :       !Applying the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 from left and right
     934             :       !And keep the bump on LHS only version as well, with B*M^-1 = (M^-1*B)^T
     935         126 :       DO iatom = 1, natom
     936          84 :          CALL dbt_create(t_2c_inv(iatom), t_2c_binv(iatom))
     937          84 :          CALL dbt_copy(t_2c_inv(iatom), t_2c_binv(iatom))
     938          84 :          CALL apply_bump(t_2c_binv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     939         126 :          CALL apply_bump(t_2c_inv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
     940             :       END DO
     941             : 
     942          42 :       t1 = m_walltime()
     943          42 :       work_virial = 0.0_dp
     944         210 :       ALLOCATE (iapc_pairs(nimg, 2), i_images(nimg))
     945         210 :       ALLOCATE (force_pattern(natom, natom, nimg))
     946        6608 :       force_pattern(:, :, :) = -1
     947             :       !We proceed with 2 loops: one over the sparsity pattern from the SCF, one over the rest
     948             :       !We use the SCF cost model for the first loop, while we calculate the cost of the upcoming loop
     949         126 :       DO i_loop = 1, 2
     950        1960 :          DO b_img = 1, nimg
     951        5712 :             DO jatom = 1, natom
     952       13132 :                DO iatom = 1, natom
     953             : 
     954        7504 :                   pref = -0.5_dp*fac
     955        7504 :                   IF (i_loop == 1 .AND. (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     956        4333 :                   IF (i_loop == 2 .AND. (.NOT. force_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     957             : 
     958             :                   !Get the proper HFX potential 2c integrals (R_i^0|S_j^b), times (S_j^b|Q_j^b)^-1
     959        1102 :                   CALL timeset(routineN//"_2c_1", handle2)
     960             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     961             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     962        1102 :                                       dbcsr_template=dbcsr_template)
     963             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_inv(jatom), &
     964             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
     965             :                                     contract_1=[2], notcontract_1=[1], &
     966             :                                     contract_2=[1], notcontract_2=[2], &
     967        1102 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     968        1102 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(5), move_data=.TRUE.) !move to split blocks
     969        1102 :                   CALL dbt_filter(t_2c_work(5), ri_data%filter_eps)
     970        1102 :                   CALL timestop(handle2)
     971             : 
     972        1102 :                   CALL timeset(routineN//"_3c", handle2)
     973        5504 :                   bounds_iat(:, 1) = [SUM(ri_data%bsizes_AO(1:iatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:iatom))]
     974        5468 :                   bounds_jat(:, 1) = [SUM(ri_data%bsizes_AO(1:jatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:jatom))]
     975        1102 :                   CALL dbt_clear(t_2c_R_split)
     976             : 
     977        2443 :                   DO i_spin = 1, nspins
     978        2443 :                      CALL dbt_batched_contract_init(rho_ao_t_sub(i_spin, b_img))
     979             :                   END DO
     980             : 
     981        1102 :                   CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env, i_images) !i = a+c-b
     982        2499 :                   DO i_batch = 1, n_batch
     983             : 
     984             :                      !Stack the 3c derivatives to take the trace later on
     985        5588 :                      DO i_xyz = 1, 3
     986        4191 :                         CALL dbt_clear(t_3c_der_stack(i_xyz))
     987             :                         CALL fill_3c_stack(t_3c_der_stack(i_xyz), t_3c_der_RI_sub(:, i_xyz), &
     988             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
     989             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
     990       12573 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     991             : 
     992        4191 :                         CALL dbt_clear(t_3c_der_stack(3 + i_xyz))
     993             :                         CALL fill_3c_stack(t_3c_der_stack(3 + i_xyz), t_3c_der_AO_sub(:, i_xyz), &
     994             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
     995             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
     996       13970 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     997             :                      END DO
     998             : 
     999        4135 :                      DO i_spin = 1, nspins
    1000             :                         !stack the t_3c_apc tensors
    1001        1636 :                         CALL dbt_clear(t_3c_work_2(3))
    1002             :                         CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
    1003             :                                            ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
    1004        4908 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1005        1636 :                         CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
    1006        1636 :                         IF (nze == 0) CYCLE
    1007        1636 :                         CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
    1008             : 
    1009             :                         !Contract with the second density matrix: P_mu^0,nu^b * t_3c_apc,
    1010             :                         !where t_3c_apc = P_sigma^a,lambda^a+c (mu^0 P^0 sigma^a) *(P^0|R^0)^-1 (stacked along a+c)
    1011             :                         CALL dbt_contract(1.0_dp, rho_ao_t_sub(i_spin, b_img), t_3c_work_2(1), &
    1012             :                                           0.0_dp, t_3c_work_2(2), map_1=[1], map_2=[2, 3], &
    1013             :                                           contract_1=[1], notcontract_1=[2], &
    1014             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1015             :                                           bounds_1=bounds_iat, bounds_2=bounds_jat, &
    1016        1636 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1017        1636 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1018             : 
    1019        1636 :                         CALL get_tensor_occupancy(t_3c_work_2(2), nze, occ)
    1020        1636 :                         IF (nze == 0) CYCLE
    1021             : 
    1022             :                         !Contract with V_PQ so that we can take the trace with (Q^b|nu^b lmabda^a+c)^(x)
    1023        1525 :                         CALL dbt_copy(t_3c_work_2(2), t_3c_work_3(1), order=[2, 1, 3], move_data=.TRUE.)
    1024        1525 :                         CALL dbt_batched_contract_init(t_2c_work(5))
    1025             :                         CALL dbt_contract(1.0_dp, t_2c_work(5), t_3c_work_3(1), &
    1026             :                                           0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
    1027             :                                           contract_1=[1], notcontract_1=[2], &
    1028             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1029        1525 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1030        1525 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1031        1525 :                         CALL dbt_batched_contract_finalize(t_2c_work(5))
    1032             : 
    1033             :                         !Contract with the 3c derivatives to get the force/virial
    1034        1525 :                         CALL dbt_copy(t_3c_work_3(2), t_3c_work_3(4), move_data=.TRUE.)
    1035        1525 :                         IF (use_virial_prv) THEN
    1036             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1037             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1038             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1039             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1040         257 :                                                         ri_data, qs_env, work_virial, cell, particle_set)
    1041             :                         ELSE
    1042             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1043             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1044             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1045             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1046        1268 :                                                         ri_data, qs_env)
    1047             :                         END IF
    1048        1525 :                         CALL dbt_clear(t_3c_work_3(4))
    1049             : 
    1050             :                         !Contract with the 3-center integrals in order to have a matrix R_PQ such that
    1051             :                         !we can take the trace sum_PQ R_PQ (P^0|Q^b)^(x)
    1052        1525 :                         IF (i_loop == 2) CYCLE
    1053             : 
    1054             :                         !Stack the 3c integrals
    1055             :                         CALL fill_3c_stack(t_3c_work_3(4), ri_data%kp_t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
    1056             :                                            filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
    1057        2415 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1058         805 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(3), move_data=.TRUE.)
    1059             : 
    1060         805 :                         CALL dbt_batched_contract_init(t_2c_R_split)
    1061             :                         CALL dbt_contract(1.0_dp, t_3c_work_3(1), t_3c_work_3(3), &
    1062             :                                           1.0_dp, t_2c_R_split, map_1=[1], map_2=[2], &
    1063             :                                           contract_1=[2, 3], notcontract_1=[1], &
    1064             :                                           contract_2=[2, 3], notcontract_2=[1], &
    1065         805 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1066         805 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1067         805 :                         CALL dbt_batched_contract_finalize(t_2c_R_split)
    1068        5474 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(1))
    1069             :                      END DO
    1070             :                   END DO
    1071        2443 :                   DO i_spin = 1, nspins
    1072        2443 :                      CALL dbt_batched_contract_finalize(rho_ao_t_sub(i_spin, b_img))
    1073             :                   END DO
    1074        1102 :                   CALL timestop(handle2)
    1075             : 
    1076        1102 :                   IF (i_loop == 2) CYCLE
    1077         581 :                   pref = 2.0_dp*pref
    1078         581 :                   IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp*pref
    1079             : 
    1080         581 :                   CALL timeset(routineN//"_2c_2", handle2)
    1081             :                   !Note that the derivatives are in atomic block format (not split)
    1082         581 :                   CALL dbt_copy(t_2c_R_split, t_2c_R, move_data=.TRUE.)
    1083             : 
    1084             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
    1085             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
    1086         581 :                                       dbcsr_template=dbcsr_template)
    1087             : 
    1088             :                   !We have to calculate: S^-1(iat) * R_PQ * S^-1(jat)    to trace with HFX pot der
    1089             :                   !                      + R_PQ * S^-1(jat) * pot^T      to trace with S^(x) (iat)
    1090             :                   !                      + pot^T * S^-1(iat) *R_PQ       to trace with S^(x) (jat)
    1091             : 
    1092             :                   !Because 3c tensors are all precontracted with the inverse RI metric,
    1093             :                   !t_2c_R is currently implicitely multiplied by S^-1(iat) from the left
    1094             :                   !and S^-1(jat) from the right, directly in the proper format for the trace
    1095             :                   !with the HFX potential derivative
    1096             : 
    1097             :                   !Trace with HFX pot deriv, that we need to build first
    1098        2324 :                   DO i_xyz = 1, 3
    1099             :                      CALL get_ext_2c_int(t_2c_der_pot(i_xyz), mat_der_pot_sub(:, i_xyz), iatom, jatom, &
    1100             :                                          b_img, ri_data, qs_env, blacs_env_ext=blacs_env_sub, &
    1101        2324 :                                          para_env_ext=para_env_sub, dbcsr_template=dbcsr_template)
    1102             :                   END DO
    1103             : 
    1104         581 :                   IF (use_virial_prv) THEN
    1105             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1106         113 :                                            b_img, pref, ri_data, qs_env, work_virial, cell, particle_set)
    1107             :                   ELSE
    1108             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1109         468 :                                            b_img, pref, ri_data, qs_env)
    1110             :                   END IF
    1111             : 
    1112        2324 :                   DO i_xyz = 1, 3
    1113        2324 :                      CALL dbt_clear(t_2c_der_pot(i_xyz))
    1114             :                   END DO
    1115             : 
    1116             :                   !R_PQ * S^-1(jat) * pot^T  (=A)
    1117             :                   CALL dbt_contract(1.0_dp, t_2c_metric(iatom), t_2c_R, & !get rid of implicit S^-1(iat)
    1118             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1119             :                                     contract_1=[2], notcontract_1=[1], &
    1120             :                                     contract_2=[1], notcontract_2=[2], &
    1121         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1122         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1123             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_work(1), &
    1124             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1125             :                                     contract_1=[2], notcontract_1=[1], &
    1126             :                                     contract_2=[2], notcontract_2=[1], &
    1127         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1128         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1129             : 
    1130             :                   !With the RI bump function, things get more complex. M = (S|P)_D + B*(S|P)_OD*B
    1131             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1132             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(iatom), &
    1133             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1134             :                                     contract_1=[2], notcontract_1=[1], &
    1135             :                                     contract_2=[1], notcontract_2=[2], &
    1136         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1137         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1138             : 
    1139             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1140             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1141             :                                     contract_1=[1], notcontract_1=[2], &
    1142             :                                     contract_2=[1], notcontract_2=[2], &
    1143         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1144         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1145             : 
    1146         581 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1147             :                   CALL get_2c_bump_forces(force, t_2c_work(4), iatom, atom_of_kind, kind_of, pref, &
    1148         581 :                                           ri_data, qs_env, work_virial)
    1149             : 
    1150             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1151             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(2), &
    1152             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1153             :                                     contract_1=[1], notcontract_1=[2], &
    1154             :                                     contract_2=[1], notcontract_2=[2], &
    1155         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1156         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1157             : 
    1158         581 :                   IF (use_virial_prv) THEN
    1159             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1160             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1161         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1162             :                   ELSE
    1163             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1164         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1165             :                   END IF
    1166             : 
    1167             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1168         581 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1169         581 :                   CALL apply_bump(t_2c_work(2), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1170             : 
    1171         581 :                   IF (use_virial_prv) THEN
    1172             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1173             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1174         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1175             :                   ELSE
    1176             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1177         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1178             :                   END IF
    1179             : 
    1180             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1181             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1182             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(iatom), &
    1183             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1184             :                                     contract_1=[2], notcontract_1=[1], &
    1185             :                                     contract_2=[1], notcontract_2=[2], &
    1186         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1187         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1188             : 
    1189             :                   CALL dbt_contract(1.0_dp, t_2c_bint(iatom), t_2c_work(4), &
    1190             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1191             :                                     contract_1=[1], notcontract_1=[2], &
    1192             :                                     contract_2=[1], notcontract_2=[2], &
    1193         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1194         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1195             : 
    1196             :                   CALL get_2c_bump_forces(force, t_2c_work(2), iatom, atom_of_kind, kind_of, -pref, &
    1197         581 :                                           ri_data, qs_env, work_virial)
    1198             : 
    1199             :                   ! pot^T * S^-1(iat) * R_PQ (=A)
    1200             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_R, &
    1201             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1202             :                                     contract_1=[1], notcontract_1=[2], &
    1203             :                                     contract_2=[1], notcontract_2=[2], &
    1204         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1205         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1206             : 
    1207             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_metric(jatom), & !get rid of implicit S^-1(jat)
    1208             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1209             :                                     contract_1=[2], notcontract_1=[1], &
    1210             :                                     contract_2=[1], notcontract_2=[2], &
    1211         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1212         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1213             : 
    1214             :                   !Do the same shenanigans with the S^(x) (jatom)
    1215             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1216             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(jatom), &
    1217             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1218             :                                     contract_1=[2], notcontract_1=[1], &
    1219             :                                     contract_2=[1], notcontract_2=[2], &
    1220         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1221         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1222             : 
    1223             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1224             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1225             :                                     contract_1=[1], notcontract_1=[2], &
    1226             :                                     contract_2=[1], notcontract_2=[2], &
    1227         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1228         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1229             : 
    1230         581 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1231             :                   CALL get_2c_bump_forces(force, t_2c_work(4), jatom, atom_of_kind, kind_of, pref, &
    1232         581 :                                           ri_data, qs_env, work_virial)
    1233             : 
    1234             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1235             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(2), &
    1236             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1237             :                                     contract_1=[1], notcontract_1=[2], &
    1238             :                                     contract_2=[1], notcontract_2=[2], &
    1239         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1240         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1241             : 
    1242         581 :                   IF (use_virial_prv) THEN
    1243             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1244             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1245         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1246             :                   ELSE
    1247             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1248         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1249             :                   END IF
    1250             : 
    1251             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1252         581 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1253         581 :                   CALL apply_bump(t_2c_work(2), jatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1254             : 
    1255         581 :                   IF (use_virial_prv) THEN
    1256             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1257             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1258         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1259             :                   ELSE
    1260             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1261         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1262             :                   END IF
    1263             : 
    1264             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1265             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1266             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(jatom), &
    1267             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1268             :                                     contract_1=[2], notcontract_1=[1], &
    1269             :                                     contract_2=[1], notcontract_2=[2], &
    1270         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1271         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1272             : 
    1273             :                   CALL dbt_contract(1.0_dp, t_2c_bint(jatom), t_2c_work(4), &
    1274             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1275             :                                     contract_1=[1], notcontract_1=[2], &
    1276             :                                     contract_2=[1], notcontract_2=[2], &
    1277         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1278         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1279             : 
    1280             :                   CALL get_2c_bump_forces(force, t_2c_work(2), jatom, atom_of_kind, kind_of, -pref, &
    1281         581 :                                           ri_data, qs_env, work_virial)
    1282             : 
    1283       13520 :                   CALL timestop(handle2)
    1284             :                END DO !iatom
    1285             :             END DO !jatom
    1286             :          END DO !b_img
    1287             : 
    1288         126 :          IF (i_loop == 1) THEN
    1289          42 :             CALL update_pattern_to_forces(force_pattern, sparsity_pattern, ngroups, ri_data, qs_env)
    1290             :          END IF
    1291             :       END DO !i_loop
    1292             : 
    1293          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
    1294          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
    1295          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(3))
    1296          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
    1297          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
    1298             : 
    1299          42 :       IF (use_virial_prv) THEN
    1300          32 :          DO k_xyz = 1, 3
    1301         104 :             DO j_xyz = 1, 3
    1302         312 :                DO i_xyz = 1, 3
    1303             :                   virial%pv_fock_4c(i_xyz, j_xyz) = virial%pv_fock_4c(i_xyz, j_xyz) &
    1304         288 :                                                     + work_virial(i_xyz, k_xyz)*cell%hmat(j_xyz, k_xyz)
    1305             :                END DO
    1306             :             END DO
    1307             :          END DO
    1308             :       END IF
    1309             : 
    1310             :       !End of subgroup parallelization
    1311          42 :       CALL cp_blacs_env_release(blacs_env_sub)
    1312          42 :       CALL para_env_sub%free()
    1313          42 :       DEALLOCATE (para_env_sub)
    1314             : 
    1315          42 :       CALL para_env%sync()
    1316          42 :       t2 = m_walltime()
    1317          42 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    1318             : 
    1319             :       !clean-up
    1320          42 :       IF (ASSOCIATED(dbcsr_template)) THEN
    1321          42 :          CALL dbcsr_release(dbcsr_template)
    1322          42 :          DEALLOCATE (dbcsr_template)
    1323             :       END IF
    1324          42 :       CALL dbt_destroy(t_2c_R)
    1325          42 :       CALL dbt_destroy(t_2c_R_split)
    1326          42 :       CALL dbt_destroy(t_2c_work(1))
    1327          42 :       CALL dbt_destroy(t_2c_work(2))
    1328          42 :       CALL dbt_destroy(t_2c_work(3))
    1329          42 :       CALL dbt_destroy(t_2c_work(4))
    1330          42 :       CALL dbt_destroy(t_2c_work(5))
    1331          42 :       CALL dbt_destroy(t_3c_work_2(1))
    1332          42 :       CALL dbt_destroy(t_3c_work_2(2))
    1333          42 :       CALL dbt_destroy(t_3c_work_2(3))
    1334          42 :       CALL dbt_destroy(t_3c_work_3(1))
    1335          42 :       CALL dbt_destroy(t_3c_work_3(2))
    1336          42 :       CALL dbt_destroy(t_3c_work_3(3))
    1337          42 :       CALL dbt_destroy(t_3c_work_3(4))
    1338          42 :       CALL dbt_destroy(t_3c_der_stack(1))
    1339          42 :       CALL dbt_destroy(t_3c_der_stack(2))
    1340          42 :       CALL dbt_destroy(t_3c_der_stack(3))
    1341          42 :       CALL dbt_destroy(t_3c_der_stack(4))
    1342          42 :       CALL dbt_destroy(t_3c_der_stack(5))
    1343          42 :       CALL dbt_destroy(t_3c_der_stack(6))
    1344         168 :       DO i_xyz = 1, 3
    1345         168 :          CALL dbt_destroy(t_2c_der_pot(i_xyz))
    1346             :       END DO
    1347         126 :       DO iatom = 1, natom
    1348          84 :          CALL dbt_destroy(t_2c_inv(iatom))
    1349          84 :          CALL dbt_destroy(t_2c_binv(iatom))
    1350          84 :          CALL dbt_destroy(t_2c_bint(iatom))
    1351          84 :          CALL dbt_destroy(t_2c_metric(iatom))
    1352         378 :          DO i_xyz = 1, 3
    1353         336 :             CALL dbt_destroy(t_2c_der_metric_sub(iatom, i_xyz))
    1354             :          END DO
    1355             :       END DO
    1356         980 :       DO i_img = 1, nimg
    1357         938 :          CALL dbcsr_release(mat_2c_pot(i_img))
    1358        2084 :          DO i_spin = 1, nspins
    1359        1104 :             CALL dbt_destroy(rho_ao_t_sub(i_spin, i_img))
    1360        2042 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
    1361             :          END DO
    1362             :       END DO
    1363         168 :       DO i_xyz = 1, 3
    1364        2982 :          DO i_img = 1, nimg
    1365        2814 :             CALL dbt_destroy(t_3c_der_RI_sub(i_img, i_xyz))
    1366        2814 :             CALL dbt_destroy(t_3c_der_AO_sub(i_img, i_xyz))
    1367        2940 :             CALL dbcsr_release(mat_der_pot_sub(i_img, i_xyz))
    1368             :          END DO
    1369             :       END DO
    1370             : 
    1371          42 :       CALL timestop(handle)
    1372             : 
    1373       17604 :    END SUBROUTINE hfx_ri_update_forces_kp
    1374             : 
    1375             : ! **************************************************************************************************
    1376             : !> \brief A routine the applies the RI bump matrix from the left and/or the right, given an input
    1377             : !>        matrix and the central RI atom. We assume atomic block sizes
    1378             : !> \param t_2c_inout ...
    1379             : !> \param atom_i ...
    1380             : !> \param ri_data ...
    1381             : !> \param qs_env ...
    1382             : !> \param from_left ...
    1383             : !> \param from_right ...
    1384             : !> \param debump ...
    1385             : ! **************************************************************************************************
    1386        1750 :    SUBROUTINE apply_bump(t_2c_inout, atom_i, ri_data, qs_env, from_left, from_right, debump)
    1387             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_inout
    1388             :       INTEGER, INTENT(IN)                                :: atom_i
    1389             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1390             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1391             :       LOGICAL, INTENT(IN), OPTIONAL                      :: from_left, from_right, debump
    1392             : 
    1393             :       INTEGER                                            :: i_img, i_RI, iatom, ind(2), j_img, j_RI, &
    1394             :                                                             jatom, natom, nblks(2), nimg, nkind
    1395        1750 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1396        1750 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1397             :       LOGICAL                                            :: found, my_debump, my_left, my_right
    1398             :       REAL(dp)                                           :: bval, r0, r1, ri(3), rj(3), rref(3), &
    1399             :                                                             scoord(3)
    1400        1750 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1401             :       TYPE(cell_type), POINTER                           :: cell
    1402             :       TYPE(dbt_iterator_type)                            :: iter
    1403             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1404        1750 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1405        1750 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1406             : 
    1407        1750 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1408             : 
    1409             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1410        1750 :                       kpoints=kpoints, particle_set=particle_set)
    1411        1750 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1412             : 
    1413        1750 :       my_debump = .FALSE.
    1414        1750 :       IF (PRESENT(debump)) my_debump = debump
    1415             : 
    1416        1750 :       my_left = .FALSE.
    1417        1750 :       IF (PRESENT(from_left)) my_left = from_left
    1418             : 
    1419        1750 :       my_right = .FALSE.
    1420        1750 :       IF (PRESENT(from_right)) my_right = from_right
    1421        1750 :       CPASSERT(my_left .OR. my_right)
    1422             : 
    1423        1750 :       CALL dbt_get_info(t_2c_inout, nblks_total=nblks)
    1424        1750 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1425        1750 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1426             : 
    1427        1750 :       nimg = ri_data%nimg
    1428             : 
    1429             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1430        1750 :       r1 = ri_data%kp_RI_range
    1431        1750 :       r0 = ri_data%kp_bump_rad
    1432        1750 :       rref = pbc(particle_set(atom_i)%r, cell)
    1433             : 
    1434             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_inout,natom,ri_data,cell,particle_set,index_to_cell,my_left, &
    1435             : !$OMP                               my_right,r0,r1,rref,my_debump) &
    1436        1750 : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,bval)
    1437             :       CALL dbt_iterator_start(iter, t_2c_inout)
    1438             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1439             :          CALL dbt_iterator_next_block(iter, ind)
    1440             :          CALL dbt_get_block(t_2c_inout, ind, blk, found)
    1441             :          IF (.NOT. found) CYCLE
    1442             : 
    1443             :          i_RI = (ind(1) - 1)/natom + 1
    1444             :          i_img = ri_data%RI_cell_to_img(i_RI)
    1445             :          iatom = ind(1) - (i_RI - 1)*natom
    1446             : 
    1447             :          CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    1448             :          CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    1449             : 
    1450             :          j_RI = (ind(2) - 1)/natom + 1
    1451             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1452             :          jatom = ind(2) - (j_RI - 1)*natom
    1453             : 
    1454             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1455             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1456             : 
    1457             :          IF (.NOT. my_debump) THEN
    1458             :             IF (my_left) blk(:, :) = blk(:, :)*bump(NORM2(ri - rref), r0, r1)
    1459             :             IF (my_right) blk(:, :) = blk(:, :)*bump(NORM2(rj - rref), r0, r1)
    1460             :          ELSE
    1461             :             !Note: by construction, the bump function is never quite zero, as its range is the same
    1462             :             !      as that of the extended RI basis (but we are safe)
    1463             :             bval = bump(NORM2(ri - rref), r0, r1)
    1464             :             IF (my_left .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1465             :             bval = bump(NORM2(rj - rref), r0, r1)
    1466             :             IF (my_right .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1467             :          END IF
    1468             : 
    1469             :          CALL dbt_put_block(t_2c_inout, ind, SHAPE(blk), blk)
    1470             : 
    1471             :          DEALLOCATE (blk)
    1472             :       END DO
    1473             :       CALL dbt_iterator_stop(iter)
    1474             : !$OMP END PARALLEL
    1475        1750 :       CALL dbt_filter(t_2c_inout, ri_data%filter_eps)
    1476             : 
    1477        3500 :    END SUBROUTINE apply_bump
    1478             : 
    1479             : ! **************************************************************************************************
    1480             : !> \brief A routine that calculates the forces due to the derivative of the bump function
    1481             : !> \param force ...
    1482             : !> \param t_2c_in ...
    1483             : !> \param atom_i ...
    1484             : !> \param atom_of_kind ...
    1485             : !> \param kind_of ...
    1486             : !> \param pref ...
    1487             : !> \param ri_data ...
    1488             : !> \param qs_env ...
    1489             : !> \param work_virial ...
    1490             : ! **************************************************************************************************
    1491        2324 :    SUBROUTINE get_2c_bump_forces(force, t_2c_in, atom_i, atom_of_kind, kind_of, pref, ri_data, &
    1492             :                                  qs_env, work_virial)
    1493             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    1494             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_in
    1495             :       INTEGER, INTENT(IN)                                :: atom_i
    1496             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    1497             :       REAL(dp), INTENT(IN)                               :: pref
    1498             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1499             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1500             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT)           :: work_virial
    1501             : 
    1502             :       INTEGER :: i, i_img, i_RI, i_xyz, iat_of_kind, iatom, ikind, ind(2), j_img, j_RI, j_xyz, &
    1503             :          jat_of_kind, jatom, jkind, natom, nblks(2), nimg, nkind
    1504        2324 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1505        2324 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1506             :       LOGICAL                                            :: found
    1507             :       REAL(dp)                                           :: new_force, r0, r1, ri(3), rj(3), &
    1508             :                                                             rref(3), scoord(3), x
    1509        2324 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1510             :       TYPE(cell_type), POINTER                           :: cell
    1511             :       TYPE(dbt_iterator_type)                            :: iter
    1512             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1513        2324 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1514        2324 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1515             : 
    1516        2324 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1517             : 
    1518             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1519        2324 :                       kpoints=kpoints, particle_set=particle_set)
    1520        2324 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1521             : 
    1522        2324 :       CALL dbt_get_info(t_2c_in, nblks_total=nblks)
    1523        2324 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1524        2324 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1525             : 
    1526        2324 :       nimg = ri_data%nimg
    1527             : 
    1528             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1529        2324 :       r1 = ri_data%kp_RI_range
    1530        2324 :       r0 = ri_data%kp_bump_rad
    1531        2324 :       rref = pbc(particle_set(atom_i)%r, cell)
    1532             : 
    1533        2324 :       iat_of_kind = atom_of_kind(atom_i)
    1534        2324 :       ikind = kind_of(atom_i)
    1535             : 
    1536             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_in,natom,ri_data,cell,particle_set,index_to_cell,pref, &
    1537             : !$OMP force,r0,r1,rref,atom_of_kind,kind_of,iat_of_kind,ikind,work_virial) &
    1538             : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,jkind,jat_of_kind, &
    1539        2324 : !$OMP         new_force,i_xyz,i,x,j_xyz)
    1540             :       CALL dbt_iterator_start(iter, t_2c_in)
    1541             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1542             :          CALL dbt_iterator_next_block(iter, ind)
    1543             :          IF (ind(1) .NE. ind(2)) CYCLE !bump matrix is diagonal
    1544             : 
    1545             :          CALL dbt_get_block(t_2c_in, ind, blk, found)
    1546             :          IF (.NOT. found) CYCLE
    1547             : 
    1548             :          !bump is a function of x = SQRT((R - Rref)^2). We refer to R as jatom, and Rref as atom_i
    1549             :          j_RI = (ind(2) - 1)/natom + 1
    1550             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1551             :          jatom = ind(2) - (j_RI - 1)*natom
    1552             :          jat_of_kind = atom_of_kind(jatom)
    1553             :          jkind = kind_of(jatom)
    1554             : 
    1555             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1556             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1557             :          x = NORM2(rj - rref)
    1558             :          IF (x < r0 .OR. x > r1) CYCLE
    1559             : 
    1560             :          new_force = 0.0_dp
    1561             :          DO i = 1, SIZE(blk, 1)
    1562             :             new_force = new_force + blk(i, i)
    1563             :          END DO
    1564             :          new_force = pref*new_force*dbump(x, r0, r1)
    1565             : 
    1566             :          !x = SQRT((R - Rref)^2), so we multiply by dx/dR and dx/dRref
    1567             :          DO i_xyz = 1, 3
    1568             :             !Force acting on second atom
    1569             : !$OMP ATOMIC
    1570             :             force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) + &
    1571             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1572             : 
    1573             :             !virial acting on second atom
    1574             :             CALL real_to_scaled(scoord, rj, cell)
    1575             :             DO j_xyz = 1, 3
    1576             : !$OMP ATOMIC
    1577             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1578             :                                            + new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1579             :             END DO
    1580             : 
    1581             :             !Force acting on reference atom, defining the RI basis
    1582             : !$OMP ATOMIC
    1583             :             force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) - &
    1584             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1585             : 
    1586             :             !virial of ref atom
    1587             :             CALL real_to_scaled(scoord, rref, cell)
    1588             :             DO j_xyz = 1, 3
    1589             : !$OMP ATOMIC
    1590             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1591             :                                            - new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1592             :             END DO
    1593             :          END DO !i_xyz
    1594             : 
    1595             :          DEALLOCATE (blk)
    1596             :       END DO
    1597             :       CALL dbt_iterator_stop(iter)
    1598             : !$OMP END PARALLEL
    1599             : 
    1600        4648 :    END SUBROUTINE get_2c_bump_forces
    1601             : 
    1602             : ! **************************************************************************************************
    1603             : !> \brief The bumb function as defined by Juerg
    1604             : !> \param x ...
    1605             : !> \param r0 ...
    1606             : !> \param r1 ...
    1607             : !> \return ...
    1608             : ! **************************************************************************************************
    1609       23891 :    FUNCTION bump(x, r0, r1) RESULT(b)
    1610             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1611             :       REAL(dp)                                           :: b
    1612             : 
    1613             :       REAL(dp)                                           :: r
    1614             : 
    1615             :       !Head-Gordon
    1616             :       !b = 1.0_dp/(1.0_dp+EXP((r1-r0)/(r1-x)-(r1-r0)/(x-r0)))
    1617             :       !Juerg
    1618       23891 :       r = (x - r0)/(r1 - r0)
    1619       23891 :       b = -6.0_dp*r**5 + 15.0_dp*r**4 - 10.0_dp*r**3 + 1.0_dp
    1620       23891 :       IF (x .GE. r1) b = 0.0_dp
    1621       23891 :       IF (x .LE. r0) b = 1.0_dp
    1622             : 
    1623       23891 :    END FUNCTION bump
    1624             : 
    1625             : ! **************************************************************************************************
    1626             : !> \brief The derivative of the bump function
    1627             : !> \param x ...
    1628             : !> \param r0 ...
    1629             : !> \param r1 ...
    1630             : !> \return ...
    1631             : ! **************************************************************************************************
    1632         509 :    FUNCTION dbump(x, r0, r1) RESULT(b)
    1633             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1634             :       REAL(dp)                                           :: b
    1635             : 
    1636             :       REAL(dp)                                           :: r
    1637             : 
    1638         509 :       r = (x - r0)/(r1 - r0)
    1639         509 :       b = (-30.0_dp*r**4 + 60.0_dp*r**3 - 30.0_dp*r**2)/(r1 - r0)
    1640         509 :       IF (x .GE. r1) b = 0.0_dp
    1641         509 :       IF (x .LE. r0) b = 0.0_dp
    1642             : 
    1643         509 :    END FUNCTION dbump
    1644             : 
    1645             : ! **************************************************************************************************
    1646             : !> \brief return the cell index a+c corresponding to given cell index i and b, with i = a+c-b
    1647             : !> \param i_index ...
    1648             : !> \param b_index ...
    1649             : !> \param qs_env ...
    1650             : !> \return ...
    1651             : ! **************************************************************************************************
    1652      158738 :    FUNCTION get_apc_index_from_ib(i_index, b_index, qs_env) RESULT(apc_index)
    1653             :       INTEGER, INTENT(IN)                                :: i_index, b_index
    1654             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1655             :       INTEGER                                            :: apc_index
    1656             : 
    1657             :       INTEGER, DIMENSION(3)                              :: cell_apc
    1658      158738 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1659      158738 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1660             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1661             : 
    1662      158738 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1663      158738 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1664             : 
    1665             :       !i = a+c-b => a+c = i+b
    1666      634952 :       cell_apc(:) = index_to_cell(:, i_index) + index_to_cell(:, b_index)
    1667             : 
    1668     1087392 :       IF (ANY([cell_apc(1), cell_apc(2), cell_apc(3)] < LBOUND(cell_to_index)) .OR. &
    1669             :           ANY([cell_apc(1), cell_apc(2), cell_apc(3)] > UBOUND(cell_to_index))) THEN
    1670             : 
    1671             :          apc_index = 0
    1672             :       ELSE
    1673      138666 :          apc_index = cell_to_index(cell_apc(1), cell_apc(2), cell_apc(3))
    1674             :       END IF
    1675             : 
    1676      158738 :    END FUNCTION get_apc_index_from_ib
    1677             : 
    1678             : ! **************************************************************************************************
    1679             : !> \brief return the cell index i corresponding to the summ of cell_a and cell_c
    1680             : !> \param a_index ...
    1681             : !> \param c_index ...
    1682             : !> \param qs_env ...
    1683             : !> \return ...
    1684             : ! **************************************************************************************************
    1685           0 :    FUNCTION get_apc_index(a_index, c_index, qs_env) RESULT(i_index)
    1686             :       INTEGER, INTENT(IN)                                :: a_index, c_index
    1687             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1688             :       INTEGER                                            :: i_index
    1689             : 
    1690             :       INTEGER, DIMENSION(3)                              :: cell_i
    1691           0 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1692           0 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1693             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1694             : 
    1695           0 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1696           0 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1697             : 
    1698           0 :       cell_i(:) = index_to_cell(:, a_index) + index_to_cell(:, c_index)
    1699             : 
    1700           0 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1701             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1702             : 
    1703             :          i_index = 0
    1704             :       ELSE
    1705           0 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1706             :       END IF
    1707             : 
    1708           0 :    END FUNCTION get_apc_index
    1709             : 
    1710             : ! **************************************************************************************************
    1711             : !> \brief return the cell index i corresponding to the summ of cell_a + cell_c - cell_b
    1712             : !> \param apc_index ...
    1713             : !> \param b_index ...
    1714             : !> \param qs_env ...
    1715             : !> \return ...
    1716             : ! **************************************************************************************************
    1717      526636 :    FUNCTION get_i_index(apc_index, b_index, qs_env) RESULT(i_index)
    1718             :       INTEGER, INTENT(IN)                                :: apc_index, b_index
    1719             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1720             :       INTEGER                                            :: i_index
    1721             : 
    1722             :       INTEGER, DIMENSION(3)                              :: cell_i
    1723      526636 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1724      526636 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1725             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1726             : 
    1727      526636 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1728      526636 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1729             : 
    1730     2106544 :       cell_i(:) = index_to_cell(:, apc_index) - index_to_cell(:, b_index)
    1731             : 
    1732     3597572 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1733             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1734             : 
    1735             :          i_index = 0
    1736             :       ELSE
    1737      450040 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1738             :       END IF
    1739             : 
    1740      526636 :    END FUNCTION get_i_index
    1741             : 
    1742             : ! **************************************************************************************************
    1743             : !> \brief A routine that returns all allowed a,c pairs such that a+c images corresponds to the value
    1744             : !>        of the apc_index input. Takes into account that image a corresponds to 3c integrals, which
    1745             : !>        are ordered in their own way
    1746             : !> \param ac_pairs ...
    1747             : !> \param apc_index ...
    1748             : !> \param ri_data ...
    1749             : !> \param qs_env ...
    1750             : ! **************************************************************************************************
    1751       16412 :    SUBROUTINE get_ac_pairs(ac_pairs, apc_index, ri_data, qs_env)
    1752             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: ac_pairs
    1753             :       INTEGER, INTENT(IN)                                :: apc_index
    1754             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1755             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1756             : 
    1757             :       INTEGER                                            :: a_index, actual_img, c_index, nimg
    1758             : 
    1759       16412 :       nimg = SIZE(ac_pairs, 1)
    1760             : 
    1761     1102508 :       ac_pairs(:, :) = 0
    1762             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(ac_pairs,nimg,ri_data,qs_env,apc_index) &
    1763       16412 : !$OMP PRIVATE(a_index,actual_img,c_index)
    1764             :       DO a_index = 1, nimg
    1765             :          actual_img = ri_data%idx_to_img(a_index)
    1766             :          !c = a+c - a
    1767             :          c_index = get_i_index(apc_index, actual_img, qs_env)
    1768             :          ac_pairs(a_index, 1) = a_index
    1769             :          ac_pairs(a_index, 2) = c_index
    1770             :       END DO
    1771             : !$OMP END PARALLEL DO
    1772             : 
    1773       16412 :    END SUBROUTINE get_ac_pairs
    1774             : 
    1775             : ! **************************************************************************************************
    1776             : !> \brief A routine that returns all allowed i,a+c pairs such that, for the given value of b, we have
    1777             : !>        i = a+c-b. Takes into account that image i corrsponds to the 3c ints, which are ordered in
    1778             : !>        their own way
    1779             : !> \param iapc_pairs ...
    1780             : !> \param b_index ...
    1781             : !> \param ri_data ...
    1782             : !> \param qs_env ...
    1783             : !> \param actual_i_img ...
    1784             : ! **************************************************************************************************
    1785        4708 :    SUBROUTINE get_iapc_pairs(iapc_pairs, b_index, ri_data, qs_env, actual_i_img)
    1786             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: iapc_pairs
    1787             :       INTEGER, INTENT(IN)                                :: b_index
    1788             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1789             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1790             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: actual_i_img
    1791             : 
    1792             :       INTEGER                                            :: actual_img, apc_index, i_index, nimg
    1793             : 
    1794        4708 :       nimg = SIZE(iapc_pairs, 1)
    1795       40802 :       IF (PRESENT(actual_i_img)) actual_i_img(:) = 0
    1796             : 
    1797      331600 :       iapc_pairs(:, :) = 0
    1798             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(iapc_pairs,nimg,ri_data,qs_env,b_index,actual_i_img) &
    1799        4708 : !$OMP PRIVATE(i_index,actual_img,apc_index)
    1800             :       DO i_index = 1, nimg
    1801             :          actual_img = ri_data%idx_to_img(i_index)
    1802             :          apc_index = get_apc_index_from_ib(actual_img, b_index, qs_env)
    1803             :          IF (apc_index == 0) CYCLE
    1804             :          iapc_pairs(i_index, 1) = i_index
    1805             :          iapc_pairs(i_index, 2) = apc_index
    1806             :          IF (PRESENT(actual_i_img)) actual_i_img(i_index) = actual_img
    1807             :       END DO
    1808             : 
    1809        4708 :    END SUBROUTINE get_iapc_pairs
    1810             : 
    1811             : ! **************************************************************************************************
    1812             : !> \brief A function that, given a cell index a, returun the index corresponding to -a, and zero if
    1813             : !>        if out of bounds
    1814             : !> \param a_index ...
    1815             : !> \param qs_env ...
    1816             : !> \return ...
    1817             : ! **************************************************************************************************
    1818       66222 :    FUNCTION get_opp_index(a_index, qs_env) RESULT(opp_index)
    1819             :       INTEGER, INTENT(IN)                                :: a_index
    1820             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1821             :       INTEGER                                            :: opp_index
    1822             : 
    1823             :       INTEGER, DIMENSION(3)                              :: opp_cell
    1824       66222 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1825       66222 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1826             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1827             : 
    1828       66222 :       NULLIFY (kpoints, cell_to_index, index_to_cell)
    1829             : 
    1830       66222 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1831       66222 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1832             : 
    1833      264888 :       opp_cell(:) = -index_to_cell(:, a_index)
    1834             : 
    1835      463554 :       IF (ANY([opp_cell(1), opp_cell(2), opp_cell(3)] < LBOUND(cell_to_index)) .OR. &
    1836             :           ANY([opp_cell(1), opp_cell(2), opp_cell(3)] > UBOUND(cell_to_index))) THEN
    1837             : 
    1838             :          opp_index = 0
    1839             :       ELSE
    1840       66222 :          opp_index = cell_to_index(opp_cell(1), opp_cell(2), opp_cell(3))
    1841             :       END IF
    1842             : 
    1843       66222 :    END FUNCTION get_opp_index
    1844             : 
    1845             : ! **************************************************************************************************
    1846             : !> \brief A routine that returns the actual non-symemtric density matrix for each image, by Fourier
    1847             : !>        transforming the kpoint density matrix
    1848             : !> \param rho_ao_t ...
    1849             : !> \param rho_ao ...
    1850             : !> \param scale_prev_p ...
    1851             : !> \param ri_data ...
    1852             : !> \param qs_env ...
    1853             : ! **************************************************************************************************
    1854         470 :    SUBROUTINE get_pmat_images(rho_ao_t, rho_ao, scale_prev_p, ri_data, qs_env)
    1855             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t
    1856             :       TYPE(dbcsr_p_type), DIMENSION(:, :), INTENT(INOUT) :: rho_ao
    1857             :       REAL(dp), INTENT(IN)                               :: scale_prev_p
    1858             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1859             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1860             : 
    1861             :       INTEGER                                            :: cell_j(3), i_img, i_spin, iatom, icol, &
    1862             :                                                             irow, j_img, jatom, mi_img, mj_img, &
    1863             :                                                             nimg, nspins
    1864         470 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1865             :       LOGICAL                                            :: found
    1866             :       REAL(dp)                                           :: fac
    1867         470 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock, pblock_desymm
    1868         470 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_ks, rho_desymm
    1869        4230 :       TYPE(dbt_type)                                     :: tmp
    1870             :       TYPE(dft_control_type), POINTER                    :: dft_control
    1871             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1872             :       TYPE(neighbor_list_iterator_p_type), &
    1873         470 :          DIMENSION(:), POINTER                           :: nl_iterator
    1874             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    1875         470 :          POINTER                                         :: sab_nl, sab_nl_nosym
    1876             :       TYPE(qs_scf_env_type), POINTER                     :: scf_env
    1877             : 
    1878         470 :       NULLIFY (rho_desymm, kpoints, sab_nl_nosym, scf_env, matrix_ks, dft_control, &
    1879         470 :                sab_nl, nl_iterator, cell_to_index, pblock, pblock_desymm)
    1880             : 
    1881         470 :       CALL get_qs_env(qs_env, kpoints=kpoints, scf_env=scf_env, matrix_ks_kp=matrix_ks, dft_control=dft_control)
    1882         470 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=sab_nl_nosym, cell_to_index=cell_to_index, sab_nl=sab_nl)
    1883             : 
    1884         470 :       IF (dft_control%do_admm) THEN
    1885         252 :          CALL get_admm_env(qs_env%admm_env, matrix_ks_aux_fit_kp=matrix_ks)
    1886             :       END IF
    1887             : 
    1888         470 :       nspins = SIZE(matrix_ks, 1)
    1889         470 :       nimg = ri_data%nimg
    1890             : 
    1891       28266 :       ALLOCATE (rho_desymm(nspins, nimg))
    1892       12008 :       DO i_img = 1, nimg
    1893       26856 :          DO i_spin = 1, nspins
    1894       14848 :             ALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1895             :             CALL dbcsr_create(rho_desymm(i_spin, i_img)%matrix, template=matrix_ks(i_spin, i_img)%matrix, &
    1896       14848 :                               matrix_type=dbcsr_type_no_symmetry)
    1897       26386 :             CALL cp_dbcsr_alloc_block_from_nbl(rho_desymm(i_spin, i_img)%matrix, sab_nl_nosym)
    1898             :          END DO
    1899             :       END DO
    1900         470 :       CALL dbt_create(rho_desymm(1, 1)%matrix, tmp)
    1901             : 
    1902             :       !We transfor the symmtric typed (but not actually symmetric: P_ab^i = P_ba^-i) real-spaced density
    1903             :       !matrix into proper non-symemtric ones (using the same nl for consistency)
    1904         470 :       CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
    1905       20878 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    1906       20408 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    1907       20408 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    1908       20408 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    1909             : 
    1910       13930 :          fac = 1.0_dp
    1911       13930 :          IF (iatom == jatom) fac = 0.5_dp
    1912       13930 :          mj_img = get_opp_index(j_img, qs_env)
    1913             :          !if no opposite image, then no sum of P^j + P^-j => need full diag
    1914       13930 :          IF (mj_img == 0) fac = 1.0_dp
    1915             : 
    1916       13930 :          irow = iatom
    1917       13930 :          icol = jatom
    1918       13930 :          IF (iatom > jatom) THEN
    1919             :             !because symmetric nl. Value for atom pair i,j is actually stored in j,i if i > j
    1920        4374 :             irow = jatom
    1921        4374 :             icol = iatom
    1922             :          END IF
    1923             : 
    1924       32618 :          DO i_spin = 1, nspins
    1925       18218 :             CALL dbcsr_get_block_p(rho_ao(i_spin, j_img)%matrix, irow, icol, pblock, found)
    1926       18218 :             IF (.NOT. found) CYCLE
    1927             : 
    1928             :             !distribution of symm and non-symm matrix match in that way
    1929       18218 :             CALL dbcsr_get_block_p(rho_desymm(i_spin, j_img)%matrix, iatom, jatom, pblock_desymm, found)
    1930       18218 :             IF (.NOT. found) CYCLE
    1931             : 
    1932       75062 :             IF (iatom > jatom) THEN
    1933      720396 :                pblock_desymm(:, :) = fac*TRANSPOSE(pblock(:, :))
    1934             :             ELSE
    1935     1709076 :                pblock_desymm(:, :) = fac*pblock(:, :)
    1936             :             END IF
    1937             :          END DO
    1938             :       END DO
    1939         470 :       CALL neighbor_list_iterator_release(nl_iterator)
    1940             : 
    1941       12008 :       DO i_img = 1, nimg
    1942       26856 :          DO i_spin = 1, nspins
    1943       14848 :             CALL dbt_scale(rho_ao_t(i_spin, i_img), scale_prev_p)
    1944             : 
    1945       14848 :             CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, i_img)%matrix, tmp)
    1946       14848 :             CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), summation=.TRUE., move_data=.TRUE.)
    1947             : 
    1948             :             !symmetrize by addin transpose of opp img
    1949       14848 :             mi_img = get_opp_index(i_img, qs_env)
    1950       14848 :             IF (mi_img > 0 .AND. mi_img .LE. nimg) THEN
    1951       13368 :                CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, mi_img)%matrix, tmp)
    1952       13368 :                CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), order=[2, 1], summation=.TRUE., move_data=.TRUE.)
    1953             :             END IF
    1954       26386 :             CALL dbt_filter(rho_ao_t(i_spin, i_img), ri_data%filter_eps)
    1955             :          END DO
    1956             :       END DO
    1957             : 
    1958       12008 :       DO i_img = 1, nimg
    1959       26856 :          DO i_spin = 1, nspins
    1960       14848 :             CALL dbcsr_release(rho_desymm(i_spin, i_img)%matrix)
    1961       26386 :             DEALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1962             :          END DO
    1963             :       END DO
    1964             : 
    1965         470 :       CALL dbt_destroy(tmp)
    1966         470 :       DEALLOCATE (rho_desymm)
    1967             : 
    1968         940 :    END SUBROUTINE get_pmat_images
    1969             : 
    1970             : ! **************************************************************************************************
    1971             : !> \brief A routine that, given a cell index b and atom indices ij, returns a 2c tensor with the HFX
    1972             : !>        potential (P_i^0|Q_j^b), within the extended RI basis
    1973             : !> \param t_2c_pot ...
    1974             : !> \param mat_orig ...
    1975             : !> \param atom_i ...
    1976             : !> \param atom_j ...
    1977             : !> \param img_b ...
    1978             : !> \param ri_data ...
    1979             : !> \param qs_env ...
    1980             : !> \param do_inverse ...
    1981             : !> \param para_env_ext ...
    1982             : !> \param blacs_env_ext ...
    1983             : !> \param dbcsr_template ...
    1984             : !> \param off_diagonal ...
    1985             : !> \param skip_inverse ...
    1986             : ! **************************************************************************************************
    1987        7704 :    SUBROUTINE get_ext_2c_int(t_2c_pot, mat_orig, atom_i, atom_j, img_b, ri_data, qs_env, do_inverse, &
    1988             :                              para_env_ext, blacs_env_ext, dbcsr_template, off_diagonal, skip_inverse)
    1989             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_pot
    1990             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_orig
    1991             :       INTEGER, INTENT(IN)                                :: atom_i, atom_j, img_b
    1992             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1993             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1994             :       LOGICAL, INTENT(IN), OPTIONAL                      :: do_inverse
    1995             :       TYPE(mp_para_env_type), OPTIONAL, POINTER          :: para_env_ext
    1996             :       TYPE(cp_blacs_env_type), OPTIONAL, POINTER         :: blacs_env_ext
    1997             :       TYPE(dbcsr_type), OPTIONAL, POINTER                :: dbcsr_template
    1998             :       LOGICAL, INTENT(IN), OPTIONAL                      :: off_diagonal, skip_inverse
    1999             : 
    2000             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_ext_2c_int'
    2001             : 
    2002             :       INTEGER :: blk, group, handle, handle2, i_img, i_RI, iatom, iblk, ikind, img_tot, j_img, &
    2003             :          j_RI, jatom, jblk, jkind, n_dependent, natom, nblks_RI, nimg, nkind
    2004        7704 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist1, dist2
    2005        7704 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: present_atoms_i, present_atoms_j
    2006             :       INTEGER, DIMENSION(3)                              :: cell_b, cell_i, cell_j, cell_tot
    2007        7704 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, col_dist_ext, ri_blk_size_ext, &
    2008        7704 :                                                             row_dist, row_dist_ext
    2009        7704 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell, pgrid
    2010        7704 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    2011             :       LOGICAL                                            :: do_inverse_prv, found, my_offd, &
    2012             :                                                             skip_inverse_prv, use_template
    2013             :       REAL(dp)                                           :: bfac, dij, r0, r1, threshold
    2014             :       REAL(dp), DIMENSION(3)                             :: ri, rij, rj, rref, scoord
    2015        7704 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock
    2016             :       TYPE(cell_type), POINTER                           :: cell
    2017             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
    2018             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist, dbcsr_dist_ext
    2019             :       TYPE(dbcsr_iterator_type)                          :: dbcsr_iter
    2020             :       TYPE(dbcsr_type)                                   :: work, work_tight, work_tight_inv
    2021       53928 :       TYPE(dbt_type)                                     :: t_2c_tmp
    2022             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    2023             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    2024        7704 :          DIMENSION(:), TARGET                            :: basis_set_RI
    2025             :       TYPE(kpoint_type), POINTER                         :: kpoints
    2026             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2027             :       TYPE(neighbor_list_iterator_p_type), &
    2028        7704 :          DIMENSION(:), POINTER                           :: nl_iterator
    2029             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    2030        7704 :          POINTER                                         :: nl_2c
    2031        7704 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    2032        7704 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    2033             : 
    2034        7704 :       NULLIFY (qs_kind_set, nl_2c, nl_iterator, cell, kpoints, cell_to_index, index_to_cell, dist_2d, &
    2035        7704 :                para_env, pblock, blacs_env, particle_set, col_dist, row_dist, pgrid, &
    2036        7704 :                col_dist_ext, row_dist_ext)
    2037             : 
    2038        7704 :       CALL timeset(routineN, handle)
    2039             : 
    2040             :       !Idea: run over the neighbor list once for i and once for j, and record in which cell the MIC
    2041             :       !      atoms are. Then loop over the atoms and only take the pairs the we need
    2042             : 
    2043             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    2044        7704 :                       kpoints=kpoints, para_env=para_env, blacs_env=blacs_env, particle_set=particle_set)
    2045        7704 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    2046             : 
    2047        7704 :       do_inverse_prv = .FALSE.
    2048        7704 :       IF (PRESENT(do_inverse)) do_inverse_prv = do_inverse
    2049         280 :       IF (do_inverse_prv) THEN
    2050         280 :          CPASSERT(atom_i == atom_j)
    2051             :       END IF
    2052             : 
    2053        7704 :       skip_inverse_prv = .FALSE.
    2054        7704 :       IF (PRESENT(skip_inverse)) skip_inverse_prv = skip_inverse
    2055             : 
    2056        7704 :       my_offd = .FALSE.
    2057        7704 :       IF (PRESENT(off_diagonal)) my_offd = off_diagonal
    2058             : 
    2059        7704 :       IF (PRESENT(para_env_ext)) para_env => para_env_ext
    2060        7704 :       IF (PRESENT(blacs_env_ext)) blacs_env => blacs_env_ext
    2061             : 
    2062        7704 :       nimg = SIZE(mat_orig)
    2063             : 
    2064        7704 :       CALL timeset(routineN//"_nl_iter", handle2)
    2065             : 
    2066             :       !create our own dist_2d in the subgroup
    2067       30816 :       ALLOCATE (dist1(natom), dist2(natom))
    2068       23112 :       DO iatom = 1, natom
    2069       15408 :          dist1(iatom) = MOD(iatom, blacs_env%num_pe(1))
    2070       23112 :          dist2(iatom) = MOD(iatom, blacs_env%num_pe(2))
    2071             :       END DO
    2072        7704 :       CALL distribution_2d_create(dist_2d, dist1, dist2, nkind, particle_set, blacs_env_ext=blacs_env)
    2073             : 
    2074       34860 :       ALLOCATE (basis_set_RI(nkind))
    2075        7704 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    2076             : 
    2077             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    2078        7704 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    2079             : 
    2080       46224 :       ALLOCATE (present_atoms_i(natom, nimg), present_atoms_j(natom, nimg))
    2081      748254 :       present_atoms_i = 0
    2082      748254 :       present_atoms_j = 0
    2083             : 
    2084        7704 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    2085      308400 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    2086             :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, r=rij, cell=cell_j, &
    2087      300696 :                                 ikind=ikind, jkind=jkind)
    2088             : 
    2089     1202784 :          dij = NORM2(rij)
    2090             : 
    2091      300696 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    2092      300696 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    2093             : 
    2094      298349 :          IF (iatom == atom_i .AND. dij .LE. ri_data%kp_RI_range) present_atoms_i(jatom, j_img) = 1
    2095      306053 :          IF (iatom == atom_j .AND. dij .LE. ri_data%kp_RI_range) present_atoms_j(jatom, j_img) = 1
    2096             :       END DO
    2097        7704 :       CALL neighbor_list_iterator_release(nl_iterator)
    2098        7704 :       CALL release_neighbor_list_sets(nl_2c)
    2099        7704 :       CALL distribution_2d_release(dist_2d)
    2100        7704 :       CALL timestop(handle2)
    2101             : 
    2102        7704 :       CALL para_env%sum(present_atoms_i)
    2103        7704 :       CALL para_env%sum(present_atoms_j)
    2104             : 
    2105             :       !Need to build a work matrix with matching distribution to mat_orig
    2106             :       !If template is provided, use it. If not, we create it.
    2107        7704 :       use_template = .FALSE.
    2108        7704 :       IF (PRESENT(dbcsr_template)) THEN
    2109        7032 :          IF (ASSOCIATED(dbcsr_template)) use_template = .TRUE.
    2110             :       END IF
    2111             : 
    2112             :       IF (use_template) THEN
    2113        6776 :          CALL dbcsr_create(work, template=dbcsr_template)
    2114             :       ELSE
    2115         928 :          CALL dbcsr_get_info(mat_orig(1), distribution=dbcsr_dist)
    2116         928 :          CALL dbcsr_distribution_get(dbcsr_dist, row_dist=row_dist, col_dist=col_dist, group=group, pgrid=pgrid)
    2117        3712 :          ALLOCATE (row_dist_ext(ri_data%ncell_RI*natom), col_dist_ext(ri_data%ncell_RI*natom))
    2118        1856 :          ALLOCATE (ri_blk_size_ext(ri_data%ncell_RI*natom))
    2119        6602 :          DO i_RI = 1, ri_data%ncell_RI
    2120       28370 :             row_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = row_dist(:)
    2121       28370 :             col_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = col_dist(:)
    2122       17950 :             RI_blk_size_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2123             :          END DO
    2124             : 
    2125             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2126         928 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2127             :          CALL dbcsr_create(work, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2128         928 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2129         928 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2130         928 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2131             : 
    2132        2784 :          IF (PRESENT(dbcsr_template)) THEN
    2133         256 :             ALLOCATE (dbcsr_template)
    2134         256 :             CALL dbcsr_create(dbcsr_template, template=work)
    2135             :          END IF
    2136             :       END IF !use_template
    2137             : 
    2138       30816 :       cell_b(:) = index_to_cell(:, img_b)
    2139      254554 :       DO i_img = 1, nimg
    2140      246850 :          i_RI = ri_data%img_to_RI_cell(i_img)
    2141      246850 :          IF (i_RI == 0) CYCLE
    2142      207140 :          cell_i(:) = index_to_cell(:, i_img)
    2143     2043969 :          DO j_img = 1, nimg
    2144     1984480 :             j_RI = ri_data%img_to_RI_cell(j_img)
    2145     1984480 :             IF (j_RI == 0) CYCLE
    2146     1799436 :             cell_j(:) = index_to_cell(:, j_img)
    2147     1799436 :             cell_tot = cell_j - cell_i + cell_b
    2148             : 
    2149     3107547 :             IF (ANY([cell_tot(1), cell_tot(2), cell_tot(3)] < LBOUND(cell_to_index)) .OR. &
    2150             :                 ANY([cell_tot(1), cell_tot(2), cell_tot(3)] > UBOUND(cell_to_index))) CYCLE
    2151      412907 :             img_tot = cell_to_index(cell_tot(1), cell_tot(2), cell_tot(3))
    2152      412907 :             IF (img_tot > nimg .OR. img_tot < 1) CYCLE
    2153             : 
    2154      276021 :             CALL dbcsr_iterator_start(dbcsr_iter, mat_orig(img_tot))
    2155      788793 :             DO WHILE (dbcsr_iterator_blocks_left(dbcsr_iter))
    2156      512772 :                CALL dbcsr_iterator_next_block(dbcsr_iter, row=iatom, column=jatom, blk=blk)
    2157      512772 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2158      191706 :                IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2159       83092 :                IF (my_offd .AND. (i_RI - 1)*natom + iatom == (j_RI - 1)*natom + jatom) CYCLE
    2160             : 
    2161       82805 :                CALL dbcsr_get_block_p(mat_orig(img_tot), iatom, jatom, pblock, found)
    2162       82805 :                IF (.NOT. found) CYCLE
    2163             : 
    2164      788793 :                CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2165             : 
    2166             :             END DO
    2167     2470399 :             CALL dbcsr_iterator_stop(dbcsr_iter)
    2168             : 
    2169             :          END DO !j_img
    2170             :       END DO !i_img
    2171        7704 :       CALL dbcsr_finalize(work)
    2172             : 
    2173        7704 :       IF (do_inverse_prv) THEN
    2174             : 
    2175         280 :          r1 = ri_data%kp_RI_range
    2176         280 :          r0 = ri_data%kp_bump_rad
    2177             : 
    2178             :          !Because there are a lot of empty rows/cols in work, we need to get rid of them for inversion
    2179       20008 :          nblks_RI = SUM(present_atoms_i)
    2180        1400 :          ALLOCATE (col_dist_ext(nblks_RI), row_dist_ext(nblks_RI), RI_blk_size_ext(nblks_RI))
    2181         280 :          iblk = 0
    2182        6856 :          DO i_img = 1, nimg
    2183        6576 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2184        6576 :             IF (i_RI == 0) CYCLE
    2185        5512 :             DO iatom = 1, natom
    2186        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2187        1148 :                iblk = iblk + 1
    2188        1148 :                col_dist_ext(iblk) = col_dist(iatom)
    2189        1148 :                row_dist_ext(iblk) = row_dist(iatom)
    2190       10064 :                RI_blk_size_ext(iblk) = ri_data%bsizes_RI(iatom)
    2191             :             END DO
    2192             :          END DO
    2193             : 
    2194             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2195         280 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2196             :          CALL dbcsr_create(work_tight, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2197         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2198             :          CALL dbcsr_create(work_tight_inv, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2199         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2200         280 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2201         280 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2202             : 
    2203             :          !We apply a bump function to the RI metric inverse for smooth RI basis extension:
    2204             :          ! S^-1 = B * ((P|Q)_D + B*(P|Q)_OD*B)^-1 * B, with D block-diagonal blocks and OD off-diagonal
    2205         280 :          rref = pbc(particle_set(atom_i)%r, cell)
    2206             : 
    2207         280 :          iblk = 0
    2208        6856 :          DO i_img = 1, nimg
    2209        6576 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2210        6576 :             IF (i_RI == 0) CYCLE
    2211        5512 :             DO iatom = 1, natom
    2212        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2213        1148 :                iblk = iblk + 1
    2214             : 
    2215        1148 :                CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    2216        4592 :                CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    2217             : 
    2218        1148 :                jblk = 0
    2219       40068 :                DO j_img = 1, nimg
    2220       32344 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2221       32344 :                   IF (j_RI == 0) CYCLE
    2222       29240 :                   DO jatom = 1, natom
    2223       17168 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2224        5476 :                      jblk = jblk + 1
    2225             : 
    2226        5476 :                      CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    2227       21904 :                      CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    2228             : 
    2229        5476 :                      CALL dbcsr_get_block_p(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock, found)
    2230        5476 :                      IF (.NOT. found) CYCLE
    2231             : 
    2232        2460 :                      bfac = 1.0_dp
    2233       13776 :                      IF (iblk .NE. jblk) bfac = bump(NORM2(ri - rref), r0, r1)*bump(NORM2(rj - rref), r0, r1)
    2234     5054592 :                      CALL dbcsr_put_block(work_tight, iblk, jblk, bfac*pblock(:, :))
    2235             :                   END DO
    2236             :                END DO
    2237             :             END DO
    2238             :          END DO
    2239         280 :          CALL dbcsr_finalize(work_tight)
    2240         280 :          CALL dbcsr_clear(work)
    2241             : 
    2242         280 :          IF (.NOT. skip_inverse_prv) THEN
    2243         140 :             SELECT CASE (ri_data%t2c_method)
    2244             :             CASE (hfx_ri_do_2c_iter)
    2245           0 :                threshold = MAX(ri_data%filter_eps, 1.0e-12_dp)
    2246           0 :                CALL invert_hotelling(work_tight_inv, work_tight, threshold=threshold, silent=.FALSE.)
    2247             :             CASE (hfx_ri_do_2c_cholesky)
    2248         140 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2249         140 :                CALL cp_dbcsr_cholesky_decompose(work_tight_inv, para_env=para_env, blacs_env=blacs_env)
    2250             :                CALL cp_dbcsr_cholesky_invert(work_tight_inv, para_env=para_env, blacs_env=blacs_env, &
    2251         140 :                                              upper_to_full=.TRUE.)
    2252             :             CASE (hfx_ri_do_2c_diag)
    2253           0 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2254             :                CALL cp_dbcsr_power(work_tight_inv, -1.0_dp, ri_data%eps_eigval, n_dependent, &
    2255         140 :                                    para_env, blacs_env, verbose=ri_data%unit_nr_dbcsr > 0)
    2256             :             END SELECT
    2257             :          ELSE
    2258         140 :             CALL dbcsr_copy(work_tight_inv, work_tight)
    2259             :          END IF
    2260             : 
    2261             :          !move back data to standard extended RI pattern
    2262             :          !Note: we apply the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 later, because this matrix
    2263             :          !      is required for forces
    2264         280 :          iblk = 0
    2265        6856 :          DO i_img = 1, nimg
    2266        6576 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2267        6576 :             IF (i_RI == 0) CYCLE
    2268        5512 :             DO iatom = 1, natom
    2269        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2270        1148 :                iblk = iblk + 1
    2271             : 
    2272        1148 :                jblk = 0
    2273       40068 :                DO j_img = 1, nimg
    2274       32344 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2275       32344 :                   IF (j_RI == 0) CYCLE
    2276       29240 :                   DO jatom = 1, natom
    2277       17168 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2278        5476 :                      jblk = jblk + 1
    2279             : 
    2280        5476 :                      CALL dbcsr_get_block_p(work_tight_inv, iblk, jblk, pblock, found)
    2281        5476 :                      IF (.NOT. found) CYCLE
    2282             : 
    2283       52111 :                      CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2284             :                   END DO
    2285             :                END DO
    2286             :             END DO
    2287             :          END DO
    2288         280 :          CALL dbcsr_finalize(work)
    2289             : 
    2290         280 :          CALL dbcsr_release(work_tight)
    2291         560 :          CALL dbcsr_release(work_tight_inv)
    2292             :       END IF
    2293             : 
    2294        7704 :       CALL dbt_create(work, t_2c_tmp)
    2295        7704 :       CALL dbt_copy_matrix_to_tensor(work, t_2c_tmp)
    2296        7704 :       CALL dbt_copy(t_2c_tmp, t_2c_pot, move_data=.TRUE.)
    2297        7704 :       CALL dbt_filter(t_2c_pot, ri_data%filter_eps)
    2298             : 
    2299        7704 :       CALL dbt_destroy(t_2c_tmp)
    2300        7704 :       CALL dbcsr_release(work)
    2301             : 
    2302        7704 :       CALL timestop(handle)
    2303             : 
    2304       30816 :    END SUBROUTINE get_ext_2c_int
    2305             : 
    2306             : ! **************************************************************************************************
    2307             : !> \brief Pre-contract the density matrices with the 3-center integrals:
    2308             : !>        P_sigma^a,lambda^a+c (mu^0 sigma^a| P^0)
    2309             : !> \param t_3c_apc ...
    2310             : !> \param rho_ao_t ...
    2311             : !> \param ri_data ...
    2312             : !> \param qs_env ...
    2313             : ! **************************************************************************************************
    2314         256 :    SUBROUTINE contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
    2315             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, rho_ao_t
    2316             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2317             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2318             : 
    2319             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'contract_pmat_3c'
    2320             : 
    2321             :       INTEGER                                            :: apc_img, batch_size, handle, i_batch, &
    2322             :                                                             i_img, i_spin, j_batch, n_batch_img, &
    2323             :                                                             n_batch_nze, nimg, nimg_nze, nspins
    2324             :       INTEGER(int_8)                                     :: nflop, nze
    2325         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_img, batch_ranges_nze, &
    2326         256 :                                                             int_indices
    2327         256 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: ac_pairs
    2328             :       REAL(dp)                                           :: occ, t1, t2
    2329        2304 :       TYPE(dbt_type)                                     :: t_3c_tmp
    2330         256 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ints_stack, res_stack, rho_stack
    2331             :       TYPE(dft_control_type), POINTER                    :: dft_control
    2332             : 
    2333         256 :       CALL timeset(routineN, handle)
    2334             : 
    2335         256 :       CALL get_qs_env(qs_env, dft_control=dft_control)
    2336             : 
    2337         256 :       nimg = ri_data%nimg
    2338         256 :       nimg_nze = ri_data%nimg_nze
    2339         256 :       nspins = dft_control%nspins
    2340             : 
    2341         256 :       CALL dbt_create(t_3c_apc(1, 1), t_3c_tmp)
    2342             : 
    2343         256 :       batch_size = nimg/ri_data%n_mem
    2344             : 
    2345             :       !batching over all images
    2346         256 :       n_batch_img = nimg/batch_size
    2347         256 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch_img = n_batch_img + 1
    2348         768 :       ALLOCATE (batch_ranges_img(n_batch_img + 1))
    2349         890 :       DO i_batch = 1, n_batch_img
    2350         890 :          batch_ranges_img(i_batch) = (i_batch - 1)*batch_size + 1
    2351             :       END DO
    2352         256 :       batch_ranges_img(n_batch_img + 1) = nimg + 1
    2353             : 
    2354             :       !batching over images with non-zero 3c integrals
    2355         256 :       n_batch_nze = nimg_nze/batch_size
    2356         256 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
    2357         768 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
    2358         830 :       DO i_batch = 1, n_batch_nze
    2359         830 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
    2360             :       END DO
    2361         256 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
    2362             : 
    2363             :       !Create the stack tensors in the approriate distribution
    2364        7936 :       ALLOCATE (rho_stack(2), ints_stack(2), res_stack(2))
    2365             :       CALL get_stack_tensors(res_stack, rho_stack, ints_stack, rho_ao_t(1, 1), &
    2366         256 :                              ri_data%t_3c_int_ctr_1(1, 1), batch_size, ri_data, qs_env)
    2367             : 
    2368        1280 :       ALLOCATE (ac_pairs(nimg, 2), int_indices(nimg_nze))
    2369        4888 :       DO i_img = 1, nimg_nze
    2370        4888 :          int_indices(i_img) = i_img
    2371             :       END DO
    2372             : 
    2373         256 :       t1 = m_walltime()
    2374         830 :       DO j_batch = 1, n_batch_nze
    2375             :          !First batch is over the integrals. They are always in the same order, consistent with get_ac_pairs
    2376             :          CALL fill_3c_stack(ints_stack(1), ri_data%t_3c_int_ctr_1(1, :), int_indices, 3, ri_data, &
    2377        1722 :                             img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)])
    2378         574 :          CALL dbt_copy(ints_stack(1), ints_stack(2), move_data=.TRUE.)
    2379             : 
    2380        1640 :          DO i_spin = 1, nspins
    2381        3448 :             DO i_batch = 1, n_batch_img
    2382             :                !Second batch is over the P matrix. Here we fill the stacked rho tensors col by col
    2383       18476 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2384       16412 :                   CALL get_ac_pairs(ac_pairs, apc_img, ri_data, qs_env)
    2385             :                   CALL fill_2c_stack(rho_stack(1), rho_ao_t(i_spin, :), ac_pairs(:, 2), 1, ri_data, &
    2386             :                                      img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)], &
    2387       51300 :                                      shift=apc_img - batch_ranges_img(i_batch) + 1)
    2388             : 
    2389             :                END DO !apc_img
    2390        2064 :                CALL get_tensor_occupancy(rho_stack(1), nze, occ)
    2391        2064 :                IF (nze == 0) CYCLE
    2392        1806 :                CALL dbt_copy(rho_stack(1), rho_stack(2), move_data=.TRUE.)
    2393             : 
    2394             :                !The actual contraction
    2395        1806 :                CALL dbt_batched_contract_init(rho_stack(2))
    2396             :                CALL dbt_contract(1.0_dp, ints_stack(2), rho_stack(2), &
    2397             :                                  0.0_dp, res_stack(2), map_1=[1, 2], map_2=[3], &
    2398             :                                  contract_1=[3], notcontract_1=[1, 2], &
    2399             :                                  contract_2=[1], notcontract_2=[2], &
    2400        1806 :                                  filter_eps=ri_data%filter_eps, flop=nflop)
    2401        1806 :                ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2402        1806 :                CALL dbt_batched_contract_finalize(rho_stack(2))
    2403        1806 :                CALL dbt_copy(res_stack(2), res_stack(1), move_data=.TRUE.)
    2404             : 
    2405       20230 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2406             :                   !Destack the resulting tensor and put it in t_3c_apc with correct apc_img
    2407       15550 :                   CALL unstack_t_3c_apc(t_3c_tmp, res_stack(1), apc_img - batch_ranges_img(i_batch) + 1)
    2408       17614 :                   CALL dbt_copy(t_3c_tmp, t_3c_apc(i_spin, apc_img), summation=.TRUE., move_data=.TRUE.)
    2409             :                END DO
    2410             : 
    2411             :             END DO !i_batch
    2412             :          END DO !i_spin
    2413             :       END DO !j_batch
    2414         256 :       DEALLOCATE (batch_ranges_img)
    2415         256 :       DEALLOCATE (batch_ranges_nze)
    2416         256 :       t2 = m_walltime()
    2417         256 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    2418             : 
    2419         256 :       CALL dbt_destroy(rho_stack(1))
    2420         256 :       CALL dbt_destroy(rho_stack(2))
    2421         256 :       CALL dbt_destroy(ints_stack(1))
    2422         256 :       CALL dbt_destroy(ints_stack(2))
    2423         256 :       CALL dbt_destroy(res_stack(1))
    2424         256 :       CALL dbt_destroy(res_stack(2))
    2425         256 :       CALL dbt_destroy(t_3c_tmp)
    2426             : 
    2427         256 :       CALL timestop(handle)
    2428             : 
    2429        2560 :    END SUBROUTINE contract_pmat_3c
    2430             : 
    2431             : ! **************************************************************************************************
    2432             : !> \brief Pre-contract 3-center integrals with the bumped invrse RI metric, for each atom
    2433             : !> \param t_3c_int ...
    2434             : !> \param ri_data ...
    2435             : !> \param qs_env ...
    2436             : ! **************************************************************************************************
    2437          70 :    SUBROUTINE precontract_3c_ints(t_3c_int, ri_data, qs_env)
    2438             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_int
    2439             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2440             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2441             : 
    2442             :       CHARACTER(len=*), PARAMETER :: routineN = 'precontract_3c_ints'
    2443             : 
    2444             :       INTEGER                                            :: batch_size, handle, i_batch, i_img, &
    2445             :                                                             i_RI, iatom, is, n_batch, natom, &
    2446             :                                                             nblks, nblks_3c(3), nimg
    2447             :       INTEGER(int_8)                                     :: nflop
    2448          70 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: batch_ranges, bsizes_RI_ext, bsizes_RI_ext_split, &
    2449          70 :          bsizes_stack, dist1, dist2, dist3, dist_stack3, idx_to_at_AO, int_indices
    2450         630 :       TYPE(dbt_distribution_type)                        :: t_dist
    2451       14630 :       TYPE(dbt_type)                                     :: t_2c_RI_tmp(2), t_3c_tmp(3)
    2452             : 
    2453          70 :       CALL timeset(routineN, handle)
    2454             : 
    2455          70 :       CALL get_qs_env(qs_env, natom=natom)
    2456             : 
    2457          70 :       nimg = ri_data%nimg
    2458         210 :       ALLOCATE (int_indices(nimg))
    2459        1714 :       DO i_img = 1, nimg
    2460        1714 :          int_indices(i_img) = i_img
    2461             :       END DO
    2462             : 
    2463         210 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
    2464          70 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
    2465             : 
    2466          70 :       nblks = SIZE(ri_data%bsizes_RI_split)
    2467         210 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    2468         210 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    2469         506 :       DO i_RI = 1, ri_data%ncell_RI
    2470        1308 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2471        2344 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    2472             :       END DO
    2473             :       CALL create_2c_tensor(t_2c_RI_tmp(1), dist1, dist2, ri_data%pgrid_2d, &
    2474             :                             bsizes_RI_ext, bsizes_RI_ext, &
    2475             :                             name="(RI | RI)")
    2476          70 :       DEALLOCATE (dist1, dist2)
    2477             :       CALL create_2c_tensor(t_2c_RI_tmp(2), dist1, dist2, ri_data%pgrid_2d, &
    2478             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    2479             :                             name="(RI | RI)")
    2480          70 :       DEALLOCATE (dist1, dist2)
    2481             : 
    2482             :       !For more efficiency, we stack multiple images of the 3-center integrals into a single tensor
    2483          70 :       batch_size = nimg/ri_data%n_mem
    2484          70 :       n_batch = nimg/batch_size
    2485          70 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch = n_batch + 1
    2486         210 :       ALLOCATE (batch_ranges(n_batch + 1))
    2487         246 :       DO i_batch = 1, n_batch
    2488         246 :          batch_ranges(i_batch) = (i_batch - 1)*batch_size + 1
    2489             :       END DO
    2490          70 :       batch_ranges(n_batch + 1) = nimg + 1
    2491             : 
    2492          70 :       nblks = SIZE(ri_data%bsizes_AO_split)
    2493         210 :       ALLOCATE (bsizes_stack(batch_size*nblks))
    2494         874 :       DO is = 1, batch_size
    2495        3394 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    2496             :       END DO
    2497             : 
    2498          70 :       CALL dbt_get_info(t_3c_int(1, 1), nblks_total=nblks_3c)
    2499         630 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)), dist_stack3(batch_size*nblks_3c(3)))
    2500          70 :       CALL dbt_get_info(t_3c_int(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, proc_dist_3=dist3)
    2501         874 :       DO is = 1, batch_size
    2502        3394 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    2503             :       END DO
    2504             : 
    2505          70 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid, dist1, dist2, dist_stack3)
    2506             :       CALL dbt_create(t_3c_tmp(1), "ints_stack", t_dist, [1], [2, 3], bsizes_RI_ext_split, &
    2507          70 :                       ri_data%bsizes_AO_split, bsizes_stack)
    2508          70 :       CALL dbt_distribution_destroy(t_dist)
    2509          70 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    2510             : 
    2511          70 :       CALL dbt_create(t_3c_tmp(1), t_3c_tmp(2))
    2512          70 :       CALL dbt_create(t_3c_int(1, 1), t_3c_tmp(3))
    2513             : 
    2514         210 :       DO iatom = 1, natom
    2515         140 :          CALL dbt_copy(ri_data%t_2c_inv(1, iatom), t_2c_RI_tmp(1))
    2516         140 :          CALL apply_bump(t_2c_RI_tmp(1), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    2517         140 :          CALL dbt_copy(t_2c_RI_tmp(1), t_2c_RI_tmp(2), move_data=.TRUE.)
    2518             : 
    2519         140 :          CALL dbt_batched_contract_init(t_2c_RI_tmp(2))
    2520         492 :          DO i_batch = 1, n_batch
    2521             : 
    2522             :             CALL fill_3c_stack(t_3c_tmp(1), t_3c_int(1, :), int_indices, 3, ri_data, &
    2523             :                                img_bounds=[batch_ranges(i_batch), batch_ranges(i_batch + 1)], &
    2524        1056 :                                filter_at=iatom, filter_dim=2, idx_to_at=idx_to_at_AO)
    2525             : 
    2526             :             CALL dbt_contract(1.0_dp, t_2c_RI_tmp(2), t_3c_tmp(1), &
    2527             :                               0.0_dp, t_3c_tmp(2), map_1=[1], map_2=[2, 3], &
    2528             :                               contract_1=[2], notcontract_1=[1], &
    2529             :                               contract_2=[1], notcontract_2=[2, 3], &
    2530         352 :                               filter_eps=ri_data%filter_eps, flop=nflop)
    2531         352 :             ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2532             : 
    2533        3640 :             DO i_img = batch_ranges(i_batch), batch_ranges(i_batch + 1) - 1
    2534        3288 :                CALL unstack_t_3c_apc(t_3c_tmp(3), t_3c_tmp(2), i_img - batch_ranges(i_batch) + 1)
    2535             :                CALL dbt_copy(t_3c_tmp(3), ri_data%t_3c_int_ctr_1(1, i_img), summation=.TRUE., &
    2536        3640 :                              order=[2, 1, 3], move_data=.TRUE.)
    2537             :             END DO
    2538         492 :             CALL dbt_clear(t_3c_tmp(1))
    2539             :          END DO
    2540         210 :          CALL dbt_batched_contract_finalize(t_2c_RI_tmp(2))
    2541             : 
    2542             :       END DO
    2543          70 :       CALL dbt_destroy(t_2c_RI_tmp(1))
    2544          70 :       CALL dbt_destroy(t_2c_RI_tmp(2))
    2545          70 :       CALL dbt_destroy(t_3c_tmp(1))
    2546          70 :       CALL dbt_destroy(t_3c_tmp(2))
    2547          70 :       CALL dbt_destroy(t_3c_tmp(3))
    2548             : 
    2549        1714 :       DO i_img = 1, nimg
    2550        1714 :          CALL dbt_destroy(t_3c_int(1, i_img))
    2551             :       END DO
    2552             : 
    2553          70 :       CALL timestop(handle)
    2554             : 
    2555         350 :    END SUBROUTINE precontract_3c_ints
    2556             : 
    2557             : ! **************************************************************************************************
    2558             : !> \brief Copy the data of a 2D tensor living in the main MPI group to a sub-group, given the proc
    2559             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2560             : !> \param t2c_sub ...
    2561             : !> \param t2c_main ...
    2562             : !> \param group_size ...
    2563             : !> \param ngroups ...
    2564             : !> \param para_env ...
    2565             : ! **************************************************************************************************
    2566        8388 :    SUBROUTINE copy_2c_to_subgroup(t2c_sub, t2c_main, group_size, ngroups, para_env)
    2567             :       TYPE(dbt_type), INTENT(INOUT)                      :: t2c_sub, t2c_main
    2568             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2569             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2570             : 
    2571             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iblk, &
    2572             :                                                             igroup, iproc, ir, is, jblk, n_batch, &
    2573             :                                                             nocc, tag
    2574        8388 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2
    2575        8388 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: block_dest, block_source
    2576        8388 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: current_dest
    2577             :       INTEGER, DIMENSION(2)                              :: ind, nblks
    2578             :       LOGICAL                                            :: found
    2579        8388 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2580        8388 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2581             :       TYPE(dbt_iterator_type)                            :: iter
    2582        8388 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2583             : 
    2584             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2585             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2586             :       !         to have blocks pre-reserved though
    2587             : 
    2588        8388 :       CALL dbt_get_info(t2c_main, nblks_total=nblks)
    2589             : 
    2590             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2591       33552 :       ALLOCATE (block_source(nblks(1), nblks(2)))
    2592      169184 :       block_source = -1
    2593        8388 :       nocc = 0
    2594        8388 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t2c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2595             :       CALL dbt_iterator_start(iter, t2c_main)
    2596             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2597             :          CALL dbt_iterator_next_block(iter, ind)
    2598             :          CALL dbt_get_block(t2c_main, ind, blk, found)
    2599             :          IF (.NOT. found) CYCLE
    2600             : 
    2601             :          block_source(ind(1), ind(2)) = para_env%mepos
    2602             : !$OMP ATOMIC
    2603             :          nocc = nocc + 1
    2604             :          DEALLOCATE (blk)
    2605             :       END DO
    2606             :       CALL dbt_iterator_stop(iter)
    2607             : !$OMP END PARALLEL
    2608             : 
    2609        8388 :       CALL para_env%sum(nocc)
    2610        8388 :       CALL para_env%sum(block_source)
    2611      169184 :       block_source = block_source + para_env%num_pe - 1
    2612        8388 :       IF (nocc == 0) RETURN
    2613             : 
    2614             :       !Loop over the sub tensor, get the block destination
    2615        8268 :       igroup = para_env%mepos/group_size
    2616       24804 :       ALLOCATE (block_dest(nblks(1), nblks(2)))
    2617      168344 :       block_dest = -1
    2618       31184 :       DO jblk = 1, nblks(2)
    2619      168344 :          DO iblk = 1, nblks(1)
    2620      137160 :             IF (block_source(iblk, jblk) == -1) CYCLE
    2621             : 
    2622      101028 :             CALL dbt_get_stored_coordinates(t2c_sub, [iblk, jblk], iproc)
    2623      160076 :             block_dest(iblk, jblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2624             :          END DO
    2625             :       END DO
    2626             : 
    2627       41340 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)))
    2628        8268 :       CALL dbt_get_info(t2c_main, blk_size_1=bsizes1, blk_size_2=bsizes2)
    2629             : 
    2630       41340 :       ALLOCATE (current_dest(nblks(1), nblks(2), 0:ngroups - 1))
    2631       24804 :       DO igroup = 0, ngroups - 1
    2632             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2633      336688 :          current_dest(:, :, igroup) = block_dest(:, :)
    2634       24804 :          CALL para_env%bcast(current_dest(:, :, igroup), source=igroup*group_size) !bcast from first proc in sub-group
    2635             :       END DO
    2636             : 
    2637             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2638        8268 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2639        8268 :       n_batch = (nocc*ngroups)/batch_size
    2640        8268 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2641             : 
    2642       16536 :       DO i_batch = 1, n_batch
    2643             :          !Loop over groups, blocks and send/receive
    2644      167776 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2645      167776 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2646             :          ir = 0
    2647             :          is = 0
    2648             :          i_msg = 0
    2649       31184 :          DO jblk = 1, nblks(2)
    2650      168344 :             DO iblk = 1, nblks(1)
    2651      434396 :                DO igroup = 0, ngroups - 1
    2652      274320 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2653             : 
    2654       67352 :                   i_msg = i_msg + 1
    2655       67352 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2656             : 
    2657             :                   !a unique tag per block, within this batch
    2658       67352 :                   tag = i_msg - (i_batch - 1)*batch_size
    2659             : 
    2660       67352 :                   found = .FALSE.
    2661       67352 :                   IF (para_env%mepos == block_source(iblk, jblk)) THEN
    2662      101028 :                      CALL dbt_get_block(t2c_main, [iblk, jblk], blk, found)
    2663             :                   END IF
    2664             : 
    2665             :                   !If blocks live on same proc, simply copy. Else MPI send/recv
    2666       67352 :                   IF (block_source(iblk, jblk) == current_dest(iblk, jblk, igroup)) THEN
    2667      101028 :                      IF (found) CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2668             :                   ELSE
    2669       33676 :                      IF (para_env%mepos == block_source(iblk, jblk) .AND. found) THEN
    2670       67352 :                         ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2671    21791618 :                         send_buff(tag)%array(:, :) = blk(:, :)
    2672       16838 :                         is = is + 1
    2673             :                         CALL para_env%isend(msgin=send_buff(tag)%array, dest=current_dest(iblk, jblk, igroup), &
    2674       16838 :                                             request=send_req(is), tag=tag)
    2675             :                      END IF
    2676             : 
    2677       33676 :                      IF (para_env%mepos == current_dest(iblk, jblk, igroup)) THEN
    2678       67352 :                         ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2679       16838 :                         ir = ir + 1
    2680             :                         CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk), &
    2681       16838 :                                             request=recv_req(ir), tag=tag)
    2682             :                      END IF
    2683             :                   END IF
    2684             : 
    2685      204512 :                   IF (found) DEALLOCATE (blk)
    2686             :                END DO
    2687             :             END DO
    2688             :          END DO
    2689             : 
    2690        8268 :          CALL mp_waitall(send_req(1:is))
    2691        8268 :          CALL mp_waitall(recv_req(1:ir))
    2692             :          !clean-up
    2693       75620 :          DO i = 1, batch_size
    2694       75620 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2695             :          END DO
    2696             : 
    2697             :          !Finally copy the data from the buffer to the sub-tensor
    2698             :          i_msg = 0
    2699       31184 :          DO jblk = 1, nblks(2)
    2700      168344 :             DO iblk = 1, nblks(1)
    2701      434396 :                DO igroup = 0, ngroups - 1
    2702      274320 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2703             : 
    2704       67352 :                   i_msg = i_msg + 1
    2705       67352 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2706             : 
    2707             :                   !a unique tag per block, within this batch
    2708       67352 :                   tag = i_msg - (i_batch - 1)*batch_size
    2709             : 
    2710       67352 :                   IF (para_env%mepos == current_dest(iblk, jblk, igroup) .AND. &
    2711      137160 :                       block_source(iblk, jblk) .NE. current_dest(iblk, jblk, igroup)) THEN
    2712             : 
    2713       67352 :                      ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk)))
    2714    21791618 :                      blk(:, :) = recv_buff(tag)%array(:, :)
    2715       84190 :                      CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2716       16838 :                      DEALLOCATE (blk)
    2717             :                   END IF
    2718             :                END DO
    2719             :             END DO
    2720             :          END DO
    2721             : 
    2722             :          !clean-up
    2723       75620 :          DO i = 1, batch_size
    2724       75620 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2725             :          END DO
    2726       16536 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2727             :       END DO !i_batch
    2728        8268 :       CALL dbt_finalize(t2c_sub)
    2729             : 
    2730       16776 :    END SUBROUTINE copy_2c_to_subgroup
    2731             : 
    2732             : ! **************************************************************************************************
    2733             : !> \brief Copy the data of a 3D tensor living in the main MPI group to a sub-group, given the proc
    2734             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2735             : !> \param t3c_sub ...
    2736             : !> \param t3c_main ...
    2737             : !> \param group_size ...
    2738             : !> \param ngroups ...
    2739             : !> \param para_env ...
    2740             : !> \param iatom_to_subgroup ...
    2741             : !> \param dim_at ...
    2742             : !> \param idx_to_at ...
    2743             : ! **************************************************************************************************
    2744       12220 :    SUBROUTINE copy_3c_to_subgroup(t3c_sub, t3c_main, group_size, ngroups, para_env, iatom_to_subgroup, &
    2745       12220 :                                   dim_at, idx_to_at)
    2746             :       TYPE(dbt_type), INTENT(INOUT)                      :: t3c_sub, t3c_main
    2747             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2748             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2749             :       TYPE(cp_1d_logical_p_type), DIMENSION(:), &
    2750             :          INTENT(INOUT), OPTIONAL                         :: iatom_to_subgroup
    2751             :       INTEGER, INTENT(IN), OPTIONAL                      :: dim_at
    2752             :       INTEGER, DIMENSION(:), OPTIONAL                    :: idx_to_at
    2753             : 
    2754             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iatom, &
    2755             :                                                             iblk, igroup, iproc, ir, is, jblk, &
    2756             :                                                             kblk, n_batch, nocc, tag
    2757       12220 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2, bsizes3
    2758       12220 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: block_dest, block_source
    2759       12220 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :)        :: current_dest
    2760             :       INTEGER, DIMENSION(3)                              :: ind, nblks
    2761             :       LOGICAL                                            :: filter_at, found
    2762       12220 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    2763       12220 :       TYPE(cp_3d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2764             :       TYPE(dbt_iterator_type)                            :: iter
    2765       12220 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2766             : 
    2767             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2768             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2769             :       !         to have blocks pre-reserved though
    2770             : 
    2771       12220 :       CALL dbt_get_info(t3c_main, nblks_total=nblks)
    2772             : 
    2773             :       !in some cases, only copy a fraction of the 3c tensor to a given subgroup (corresponding to some atoms)
    2774       12220 :       filter_at = .FALSE.
    2775       12220 :       IF (PRESENT(iatom_to_subgroup) .AND. PRESENT(dim_at) .AND. PRESENT(idx_to_at)) THEN
    2776        7130 :          filter_at = .TRUE.
    2777        7130 :          CPASSERT(nblks(dim_at) == SIZE(idx_to_at))
    2778             :       END IF
    2779             : 
    2780             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2781       61100 :       ALLOCATE (block_source(nblks(1), nblks(2), nblks(3)))
    2782      661340 :       block_source = -1
    2783       12220 :       nocc = 0
    2784       12220 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t3c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2785             :       CALL dbt_iterator_start(iter, t3c_main)
    2786             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2787             :          CALL dbt_iterator_next_block(iter, ind)
    2788             :          CALL dbt_get_block(t3c_main, ind, blk, found)
    2789             :          IF (.NOT. found) CYCLE
    2790             : 
    2791             :          block_source(ind(1), ind(2), ind(3)) = para_env%mepos
    2792             : !$OMP ATOMIC
    2793             :          nocc = nocc + 1
    2794             :          DEALLOCATE (blk)
    2795             :       END DO
    2796             :       CALL dbt_iterator_stop(iter)
    2797             : !$OMP END PARALLEL
    2798             : 
    2799       12220 :       CALL para_env%sum(nocc)
    2800       12220 :       CALL para_env%sum(block_source)
    2801      661340 :       block_source = block_source + para_env%num_pe - 1
    2802       12220 :       IF (nocc == 0) RETURN
    2803             : 
    2804             :       !Loop over the sub tensor, get the block destination
    2805       12220 :       igroup = para_env%mepos/group_size
    2806       48880 :       ALLOCATE (block_dest(nblks(1), nblks(2), nblks(3)))
    2807      661340 :       block_dest = -1
    2808       36660 :       DO kblk = 1, nblks(3)
    2809      170956 :          DO jblk = 1, nblks(2)
    2810      649120 :             DO iblk = 1, nblks(1)
    2811      490384 :                IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2812             : 
    2813      516160 :                CALL dbt_get_stored_coordinates(t3c_sub, [iblk, jblk, kblk], iproc)
    2814      624680 :                block_dest(iblk, jblk, kblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2815             :             END DO
    2816             :          END DO
    2817             :       END DO
    2818             : 
    2819       85540 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)), bsizes3(nblks(3)))
    2820       12220 :       CALL dbt_get_info(t3c_main, blk_size_1=bsizes1, blk_size_2=bsizes2, blk_size_3=bsizes3)
    2821             : 
    2822       73320 :       ALLOCATE (current_dest(nblks(1), nblks(2), nblks(3), 0:ngroups - 1))
    2823       36660 :       DO igroup = 0, ngroups - 1
    2824             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2825     1322680 :          current_dest(:, :, :, igroup) = block_dest(:, :, :)
    2826       36660 :          CALL para_env%bcast(current_dest(:, :, :, igroup), source=igroup*group_size) !bcast from first proc in subgroup
    2827             :       END DO
    2828             : 
    2829             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2830       12220 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2831       12220 :       n_batch = (nocc*ngroups)/batch_size
    2832       12220 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2833             : 
    2834       24440 :       DO i_batch = 1, n_batch
    2835             :          !Loop over groups, blocks and send/receive
    2836      565040 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2837      565040 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2838             :          ir = 0
    2839             :          is = 0
    2840             :          i_msg = 0
    2841       36660 :          DO kblk = 1, nblks(3)
    2842      170956 :             DO jblk = 1, nblks(2)
    2843      649120 :                DO iblk = 1, nblks(1)
    2844     1605448 :                   DO igroup = 0, ngroups - 1
    2845      980768 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2846             : 
    2847      258080 :                      i_msg = i_msg + 1
    2848      258080 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2849             : 
    2850             :                      !a unique tag per block, within this batch
    2851      258080 :                      tag = i_msg - (i_batch - 1)*batch_size
    2852             : 
    2853      258080 :                      IF (filter_at) THEN
    2854      744464 :                         ind(:) = [iblk, jblk, kblk]
    2855      186116 :                         iatom = idx_to_at(ind(dim_at))
    2856      186116 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2857             :                      END IF
    2858             : 
    2859      165022 :                      found = .FALSE.
    2860      165022 :                      IF (para_env%mepos == block_source(iblk, jblk, kblk)) THEN
    2861      330044 :                         CALL dbt_get_block(t3c_main, [iblk, jblk, kblk], blk, found)
    2862             :                      END IF
    2863             : 
    2864             :                      !If blocks live on same proc, simply copy. Else MPI send/recv
    2865      165022 :                      IF (block_source(iblk, jblk, kblk) == current_dest(iblk, jblk, kblk, igroup)) THEN
    2866      341656 :                         IF (found) CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2867             :                      ELSE
    2868       79608 :                         IF (para_env%mepos == block_source(iblk, jblk, kblk) .AND. found) THEN
    2869      199020 :                            ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2870   197765851 :                            send_buff(tag)%array(:, :, :) = blk(:, :, :)
    2871       39804 :                            is = is + 1
    2872             :                            CALL para_env%isend(msgin=send_buff(tag)%array, &
    2873             :                                                dest=current_dest(iblk, jblk, kblk, igroup), &
    2874       39804 :                                                request=send_req(is), tag=tag)
    2875             :                         END IF
    2876             : 
    2877       79608 :                         IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup)) THEN
    2878      199020 :                            ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2879       39804 :                            ir = ir + 1
    2880             :                            CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk, kblk), &
    2881       39804 :                                                request=recv_req(ir), tag=tag)
    2882             :                         END IF
    2883             :                      END IF
    2884             : 
    2885      655406 :                      IF (found) DEALLOCATE (blk)
    2886             :                   END DO
    2887             :                END DO
    2888             :             END DO
    2889             :          END DO
    2890             : 
    2891       12220 :          CALL mp_waitall(send_req(1:is))
    2892       12220 :          CALL mp_waitall(recv_req(1:ir))
    2893             :          !clean-up
    2894      270300 :          DO i = 1, batch_size
    2895      270300 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2896             :          END DO
    2897             : 
    2898             :          !Finally copy the data from the buffer to the sub-tensor
    2899             :          i_msg = 0
    2900       36660 :          DO kblk = 1, nblks(3)
    2901      170956 :             DO jblk = 1, nblks(2)
    2902      649120 :                DO iblk = 1, nblks(1)
    2903     1605448 :                   DO igroup = 0, ngroups - 1
    2904      980768 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2905             : 
    2906      258080 :                      i_msg = i_msg + 1
    2907      258080 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2908             : 
    2909             :                      !a unique tag per block, within this batch
    2910      258080 :                      tag = i_msg - (i_batch - 1)*batch_size
    2911             : 
    2912      258080 :                      IF (filter_at) THEN
    2913      744464 :                         ind(:) = [iblk, jblk, kblk]
    2914      186116 :                         iatom = idx_to_at(ind(dim_at))
    2915      186116 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2916             :                      END IF
    2917             : 
    2918      165022 :                      IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup) .AND. &
    2919      490384 :                          block_source(iblk, jblk, kblk) .NE. current_dest(iblk, jblk, kblk, igroup)) THEN
    2920             : 
    2921      199020 :                         ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2922   197765851 :                         blk(:, :, :) = recv_buff(tag)%array(:, :, :)
    2923      278628 :                         CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2924       39804 :                         DEALLOCATE (blk)
    2925             :                      END IF
    2926             :                   END DO
    2927             :                END DO
    2928             :             END DO
    2929             :          END DO
    2930             : 
    2931             :          !clean-up
    2932      270300 :          DO i = 1, batch_size
    2933      270300 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2934             :          END DO
    2935       24440 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2936             :       END DO !i_batch
    2937       12220 :       CALL dbt_finalize(t3c_sub)
    2938             : 
    2939       24440 :    END SUBROUTINE copy_3c_to_subgroup
    2940             : 
    2941             : ! **************************************************************************************************
    2942             : !> \brief A routine that gather the pieces of the KS matrix accross the subgroup and puts it in the
    2943             : !>        main group. Each b_img, iatom, jatom tuple is one a single CPU
    2944             : !> \param ks_t ...
    2945             : !> \param ks_t_sub ...
    2946             : !> \param group_size ...
    2947             : !> \param sparsity_pattern ...
    2948             : !> \param para_env ...
    2949             : !> \param ri_data ...
    2950             : ! **************************************************************************************************
    2951         214 :    SUBROUTINE gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
    2952             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t, ks_t_sub
    2953             :       INTEGER, INTENT(IN)                                :: group_size
    2954             :       INTEGER, DIMENSION(:, :, :), INTENT(IN)            :: sparsity_pattern
    2955             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2956             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2957             : 
    2958             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'gather_ks_matrix'
    2959             : 
    2960             :       INTEGER                                            :: b_img, dest, handle, i, i_spin, iatom, &
    2961             :                                                             igroup, ir, is, jatom, n_mess, natom, &
    2962             :                                                             nimg, nspins, source, tag
    2963             :       LOGICAL                                            :: found
    2964         214 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2965         214 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2966         214 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2967             : 
    2968         214 :       CALL timeset(routineN, handle)
    2969             : 
    2970         214 :       nimg = SIZE(sparsity_pattern, 3)
    2971         214 :       natom = SIZE(sparsity_pattern, 2)
    2972         214 :       nspins = SIZE(ks_t, 1)
    2973             : 
    2974        5514 :       DO b_img = 1, nimg
    2975             :          n_mess = 0
    2976       12172 :          DO i_spin = 1, nspins
    2977       25916 :             DO jatom = 1, natom
    2978       48104 :                DO iatom = 1, natom
    2979       41232 :                   IF (sparsity_pattern(iatom, jatom, b_img) > -1) n_mess = n_mess + 1
    2980             :                END DO
    2981             :             END DO
    2982             :          END DO
    2983             : 
    2984       40880 :          ALLOCATE (send_buff(n_mess), recv_buff(n_mess))
    2985       46180 :          ALLOCATE (send_req(n_mess), recv_req(n_mess))
    2986        5300 :          ir = 0
    2987        5300 :          is = 0
    2988        5300 :          n_mess = 0
    2989        5300 :          tag = 0
    2990             : 
    2991       12172 :          DO i_spin = 1, nspins
    2992       25916 :             DO jatom = 1, natom
    2993       48104 :                DO iatom = 1, natom
    2994       27488 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    2995        9608 :                   n_mess = n_mess + 1
    2996        9608 :                   tag = tag + 1
    2997             : 
    2998             :                   !sending the message
    2999       28824 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3000       28824 :                   CALL dbt_get_stored_coordinates(ks_t_sub(i_spin, b_img), [iatom, jatom], source) !source within sub
    3001        9608 :                   igroup = sparsity_pattern(iatom, jatom, b_img)
    3002        9608 :                   source = source + igroup*group_size
    3003        9608 :                   IF (para_env%mepos == source) THEN
    3004       14412 :                      CALL dbt_get_block(ks_t_sub(i_spin, b_img), [iatom, jatom], blk, found)
    3005        4804 :                      IF (source == dest) THEN
    3006        4105 :                         IF (found) CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3007             :                      ELSE
    3008       14620 :                         ALLOCATE (send_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3009      309987 :                         send_buff(n_mess)%array(:, :) = 0.0_dp
    3010        3655 :                         IF (found) THEN
    3011      208918 :                            send_buff(n_mess)%array(:, :) = blk(:, :)
    3012             :                         END IF
    3013        3655 :                         is = is + 1
    3014             :                         CALL para_env%isend(msgin=send_buff(n_mess)%array, dest=dest, &
    3015        3655 :                                             request=send_req(is), tag=tag)
    3016             :                      END IF
    3017        4804 :                      DEALLOCATE (blk)
    3018             :                   END IF
    3019             : 
    3020             :                   !receiving the message
    3021       23352 :                   IF (para_env%mepos == dest .AND. source .NE. dest) THEN
    3022       14620 :                      ALLOCATE (recv_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3023        3655 :                      ir = ir + 1
    3024             :                      CALL para_env%irecv(msgout=recv_buff(n_mess)%array, source=source, &
    3025        3655 :                                          request=recv_req(ir), tag=tag)
    3026             :                   END IF
    3027             :                END DO !iatom
    3028             :             END DO !jatom
    3029             :          END DO !ispin
    3030             : 
    3031        5300 :          CALL mp_waitall(send_req(1:is))
    3032        5300 :          CALL mp_waitall(recv_req(1:ir))
    3033             : 
    3034             :          !Copy the messages received into the KS matrix
    3035        5300 :          n_mess = 0
    3036       12172 :          DO i_spin = 1, nspins
    3037       25916 :             DO jatom = 1, natom
    3038       48104 :                DO iatom = 1, natom
    3039       27488 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    3040        9608 :                   n_mess = n_mess + 1
    3041             : 
    3042       28824 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3043       23352 :                   IF (para_env%mepos == dest) THEN
    3044        4804 :                      IF (.NOT. ASSOCIATED(recv_buff(n_mess)%array)) CYCLE
    3045       14620 :                      ALLOCATE (blk(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3046      309987 :                      blk(:, :) = recv_buff(n_mess)%array(:, :)
    3047       18275 :                      CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3048        3655 :                      DEALLOCATE (blk)
    3049             :                   END IF
    3050             :                END DO
    3051             :             END DO
    3052             :          END DO
    3053             : 
    3054             :          !clean-up
    3055       14908 :          DO i = 1, n_mess
    3056        9608 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    3057       14908 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    3058             :          END DO
    3059        5514 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    3060             :       END DO !b_img
    3061             : 
    3062         214 :       CALL timestop(handle)
    3063             : 
    3064         214 :    END SUBROUTINE gather_ks_matrix
    3065             : 
    3066             : ! **************************************************************************************************
    3067             : !> \brief copy all required 2c tensors from the main MPI group to the subgroups
    3068             : !> \param mat_2c_pot ...
    3069             : !> \param t_2c_work ...
    3070             : !> \param t_2c_ao_tmp ...
    3071             : !> \param ks_t_split ...
    3072             : !> \param ks_t_sub ...
    3073             : !> \param group_size ...
    3074             : !> \param ngroups ...
    3075             : !> \param para_env ...
    3076             : !> \param para_env_sub ...
    3077             : !> \param ri_data ...
    3078             : ! **************************************************************************************************
    3079         214 :    SUBROUTINE get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
    3080             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3081             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3082             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work, t_2c_ao_tmp, ks_t_split
    3083             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t_sub
    3084             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3085             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3086             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3087             : 
    3088             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_tensors'
    3089             : 
    3090             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, iproc, &
    3091             :                                                             j, natom, nblks, nimg, nspins
    3092             :       INTEGER(int_8)                                     :: nze
    3093             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3094         214 :                                                             dist1, dist2
    3095             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3096         428 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3097         214 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3098             :       REAL(dp)                                           :: occ
    3099             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3100         642 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3101        2782 :       TYPE(dbt_type)                                     :: work, work_sub
    3102             : 
    3103         214 :       CALL timeset(routineN, handle)
    3104             : 
    3105             :       !Create the 2d pgrid
    3106         214 :       pdims_2d = 0
    3107         214 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3108             : 
    3109         214 :       natom = SIZE(ri_data%bsizes_RI)
    3110         214 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3111         642 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3112         642 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3113        1508 :       DO i_RI = 1, ri_data%ncell_RI
    3114        3882 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3115        6686 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3116             :       END DO
    3117             : 
    3118             :       !nRI x nRI 2c tensors
    3119             :       CALL create_2c_tensor(t_2c_work(1), dist1, dist2, pgrid_2d, &
    3120             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3121             :                             name="(RI | RI)")
    3122         214 :       DEALLOCATE (dist1, dist2)
    3123             : 
    3124             :       CALL create_2c_tensor(t_2c_work(2), dist1, dist2, pgrid_2d, &
    3125             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3126         214 :                             name="(RI | RI)")
    3127         214 :       DEALLOCATE (dist1, dist2)
    3128             : 
    3129             :       !the AO based tensors
    3130             :       CALL create_2c_tensor(ks_t_split(1), dist1, dist2, pgrid_2d, &
    3131             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3132             :                             name="(AO | AO)")
    3133         214 :       DEALLOCATE (dist1, dist2)
    3134         214 :       CALL dbt_create(ks_t_split(1), ks_t_split(2))
    3135             : 
    3136             :       CALL create_2c_tensor(t_2c_ao_tmp(1), dist1, dist2, pgrid_2d, &
    3137             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, &
    3138             :                             name="(AO | AO)")
    3139         214 :       DEALLOCATE (dist1, dist2)
    3140             : 
    3141         214 :       nspins = SIZE(ks_t_sub, 1)
    3142         214 :       nimg = SIZE(ks_t_sub, 2)
    3143        5514 :       DO i_img = 1, nimg
    3144       12386 :          DO i_spin = 1, nspins
    3145       12172 :             CALL dbt_create(t_2c_ao_tmp(1), ks_t_sub(i_spin, i_img))
    3146             :          END DO
    3147             :       END DO
    3148             : 
    3149             :       !Finally the HFX potential matrices
    3150             :       !For now, we do a convoluted things where we go to tensors first, then back to matrices.
    3151             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3152             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3153             :                             name="(RI | RI)")
    3154         214 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3155             : 
    3156         856 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3157         214 :       iproc = 0
    3158         428 :       DO i = 0, pdims_2d(1) - 1
    3159         642 :          DO j = 0, pdims_2d(2) - 1
    3160         214 :             dbcsr_pgrid(i, j) = iproc
    3161         428 :             iproc = iproc + 1
    3162             :          END DO
    3163             :       END DO
    3164             : 
    3165             :       !We need to have the same exact 2d block dist as the tensors
    3166         856 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3167         642 :       row_dist(:) = dist1(:)
    3168         642 :       col_dist(:) = dist2(:)
    3169             : 
    3170         428 :       ALLOCATE (RI_blk_size(natom))
    3171         642 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3172             : 
    3173             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3174         214 :                                   row_dist=row_dist, col_dist=col_dist)
    3175             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3176         214 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3177             : 
    3178        5514 :       DO i_img = 1, nimg
    3179        5300 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3180        5300 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3181        5300 :          CALL get_tensor_occupancy(work, nze, occ)
    3182        5300 :          IF (nze == 0) CYCLE
    3183             : 
    3184        4132 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3185        4132 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3186        4132 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3187        9646 :          CALL dbt_clear(work_sub)
    3188             :       END DO
    3189             : 
    3190         214 :       CALL dbt_destroy(work)
    3191         214 :       CALL dbt_destroy(work_sub)
    3192         214 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3193         214 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3194         214 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3195         214 :       CALL timestop(handle)
    3196             : 
    3197        1926 :    END SUBROUTINE get_subgroup_2c_tensors
    3198             : 
    3199             : ! **************************************************************************************************
    3200             : !> \brief copy all required 3c tensors from the main MPI group to the subgroups
    3201             : !> \param t_3c_int ...
    3202             : !> \param t_3c_work_2 ...
    3203             : !> \param t_3c_work_3 ...
    3204             : !> \param t_3c_apc ...
    3205             : !> \param t_3c_apc_sub ...
    3206             : !> \param group_size ...
    3207             : !> \param ngroups ...
    3208             : !> \param para_env ...
    3209             : !> \param para_env_sub ...
    3210             : !> \param ri_data ...
    3211             : ! **************************************************************************************************
    3212         214 :    SUBROUTINE get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
    3213             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3214             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_int, t_3c_work_2, t_3c_work_3
    3215             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, t_3c_apc_sub
    3216             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3217             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3218             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3219             : 
    3220             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_tensors'
    3221             : 
    3222             :       INTEGER                                            :: batch_size, bfac, bo(2), handle, &
    3223             :                                                             handle2, i_blk, i_img, i_RI, i_spin, &
    3224             :                                                             ib, natom, nblks_AO, nblks_RI, nimg, &
    3225             :                                                             nspins
    3226             :       INTEGER(int_8)                                     :: nze
    3227         214 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3228         214 :                                                             bsizes_stack, bsizes_tmp, dist1, &
    3229         214 :                                                             dist2, dist3, dist_stack, idx_to_at
    3230             :       INTEGER, DIMENSION(3)                              :: pdims
    3231             :       REAL(dp)                                           :: occ
    3232        1926 :       TYPE(dbt_distribution_type)                        :: t_dist
    3233         642 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3234        5350 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3235             : 
    3236         214 :       CALL timeset(routineN, handle)
    3237             : 
    3238         214 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3239         642 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3240        1508 :       DO i_RI = 1, ri_data%ncell_RI
    3241        6686 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3242             :       END DO
    3243             : 
    3244             :       !Preparing larger block sizes for efficient communication (less, bigger messages)
    3245             :       !we put 2 atoms per RI block
    3246         214 :       bfac = 2
    3247         214 :       natom = SIZE(ri_data%bsizes_RI)
    3248         214 :       nblks_RI = MAX(1, natom/bfac)
    3249         642 :       ALLOCATE (bsizes_tmp(nblks_RI))
    3250         428 :       DO i_blk = 1, nblks_RI
    3251         214 :          bo = get_limit(natom, nblks_RI, i_blk - 1)
    3252         856 :          bsizes_tmp(i_blk) = SUM(ri_data%bsizes_RI(bo(1):bo(2)))
    3253             :       END DO
    3254         642 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3255        1508 :       DO i_RI = 1, ri_data%ncell_RI
    3256        2802 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = bsizes_tmp(:)
    3257             :       END DO
    3258             : 
    3259         214 :       batch_size = ri_data%kp_stack_size
    3260         214 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3261         642 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3262        7062 :       DO ib = 1, batch_size
    3263       33558 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3264             :       END DO
    3265             : 
    3266             :       !Create the pgrid for the configuration correspoinding to ri_data%t_3c_int_ctr_3
    3267         214 :       natom = SIZE(ri_data%bsizes_RI)
    3268         214 :       pdims = 0
    3269             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3270         856 :                             tensor_dims=[SIZE(bsizes_RI_ext_split), 1, batch_size*SIZE(ri_data%bsizes_AO_split)])
    3271             : 
    3272             :       !Create all required 3c tensors in that configuration
    3273             :       CALL create_3c_tensor(t_3c_int(1), dist1, dist2, dist3, &
    3274             :                             pgrid, bsizes_RI_ext_split, ri_data%bsizes_AO_split, &
    3275         214 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(RI | AO AO)")
    3276         214 :       nimg = SIZE(t_3c_int)
    3277        5300 :       DO i_img = 2, nimg
    3278        5300 :          CALL dbt_create(t_3c_int(1), t_3c_int(i_img))
    3279             :       END DO
    3280             : 
    3281             :       !The stacked work tensors, in a distribution that matches that of t_3c_int
    3282         428 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3283        7062 :       DO ib = 1, batch_size
    3284       33558 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3285             :       END DO
    3286             : 
    3287         214 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3288             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3289         214 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3290         214 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3291         214 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3292         214 :       CALL dbt_distribution_destroy(t_dist)
    3293         214 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3294             : 
    3295             :       !For more efficient communication, we use intermediate tensors with larger block size
    3296             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3297             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3298         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3299         214 :       DEALLOCATE (dist1, dist2, dist3)
    3300             : 
    3301             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3302             :                             ri_data%pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3303         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3304         214 :       DEALLOCATE (dist1, dist2, dist3)
    3305             : 
    3306             :       !Finally copy the integrals into the subgroups (if not there already)
    3307         214 :       CALL timeset(routineN//"_ints", handle2)
    3308         214 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
    3309        3800 :          DO i_img = 1, nimg
    3310        3800 :             CALL dbt_copy(ri_data%kp_t_3c_int(i_img), t_3c_int(i_img), move_data=.TRUE.)
    3311             :          END DO
    3312             :       ELSE
    3313        2414 :          ALLOCATE (ri_data%kp_t_3c_int(nimg))
    3314        1714 :          DO i_img = 1, nimg
    3315        1644 :             CALL dbt_create(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img))
    3316        1644 :             CALL get_tensor_occupancy(ri_data%t_3c_int_ctr_1(1, i_img), nze, occ)
    3317        1644 :             IF (nze == 0) CYCLE
    3318        1238 :             CALL dbt_copy(ri_data%t_3c_int_ctr_1(1, i_img), work_atom_block, order=[2, 1, 3])
    3319        1238 :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, ngroups, para_env)
    3320        2952 :             CALL dbt_copy(work_atom_block_sub, t_3c_int(i_img), move_data=.TRUE.)
    3321             :          END DO
    3322             :       END IF
    3323         214 :       CALL timestop(handle2)
    3324         214 :       CALL dbt_pgrid_destroy(pgrid)
    3325         214 :       CALL dbt_destroy(work_atom_block)
    3326         214 :       CALL dbt_destroy(work_atom_block_sub)
    3327             : 
    3328             :       !Do the same for the t_3c_ctr_2 configuration
    3329         214 :       pdims = 0
    3330             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3331         856 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3332             : 
    3333             :       !For more efficient communication, we use intermediate tensors with larger block size
    3334             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3335             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3336         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3337         214 :       DEALLOCATE (dist1, dist2, dist3)
    3338             : 
    3339             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3340             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3341         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3342         214 :       DEALLOCATE (dist1, dist2, dist3)
    3343             : 
    3344             :       !template for t_3c_apc_sub
    3345             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3346             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3347         214 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3348             : 
    3349             :       !create t_3c_work_2 tensors in a distribution that matches the above
    3350         428 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3351        7062 :       DO ib = 1, batch_size
    3352       33558 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3353             :       END DO
    3354             : 
    3355         214 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3356             :       CALL dbt_create(t_3c_work_2(1), "work_2_stack", t_dist, [1], [2, 3], &
    3357         214 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3358         214 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3359         214 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3360         214 :       CALL dbt_distribution_destroy(t_dist)
    3361         214 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3362             : 
    3363             :       !Finally copy data from t_3c_apc to the subgroups
    3364         642 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3365         214 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3366         214 :       nspins = SIZE(t_3c_apc, 1)
    3367         214 :       CALL timeset(routineN//"_apc", handle2)
    3368        5514 :       DO i_img = 1, nimg
    3369       12172 :          DO i_spin = 1, nspins
    3370        6872 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3371        6872 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3372        6872 :             IF (nze == 0) CYCLE
    3373        6032 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3374             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3375        6032 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3376       18204 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3377             :          END DO
    3378       12386 :          DO i_spin = 1, nspins
    3379       12172 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3380             :          END DO
    3381             :       END DO
    3382         214 :       CALL timestop(handle2)
    3383         214 :       CALL dbt_pgrid_destroy(pgrid)
    3384         214 :       CALL dbt_destroy(tmp)
    3385         214 :       CALL dbt_destroy(work_atom_block)
    3386         214 :       CALL dbt_destroy(work_atom_block_sub)
    3387             : 
    3388         214 :       CALL timestop(handle)
    3389             : 
    3390         856 :    END SUBROUTINE get_subgroup_3c_tensors
    3391             : 
    3392             : ! **************************************************************************************************
    3393             : !> \brief copy all required 2c force tensors from the main MPI group to the subgroups
    3394             : !> \param t_2c_inv ...
    3395             : !> \param t_2c_bint ...
    3396             : !> \param t_2c_metric ...
    3397             : !> \param mat_2c_pot ...
    3398             : !> \param t_2c_work ...
    3399             : !> \param rho_ao_t ...
    3400             : !> \param rho_ao_t_sub ...
    3401             : !> \param t_2c_der_metric ...
    3402             : !> \param t_2c_der_metric_sub ...
    3403             : !> \param mat_der_pot ...
    3404             : !> \param mat_der_pot_sub ...
    3405             : !> \param group_size ...
    3406             : !> \param ngroups ...
    3407             : !> \param para_env ...
    3408             : !> \param para_env_sub ...
    3409             : !> \param ri_data ...
    3410             : !> \note Main MPI group tensors are deleted within this routine, for memory optimization
    3411             : ! **************************************************************************************************
    3412          84 :    SUBROUTINE get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
    3413          42 :                                      rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
    3414          42 :                                      mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
    3415             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_inv, t_2c_bint, t_2c_metric
    3416             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3417             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work
    3418             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
    3419             :                                                             t_2c_der_metric_sub
    3420             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot, mat_der_pot_sub
    3421             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3422             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3423             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3424             : 
    3425             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_derivs'
    3426             : 
    3427             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, i_xyz, &
    3428             :                                                             iatom, iproc, j, natom, nblks, nimg, &
    3429             :                                                             nspins
    3430             :       INTEGER(int_8)                                     :: nze
    3431             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3432          42 :                                                             dist1, dist2
    3433             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3434          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3435          42 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3436             :       REAL(dp)                                           :: occ
    3437             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3438         126 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3439         546 :       TYPE(dbt_type)                                     :: work, work_sub
    3440             : 
    3441          42 :       CALL timeset(routineN, handle)
    3442             : 
    3443             :       !Note: a fair portion of this routine is copied from the energy version of it
    3444             :       !Create the 2d pgrid
    3445          42 :       pdims_2d = 0
    3446          42 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3447             : 
    3448          42 :       natom = SIZE(ri_data%bsizes_RI)
    3449          42 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3450         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3451         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3452         294 :       DO i_RI = 1, ri_data%ncell_RI
    3453         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3454        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3455             :       END DO
    3456             : 
    3457             :       !nRI x nRI 2c tensors
    3458             :       CALL create_2c_tensor(t_2c_inv(1), dist1, dist2, pgrid_2d, &
    3459             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3460             :                             name="(RI | RI)")
    3461          42 :       DEALLOCATE (dist1, dist2)
    3462             : 
    3463          42 :       CALL dbt_create(t_2c_inv(1), t_2c_bint(1))
    3464          42 :       CALL dbt_create(t_2c_inv(1), t_2c_metric(1))
    3465          84 :       DO iatom = 2, natom
    3466          42 :          CALL dbt_create(t_2c_inv(1), t_2c_inv(iatom))
    3467          42 :          CALL dbt_create(t_2c_inv(1), t_2c_bint(iatom))
    3468          84 :          CALL dbt_create(t_2c_inv(1), t_2c_metric(iatom))
    3469             :       END DO
    3470          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(1))
    3471          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(2))
    3472          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(3))
    3473          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(4))
    3474             : 
    3475             :       CALL create_2c_tensor(t_2c_work(5), dist1, dist2, pgrid_2d, &
    3476             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3477          42 :                             name="(RI | RI)")
    3478          42 :       DEALLOCATE (dist1, dist2)
    3479             : 
    3480             :       !copy the data from the main group.
    3481         126 :       DO iatom = 1, natom
    3482          84 :          CALL copy_2c_to_subgroup(t_2c_inv(iatom), ri_data%t_2c_inv(1, iatom), group_size, ngroups, para_env)
    3483          84 :          CALL copy_2c_to_subgroup(t_2c_bint(iatom), ri_data%t_2c_int(1, iatom), group_size, ngroups, para_env)
    3484         126 :          CALL copy_2c_to_subgroup(t_2c_metric(iatom), ri_data%t_2c_pot(1, iatom), group_size, ngroups, para_env)
    3485             :       END DO
    3486             : 
    3487             :       !This includes the derivatives of the RI metric, for which there is one per atom
    3488         168 :       DO i_xyz = 1, 3
    3489         420 :          DO iatom = 1, natom
    3490         252 :             CALL dbt_create(t_2c_inv(1), t_2c_der_metric_sub(iatom, i_xyz))
    3491             :             CALL copy_2c_to_subgroup(t_2c_der_metric_sub(iatom, i_xyz), t_2c_der_metric(iatom, i_xyz), &
    3492         252 :                                      group_size, ngroups, para_env)
    3493         378 :             CALL dbt_destroy(t_2c_der_metric(iatom, i_xyz))
    3494             :          END DO
    3495             :       END DO
    3496             : 
    3497             :       !AO x AO 2c tensors
    3498             :       CALL create_2c_tensor(rho_ao_t_sub(1, 1), dist1, dist2, pgrid_2d, &
    3499             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3500             :                             name="(AO | AO)")
    3501          42 :       DEALLOCATE (dist1, dist2)
    3502          42 :       nspins = SIZE(rho_ao_t, 1)
    3503          42 :       nimg = SIZE(rho_ao_t, 2)
    3504             : 
    3505         980 :       DO i_img = 1, nimg
    3506        2084 :          DO i_spin = 1, nspins
    3507        1104 :             IF (.NOT. (i_img == 1 .AND. i_spin == 1)) &
    3508        1062 :                CALL dbt_create(rho_ao_t_sub(1, 1), rho_ao_t_sub(i_spin, i_img))
    3509             :             CALL copy_2c_to_subgroup(rho_ao_t_sub(i_spin, i_img), rho_ao_t(i_spin, i_img), &
    3510        1104 :                                      group_size, ngroups, para_env)
    3511        2042 :             CALL dbt_destroy(rho_ao_t(i_spin, i_img))
    3512             :          END DO
    3513             :       END DO
    3514             : 
    3515             :       !The RIxRI matrices, going through tensors
    3516             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3517             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3518             :                             name="(RI | RI)")
    3519          42 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3520             : 
    3521         168 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3522          42 :       iproc = 0
    3523          84 :       DO i = 0, pdims_2d(1) - 1
    3524         126 :          DO j = 0, pdims_2d(2) - 1
    3525          42 :             dbcsr_pgrid(i, j) = iproc
    3526          84 :             iproc = iproc + 1
    3527             :          END DO
    3528             :       END DO
    3529             : 
    3530             :       !We need to have the same exact 2d block dist as the tensors
    3531         168 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3532         126 :       row_dist(:) = dist1(:)
    3533         126 :       col_dist(:) = dist2(:)
    3534             : 
    3535          84 :       ALLOCATE (RI_blk_size(natom))
    3536         126 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3537             : 
    3538             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3539          42 :                                   row_dist=row_dist, col_dist=col_dist)
    3540             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3541          42 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3542             : 
    3543             :       !The HFX potential
    3544         980 :       DO i_img = 1, nimg
    3545         938 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3546         938 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3547         938 :          CALL get_tensor_occupancy(work, nze, occ)
    3548         938 :          IF (nze == 0) CYCLE
    3549             : 
    3550         662 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3551         662 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3552         662 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3553        1642 :          CALL dbt_clear(work_sub)
    3554             :       END DO
    3555             : 
    3556             :       !The derivatives of the HFX potential
    3557         168 :       DO i_xyz = 1, 3
    3558        2982 :          DO i_img = 1, nimg
    3559        2814 :             CALL dbcsr_create(mat_der_pot_sub(i_img, i_xyz), template=mat_2c_pot(1))
    3560        2814 :             CALL dbt_copy_matrix_to_tensor(mat_der_pot(i_img, i_xyz), work)
    3561        2814 :             CALL dbcsr_release(mat_der_pot(i_img, i_xyz))
    3562        2814 :             CALL get_tensor_occupancy(work, nze, occ)
    3563        2814 :             IF (nze == 0) CYCLE
    3564             : 
    3565        1986 :             CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3566        1986 :             CALL dbt_copy_tensor_to_matrix(work_sub, mat_der_pot_sub(i_img, i_xyz))
    3567        1986 :             CALL dbcsr_filter(mat_der_pot_sub(i_img, i_xyz), ri_data%filter_eps)
    3568        4926 :             CALL dbt_clear(work_sub)
    3569             :          END DO
    3570             :       END DO
    3571             : 
    3572          42 :       CALL dbt_destroy(work)
    3573          42 :       CALL dbt_destroy(work_sub)
    3574          42 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3575          42 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3576          42 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3577             : 
    3578          42 :       CALL timestop(handle)
    3579             : 
    3580         336 :    END SUBROUTINE get_subgroup_2c_derivs
    3581             : 
    3582             : ! **************************************************************************************************
    3583             : !> \brief copy all required 3c derivative tensors from the main MPI group to the subgroups
    3584             : !> \param t_3c_work_2 ...
    3585             : !> \param t_3c_work_3 ...
    3586             : !> \param t_3c_der_AO ...
    3587             : !> \param t_3c_der_AO_sub ...
    3588             : !> \param t_3c_der_RI ...
    3589             : !> \param t_3c_der_RI_sub ...
    3590             : !> \param t_3c_apc ...
    3591             : !> \param t_3c_apc_sub ...
    3592             : !> \param t_3c_der_stack ...
    3593             : !> \param group_size ...
    3594             : !> \param ngroups ...
    3595             : !> \param para_env ...
    3596             : !> \param para_env_sub ...
    3597             : !> \param ri_data ...
    3598             : !> \note the tensor containing the derivatives in the main MPI group are deleted for memory
    3599             : ! **************************************************************************************************
    3600          42 :    SUBROUTINE get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
    3601          42 :                                      t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, &
    3602          42 :                                      t_3c_der_stack, group_size, ngroups, para_env, para_env_sub, &
    3603             :                                      ri_data)
    3604             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_work_2, t_3c_work_3
    3605             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_AO, t_3c_der_AO_sub, &
    3606             :                                                             t_3c_der_RI, t_3c_der_RI_sub, &
    3607             :                                                             t_3c_apc, t_3c_apc_sub
    3608             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_der_stack
    3609             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3610             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3611             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3612             : 
    3613             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_derivs'
    3614             : 
    3615             :       INTEGER                                            :: batch_size, handle, i_img, i_RI, i_spin, &
    3616             :                                                             i_xyz, ib, nblks_AO, nblks_RI, nimg, &
    3617             :                                                             nspins, pdims(3)
    3618             :       INTEGER(int_8)                                     :: nze
    3619          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3620          42 :                                                             bsizes_stack, dist1, dist2, dist3, &
    3621          42 :                                                             dist_stack, idx_to_at
    3622             :       REAL(dp)                                           :: occ
    3623         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    3624         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3625        1050 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3626             : 
    3627          42 :       CALL timeset(routineN, handle)
    3628             : 
    3629             :       !We use intermediate tensors with larger block size for more optimized communication
    3630          42 :       nblks_RI = SIZE(ri_data%bsizes_RI)
    3631         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3632         294 :       DO i_RI = 1, ri_data%ncell_RI
    3633         798 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
    3634             :       END DO
    3635             : 
    3636          42 :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), pdims=pdims)
    3637          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3638             : 
    3639             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3640             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3641          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3642          42 :       DEALLOCATE (dist1, dist2, dist3)
    3643             : 
    3644             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3645             :                             ri_data%pgrid_2, bsizes_RI_ext, ri_data%bsizes_AO, &
    3646          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3647          42 :       DEALLOCATE (dist1, dist2, dist3)
    3648          42 :       CALL dbt_pgrid_destroy(pgrid)
    3649             : 
    3650             :       !We use the 3c integrals on the subgroup as template for the derivatives
    3651          42 :       nimg = ri_data%nimg
    3652         168 :       DO i_xyz = 1, 3
    3653        2940 :          DO i_img = 1, nimg
    3654        2814 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_AO_sub(i_img, i_xyz))
    3655        2814 :             CALL get_tensor_occupancy(t_3c_der_AO(i_img, i_xyz), nze, occ)
    3656        2814 :             IF (nze == 0) CYCLE
    3657             : 
    3658        1932 :             CALL dbt_copy(t_3c_der_AO(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3659             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3660        1932 :                                      group_size, ngroups, para_env)
    3661        4872 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_AO_sub(i_img, i_xyz), move_data=.TRUE.)
    3662             :          END DO
    3663             : 
    3664        2940 :          DO i_img = 1, nimg
    3665        2814 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_RI_sub(i_img, i_xyz))
    3666        2814 :             CALL get_tensor_occupancy(t_3c_der_RI(i_img, i_xyz), nze, occ)
    3667        2814 :             IF (nze == 0) CYCLE
    3668             : 
    3669        1920 :             CALL dbt_copy(t_3c_der_RI(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3670             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3671        1920 :                                      group_size, ngroups, para_env)
    3672        4860 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_RI_sub(i_img, i_xyz), move_data=.TRUE.)
    3673             :          END DO
    3674             : 
    3675        2982 :          DO i_img = 1, nimg
    3676        2814 :             CALL dbt_destroy(t_3c_der_RI(i_img, i_xyz))
    3677        2940 :             CALL dbt_destroy(t_3c_der_AO(i_img, i_xyz))
    3678             :          END DO
    3679             :       END DO
    3680          42 :       CALL dbt_destroy(work_atom_block_sub)
    3681          42 :       CALL dbt_destroy(work_atom_block)
    3682             : 
    3683             :       !Deal with t_3c_apc
    3684          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3685         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3686         294 :       DO i_RI = 1, ri_data%ncell_RI
    3687        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3688             :       END DO
    3689             : 
    3690          42 :       pdims = 0
    3691             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3692         168 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3693             : 
    3694             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3695             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3696          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3697          42 :       DEALLOCATE (dist1, dist2, dist3)
    3698             : 
    3699             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3700             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3701          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3702          42 :       DEALLOCATE (dist1, dist2, dist3)
    3703             : 
    3704             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3705             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3706          42 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3707          42 :       DEALLOCATE (dist1, dist2, dist3)
    3708             : 
    3709         126 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3710          42 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3711          42 :       nspins = SIZE(t_3c_apc, 1)
    3712         980 :       DO i_img = 1, nimg
    3713        2042 :          DO i_spin = 1, nspins
    3714        1104 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3715        1104 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3716        1104 :             IF (nze == 0) CYCLE
    3717        1098 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3718             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3719        1098 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3720        3140 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3721             :          END DO
    3722        2084 :          DO i_spin = 1, nspins
    3723        2042 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3724             :          END DO
    3725             :       END DO
    3726          42 :       CALL dbt_destroy(tmp)
    3727          42 :       CALL dbt_destroy(work_atom_block)
    3728          42 :       CALL dbt_destroy(work_atom_block_sub)
    3729          42 :       CALL dbt_pgrid_destroy(pgrid)
    3730             : 
    3731             :       !t_3c_work_3 based on structure of 3c integrals/derivs
    3732          42 :       batch_size = ri_data%kp_stack_size
    3733          42 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3734         126 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3735        1386 :       DO ib = 1, batch_size
    3736        6826 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3737             :       END DO
    3738             : 
    3739         294 :       ALLOCATE (dist1(ri_data%ncell_RI*nblks_RI), dist2(nblks_AO), dist3(nblks_AO))
    3740             :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3741          42 :                         proc_dist_3=dist3, pdims=pdims)
    3742             : 
    3743         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3744        1386 :       DO ib = 1, batch_size
    3745        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3746             :       END DO
    3747             : 
    3748          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3749          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3750             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3751          42 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3752          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3753          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3754          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(4))
    3755          42 :       CALL dbt_distribution_destroy(t_dist)
    3756          42 :       CALL dbt_pgrid_destroy(pgrid)
    3757          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3758             : 
    3759             :       !the derivatives are stacked in the same way
    3760          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(1))
    3761          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(2))
    3762          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(3))
    3763          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(4))
    3764          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(5))
    3765          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(6))
    3766             : 
    3767             :       !t_3c_work_2 based on structure of t_3c_apc
    3768         294 :       ALLOCATE (dist1(nblks_AO), dist2(ri_data%ncell_RI*nblks_RI), dist3(nblks_AO))
    3769             :       CALL dbt_get_info(t_3c_apc_sub(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3770          42 :                         proc_dist_3=dist3, pdims=pdims)
    3771             : 
    3772         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3773        1386 :       DO ib = 1, batch_size
    3774        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3775             :       END DO
    3776             : 
    3777          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3778          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3779             :       CALL dbt_create(t_3c_work_2(1), "work_3_stack", t_dist, [1], [2, 3], &
    3780          42 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3781          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3782          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3783          42 :       CALL dbt_distribution_destroy(t_dist)
    3784          42 :       CALL dbt_pgrid_destroy(pgrid)
    3785          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3786             : 
    3787          42 :       CALL timestop(handle)
    3788             : 
    3789          84 :    END SUBROUTINE get_subgroup_3c_derivs
    3790             : 
    3791             : ! **************************************************************************************************
    3792             : !> \brief A routine that reorders the t_3c_int tensors such that all items which are fully empty
    3793             : !>        are bunched together. This way, we can get much more efficient screening based on NZE
    3794             : !> \param t_3c_ints ...
    3795             : !> \param ri_data ...
    3796             : ! **************************************************************************************************
    3797          70 :    SUBROUTINE reorder_3c_ints(t_3c_ints, ri_data)
    3798             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_ints
    3799             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3800             : 
    3801             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_ints'
    3802             : 
    3803             :       INTEGER                                            :: handle, i_img, idx, idx_empty, idx_full, &
    3804             :                                                             nimg
    3805             :       INTEGER(int_8)                                     :: nze
    3806             :       REAL(dp)                                           :: occ
    3807          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3808             : 
    3809          70 :       CALL timeset(routineN, handle)
    3810             : 
    3811          70 :       nimg = ri_data%nimg
    3812        2414 :       ALLOCATE (t_3c_tmp(nimg))
    3813        1714 :       DO i_img = 1, nimg
    3814        1644 :          CALL dbt_create(t_3c_ints(i_img), t_3c_tmp(i_img))
    3815        1714 :          CALL dbt_copy(t_3c_ints(i_img), t_3c_tmp(i_img), move_data=.TRUE.)
    3816             :       END DO
    3817             : 
    3818             :       !Loop over the images, check if ints have NZE == 0, and put them at the start or end of the
    3819             :       !initial tensor array. Keep the mapping in an array
    3820         210 :       ALLOCATE (ri_data%idx_to_img(nimg))
    3821          70 :       idx_full = 0
    3822          70 :       idx_empty = nimg + 1
    3823             : 
    3824        1714 :       DO i_img = 1, nimg
    3825        1644 :          CALL get_tensor_occupancy(t_3c_tmp(i_img), nze, occ)
    3826        1644 :          IF (nze == 0) THEN
    3827         406 :             idx_empty = idx_empty - 1
    3828         406 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_empty), move_data=.TRUE.)
    3829         406 :             ri_data%idx_to_img(idx_empty) = i_img
    3830             :          ELSE
    3831        1238 :             idx_full = idx_full + 1
    3832        1238 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_full), move_data=.TRUE.)
    3833        1238 :             ri_data%idx_to_img(idx_full) = i_img
    3834             :          END IF
    3835        3358 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3836             :       END DO
    3837             : 
    3838             :       !store the highest image index with non-zero integrals
    3839          70 :       ri_data%nimg_nze = idx_full
    3840             : 
    3841         140 :       ALLOCATE (ri_data%img_to_idx(nimg))
    3842        1714 :       DO idx = 1, nimg
    3843        1714 :          ri_data%img_to_idx(ri_data%idx_to_img(idx)) = idx
    3844             :       END DO
    3845             : 
    3846          70 :       CALL timestop(handle)
    3847             : 
    3848        1784 :    END SUBROUTINE reorder_3c_ints
    3849             : 
    3850             : ! **************************************************************************************************
    3851             : !> \brief A routine that reorders the 3c derivatives, the same way that the integrals are, also to
    3852             : !>        increase efficiency of screening
    3853             : !> \param t_3c_derivs ...
    3854             : !> \param ri_data ...
    3855             : ! **************************************************************************************************
    3856          84 :    SUBROUTINE reorder_3c_derivs(t_3c_derivs, ri_data)
    3857             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_derivs
    3858             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3859             : 
    3860             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_derivs'
    3861             : 
    3862             :       INTEGER                                            :: handle, i_img, i_xyz, idx, nimg
    3863             :       INTEGER(int_8)                                     :: nze
    3864             :       REAL(dp)                                           :: occ
    3865          84 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3866             : 
    3867          84 :       CALL timeset(routineN, handle)
    3868             : 
    3869          84 :       nimg = ri_data%nimg
    3870        2800 :       ALLOCATE (t_3c_tmp(nimg))
    3871        1960 :       DO i_img = 1, nimg
    3872        1960 :          CALL dbt_create(t_3c_derivs(1, 1), t_3c_tmp(i_img))
    3873             :       END DO
    3874             : 
    3875         336 :       DO i_xyz = 1, 3
    3876        5880 :          DO i_img = 1, nimg
    3877        5880 :             CALL dbt_copy(t_3c_derivs(i_img, i_xyz), t_3c_tmp(i_img), move_data=.TRUE.)
    3878             :          END DO
    3879        5964 :          DO i_img = 1, nimg
    3880        5628 :             idx = ri_data%img_to_idx(i_img)
    3881        5628 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_derivs(idx, i_xyz), move_data=.TRUE.)
    3882        5628 :             CALL get_tensor_occupancy(t_3c_derivs(idx, i_xyz), nze, occ)
    3883        5880 :             IF (nze > 0) ri_data%nimg_nze = MAX(idx, ri_data%nimg_nze)
    3884             :          END DO
    3885             :       END DO
    3886             : 
    3887        1960 :       DO i_img = 1, nimg
    3888        1960 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3889             :       END DO
    3890             : 
    3891          84 :       CALL timestop(handle)
    3892             : 
    3893        2044 :    END SUBROUTINE reorder_3c_derivs
    3894             : 
    3895             : ! **************************************************************************************************
    3896             : !> \brief Get the sparsity pattern related to the non-symmetric AO basis overlap neighbor list
    3897             : !> \param pattern ...
    3898             : !> \param ri_data ...
    3899             : !> \param qs_env ...
    3900             : ! **************************************************************************************************
    3901         256 :    SUBROUTINE get_sparsity_pattern(pattern, ri_data, qs_env)
    3902             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: pattern
    3903             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3904             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    3905             : 
    3906             :       INTEGER                                            :: iatom, j_img, jatom, mj_img, natom, nimg
    3907         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bins
    3908         256 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: tmp_pattern
    3909             :       INTEGER, DIMENSION(3)                              :: cell_j
    3910         256 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    3911         256 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    3912             :       TYPE(dft_control_type), POINTER                    :: dft_control
    3913             :       TYPE(kpoint_type), POINTER                         :: kpoints
    3914             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    3915             :       TYPE(neighbor_list_iterator_p_type), &
    3916         256 :          DIMENSION(:), POINTER                           :: nl_iterator
    3917             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    3918         256 :          POINTER                                         :: nl_2c
    3919             : 
    3920         256 :       NULLIFY (nl_2c, nl_iterator, kpoints, cell_to_index, dft_control, index_to_cell, para_env)
    3921             : 
    3922         256 :       CALL get_qs_env(qs_env, kpoints=kpoints, dft_control=dft_control, para_env=para_env, natom=natom)
    3923         256 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell, sab_nl=nl_2c)
    3924             : 
    3925         256 :       nimg = ri_data%nimg
    3926       43922 :       pattern(:, :, :) = 0
    3927             : 
    3928             :       !We use the symmetric nl for all images that have an opposite cell
    3929         256 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3930       11318 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3931       11062 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3932             : 
    3933       11062 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3934       11062 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3935             : 
    3936        7515 :          mj_img = get_opp_index(j_img, qs_env)
    3937        7515 :          IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3938             : 
    3939        7206 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3940             : 
    3941       11062 :          pattern(iatom, jatom, j_img) = 1
    3942             :       END DO
    3943         256 :       CALL neighbor_list_iterator_release(nl_iterator)
    3944             : 
    3945             :       !If there is no opposite cell present, then we take into account the non-symmetric nl
    3946         256 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=nl_2c)
    3947             : 
    3948         256 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3949       14800 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3950       14544 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3951             : 
    3952       14544 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3953       14544 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3954             : 
    3955        9597 :          mj_img = get_opp_index(j_img, qs_env)
    3956        9597 :          IF (mj_img .LE. nimg .AND. mj_img > 0) CYCLE
    3957             : 
    3958         321 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3959             : 
    3960       14544 :          pattern(iatom, jatom, j_img) = 1
    3961             :       END DO
    3962         256 :       CALL neighbor_list_iterator_release(nl_iterator)
    3963             : 
    3964       87588 :       CALL para_env%sum(pattern)
    3965             : 
    3966             :       !If the opposite image is considered, then there is no need to compute diagonal twice
    3967        6238 :       DO j_img = 2, nimg
    3968       18202 :          DO iatom = 1, natom
    3969       17946 :             IF (pattern(iatom, iatom, j_img) .NE. 0) THEN
    3970        4148 :                mj_img = get_opp_index(j_img, qs_env)
    3971        4148 :                IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3972        4148 :                pattern(iatom, iatom, mj_img) = 0
    3973             :             END IF
    3974             :          END DO
    3975             :       END DO
    3976             : 
    3977             :       ! We want to equilibrate the sparsity pattern such that there are same amount of blocks
    3978             :       ! for each atom i of i,j pairs
    3979         768 :       ALLOCATE (bins(natom))
    3980         768 :       bins(:) = 0
    3981             : 
    3982        1280 :       ALLOCATE (tmp_pattern(natom, natom, nimg))
    3983       43922 :       tmp_pattern(:, :, :) = 0
    3984        6494 :       DO j_img = 1, nimg
    3985       18970 :          DO jatom = 1, natom
    3986       43666 :             DO iatom = 1, natom
    3987       24952 :                IF (pattern(iatom, jatom, j_img) == 0) CYCLE
    3988        8374 :                mj_img = get_opp_index(j_img, qs_env)
    3989             : 
    3990             :                !Should we take the i,j,b or th j,i,-b atomic block?
    3991       20850 :                IF (mj_img > nimg .OR. mj_img < 1) THEN
    3992             :                   !No opposite image, no choice
    3993         214 :                   bins(iatom) = bins(iatom) + 1
    3994         214 :                   tmp_pattern(iatom, jatom, j_img) = 1
    3995             :                ELSE
    3996             : 
    3997        8160 :                   IF (bins(iatom) > bins(jatom)) THEN
    3998        1646 :                      bins(jatom) = bins(jatom) + 1
    3999        1646 :                      tmp_pattern(jatom, iatom, mj_img) = 1
    4000             :                   ELSE
    4001        6514 :                      bins(iatom) = bins(iatom) + 1
    4002        6514 :                      tmp_pattern(iatom, jatom, j_img) = 1
    4003             :                   END IF
    4004             :                END IF
    4005             :             END DO
    4006             :          END DO
    4007             :       END DO
    4008             : 
    4009             :       ! -1 => unoccupied, 0 => occupied
    4010       43922 :       pattern(:, :, :) = tmp_pattern(:, :, :) - 1
    4011             : 
    4012         512 :    END SUBROUTINE get_sparsity_pattern
    4013             : 
    4014             : ! **************************************************************************************************
    4015             : !> \brief Distribute the iatom, jatom, b_img triplet over the subgroupd to spread the load
    4016             : !>        the group id for each triplet is passed as the value of sparsity_pattern(i, j, b),
    4017             : !>        with -1 being an unoccupied block
    4018             : !> \param sparsity_pattern ...
    4019             : !> \param ngroups ...
    4020             : !> \param ri_data ...
    4021             : ! **************************************************************************************************
    4022         256 :    SUBROUTINE get_sub_dist(sparsity_pattern, ngroups, ri_data)
    4023             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: sparsity_pattern
    4024             :       INTEGER, INTENT(IN)                                :: ngroups
    4025             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4026             : 
    4027             :       INTEGER                                            :: b_img, ctr, iat, iatom, igroup, jatom, &
    4028             :                                                             natom, nimg, ub
    4029         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: max_at_per_group
    4030             :       REAL(dp)                                           :: cost
    4031         256 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4032             : 
    4033         256 :       natom = SIZE(sparsity_pattern, 2)
    4034         256 :       nimg = SIZE(sparsity_pattern, 3)
    4035             : 
    4036             :       !To avoid unnecessary data replication accross the subgroups, we want to have a limited number
    4037             :       !of subgroup with the data of a given iatom. At the minimum, all groups have 1 atom
    4038             :       !We assume that the cost associated to each iatom is roughly the same
    4039         256 :       IF (.NOT. ALLOCATED(ri_data%iatom_to_subgroup)) THEN
    4040         350 :          ALLOCATE (ri_data%iatom_to_subgroup(natom), max_at_per_group(ngroups))
    4041         150 :          DO iatom = 1, natom
    4042         100 :             NULLIFY (ri_data%iatom_to_subgroup(iatom)%array)
    4043         200 :             ALLOCATE (ri_data%iatom_to_subgroup(iatom)%array(ngroups))
    4044         350 :             ri_data%iatom_to_subgroup(iatom)%array(:) = .FALSE.
    4045             :          END DO
    4046             : 
    4047          50 :          ub = natom/ngroups
    4048          50 :          IF (ub*ngroups < natom) ub = ub + 1
    4049         150 :          max_at_per_group(:) = MAX(1, ub)
    4050             : 
    4051             :          !We want each atom to be present the same amount of times. Some groups might have more atoms
    4052             :          !than other to achieve this.
    4053             :          ctr = 0
    4054         150 :          DO WHILE (MODULO(SUM(max_at_per_group), natom) .NE. 0)
    4055           0 :             igroup = MODULO(ctr, ngroups) + 1
    4056           0 :             max_at_per_group(igroup) = max_at_per_group(igroup) + 1
    4057          50 :             ctr = ctr + 1
    4058             :          END DO
    4059             : 
    4060             :          ctr = 0
    4061         150 :          DO igroup = 1, ngroups
    4062         250 :             DO iat = 1, max_at_per_group(igroup)
    4063         100 :                iatom = MODULO(ctr, natom) + 1
    4064         100 :                ri_data%iatom_to_subgroup(iatom)%array(igroup) = .TRUE.
    4065         200 :                ctr = ctr + 1
    4066             :             END DO
    4067             :          END DO
    4068             :       END IF
    4069             : 
    4070         768 :       ALLOCATE (bins(ngroups))
    4071         768 :       bins = 0.0_dp
    4072        6494 :       DO b_img = 1, nimg
    4073       18970 :          DO jatom = 1, natom
    4074       43666 :             DO iatom = 1, natom
    4075       24952 :                IF (sparsity_pattern(iatom, jatom, b_img) == -1) CYCLE
    4076       41870 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4077             : 
    4078             :                !Use cost information from previous SCF if available
    4079      533682 :                IF (ANY(ri_data%kp_cost > EPSILON(0.0_dp))) THEN
    4080        6152 :                   cost = ri_data%kp_cost(iatom, jatom, b_img)
    4081             :                ELSE
    4082        2222 :                   cost = REAL(ri_data%bsizes_AO(iatom)*ri_data%bsizes_AO(jatom), dp)
    4083             :                END IF
    4084        8374 :                bins(igroup + 1) = bins(igroup + 1) + cost
    4085       37428 :                sparsity_pattern(iatom, jatom, b_img) = igroup
    4086             :             END DO
    4087             :          END DO
    4088             :       END DO
    4089             : 
    4090         256 :    END SUBROUTINE get_sub_dist
    4091             : 
    4092             : ! **************************************************************************************************
    4093             : !> \brief A rouine that updates the sparsity pattern for force calculation, where all i,j,b combinations
    4094             : !>        are visited.
    4095             : !> \param force_pattern ...
    4096             : !> \param scf_pattern ...
    4097             : !> \param ngroups ...
    4098             : !> \param ri_data ...
    4099             : !> \param qs_env ...
    4100             : ! **************************************************************************************************
    4101          42 :    SUBROUTINE update_pattern_to_forces(force_pattern, scf_pattern, ngroups, ri_data, qs_env)
    4102             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: force_pattern, scf_pattern
    4103             :       INTEGER, INTENT(IN)                                :: ngroups
    4104             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4105             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4106             : 
    4107             :       INTEGER                                            :: b_img, iatom, igroup, jatom, mb_img, &
    4108             :                                                             natom, nimg
    4109          42 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4110             : 
    4111          42 :       natom = SIZE(scf_pattern, 2)
    4112          42 :       nimg = SIZE(scf_pattern, 3)
    4113             : 
    4114         126 :       ALLOCATE (bins(ngroups))
    4115         126 :       bins = 0.0_dp
    4116             : 
    4117         980 :       DO b_img = 1, nimg
    4118         938 :          mb_img = get_opp_index(b_img, qs_env)
    4119        2856 :          DO jatom = 1, natom
    4120        6566 :             DO iatom = 1, natom
    4121             :                !Important: same distribution as KS matrix, because reuse t_3c_apc
    4122       18760 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4123             : 
    4124             :                !check that block not already treated
    4125        3752 :                IF (scf_pattern(iatom, jatom, b_img) > -1) CYCLE
    4126             : 
    4127             :                !If not, take the cost of block j, i, -b (same energy contribution)
    4128        4466 :                IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
    4129        2210 :                   IF (scf_pattern(jatom, iatom, mb_img) == -1) CYCLE
    4130        1042 :                   bins(igroup + 1) = bins(igroup + 1) + ri_data%kp_cost(jatom, iatom, mb_img)
    4131        1042 :                   force_pattern(iatom, jatom, b_img) = igroup
    4132             :                END IF
    4133             :             END DO
    4134             :          END DO
    4135             :       END DO
    4136             : 
    4137          42 :    END SUBROUTINE update_pattern_to_forces
    4138             : 
    4139             : ! **************************************************************************************************
    4140             : !> \brief A routine that determines the extend of the KP RI-HFX periodic images, including for the
    4141             : !>        extension of the RI basis
    4142             : !> \param ri_data ...
    4143             : !> \param qs_env ...
    4144             : ! **************************************************************************************************
    4145          70 :    SUBROUTINE get_kp_and_ri_images(ri_data, qs_env)
    4146             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4147             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4148             : 
    4149             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_kp_and_ri_images'
    4150             : 
    4151             :       INTEGER :: cell_j(3), cell_k(3), handle, i_img, iatom, ikind, j_img, jatom, jcell, katom, &
    4152             :          kcell, kp_index_lbounds(3), kp_index_ubounds(3), natom, ngroups, nimg, nkind, pcoord(3), &
    4153             :          pdims(3)
    4154          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist_AO_1, dist_AO_2, dist_RI, &
    4155          70 :                                                             nRI_per_atom, present_img, RI_cells
    4156          70 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    4157             :       REAL(dp)                                           :: bump_fact, dij, dik, image_range, &
    4158             :                                                             RI_range, rij(3), rik(3)
    4159         490 :       TYPE(dbt_type)                                     :: t_dummy
    4160             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4161             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4162             :       TYPE(distribution_3d_type)                         :: dist_3d
    4163             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4164          70 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4165             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4166          70 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4167             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4168             :       TYPE(neighbor_list_3c_iterator_type)               :: nl_3c_iter
    4169             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4170             :       TYPE(neighbor_list_iterator_p_type), &
    4171          70 :          DIMENSION(:), POINTER                           :: nl_iterator
    4172             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4173          70 :          POINTER                                         :: nl_2c
    4174          70 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4175          70 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4176             :       TYPE(section_vals_type), POINTER                   :: hfx_section
    4177             : 
    4178          70 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, nl_iterator, dft_control, &
    4179          70 :                particle_set, kpoints, para_env, cell_to_index, hfx_section)
    4180             : 
    4181          70 :       CALL timeset(routineN, handle)
    4182             : 
    4183             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, &
    4184             :                       dft_control=dft_control, particle_set=particle_set, kpoints=kpoints, &
    4185          70 :                       para_env=para_env, natom=natom)
    4186          70 :       nimg = dft_control%nimages
    4187          70 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index)
    4188         280 :       kp_index_lbounds = LBOUND(cell_to_index)
    4189         280 :       kp_index_ubounds = UBOUND(cell_to_index)
    4190             : 
    4191          70 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
    4192          70 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
    4193             : 
    4194         496 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4195          70 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4196          70 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4197             : 
    4198             :       !In case of shortrange HFX potential, it is imprtant to be consistent with the rest of the KP
    4199             :       !code, and use EPS_SCHWARZ to determine the range (rather than eps_filter_2c in normal RI-HFX)
    4200          70 :       IF (ri_data%hfx_pot%potential_type == do_potential_short) THEN
    4201           0 :          CALL erfc_cutoff(ri_data%eps_schwarz, ri_data%hfx_pot%omega, ri_data%hfx_pot%cutoff_radius)
    4202             :       END IF
    4203             : 
    4204             :       !Determine the range for contributing periodic images, and for the RI basis extension
    4205          70 :       ri_data%kp_RI_range = 0.0_dp
    4206          70 :       ri_data%kp_image_range = 0.0_dp
    4207         178 :       DO ikind = 1, nkind
    4208             : 
    4209         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4210         108 :          CALL get_gto_basis_set(basis_set_AO(ikind)%gto_basis_set, kind_radius=RI_range)
    4211         108 :          ri_data%kp_RI_range = MAX(RI_range, ri_data%kp_RI_range)
    4212             : 
    4213         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4214         108 :          CALL init_interaction_radii_orb_basis(basis_set_RI(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4215         108 :          CALL get_gto_basis_set(basis_set_RI(ikind)%gto_basis_set, kind_radius=image_range)
    4216             : 
    4217         108 :          image_range = 2.0_dp*image_range + cutoff_screen_factor*ri_data%hfx_pot%cutoff_radius
    4218         286 :          ri_data%kp_image_range = MAX(image_range, ri_data%kp_image_range)
    4219             :       END DO
    4220             : 
    4221          70 :       CALL section_vals_val_get(hfx_section, "KP_RI_BUMP_FACTOR", r_val=bump_fact)
    4222          70 :       ri_data%kp_bump_rad = bump_fact*ri_data%kp_RI_range
    4223             : 
    4224             :       !For the extent of the KP RI-HFX images, we are limited by the RI-HFX potential in
    4225             :       !(mu^0 sigma^a|P^0) (P^0|Q^b) (Q^b|nu^b lambda^a+c), if there is no contact between
    4226             :       !any P^0 and Q^b, then image b does not contribute
    4227             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4228          70 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4229             : 
    4230         210 :       ALLOCATE (present_img(nimg))
    4231        3120 :       present_img = 0
    4232          70 :       ri_data%nimg = 0
    4233          70 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    4234        1568 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    4235        1498 :          CALL get_iterator_info(nl_iterator, r=rij, cell=cell_j)
    4236             : 
    4237        5992 :          dij = NORM2(rij)
    4238             : 
    4239        1498 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4240        1498 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    4241             : 
    4242        1466 :          IF (dij > ri_data%kp_image_range) CYCLE
    4243             : 
    4244        1466 :          ri_data%nimg = MAX(j_img, ri_data%nimg)
    4245        1498 :          present_img(j_img) = 1
    4246             : 
    4247             :       END DO
    4248          70 :       CALL neighbor_list_iterator_release(nl_iterator)
    4249          70 :       CALL release_neighbor_list_sets(nl_2c)
    4250          70 :       CALL para_env%max(ri_data%nimg)
    4251          70 :       IF (ri_data%nimg > nimg) &
    4252           0 :          CPABORT("Make sure the smallest exponent of the RI-HFX basis is larger than that of the ORB basis.")
    4253             : 
    4254             :       !Keep track of which images will not contribute, so that can be ignored before calculation
    4255          70 :       CALL para_env%sum(present_img)
    4256         210 :       ALLOCATE (ri_data%present_images(ri_data%nimg))
    4257        1714 :       ri_data%present_images = 0
    4258        1714 :       DO i_img = 1, ri_data%nimg
    4259        1714 :          IF (present_img(i_img) > 0) ri_data%present_images(i_img) = 1
    4260             :       END DO
    4261             : 
    4262             :       CALL create_3c_tensor(t_dummy, dist_AO_1, dist_AO_2, dist_RI, &
    4263             :                             ri_data%pgrid, ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4264          70 :                             map1=[1, 2], map2=[3], name="(AO AO | RI)")
    4265             : 
    4266          70 :       CALL dbt_mp_environ_pgrid(ri_data%pgrid, pdims, pcoord)
    4267          70 :       CALL mp_comm_t3c%create(ri_data%pgrid%mp_comm_2d, 3, pdims)
    4268             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4269          70 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4270          70 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4271          70 :       CALL dbt_destroy(t_dummy)
    4272             : 
    4273             :       !For the extension of the RI basis P in (mu^0 sigma^a |P^i), we consider an atom if the distance,
    4274             :       !between mu^0 and P^i if smaller or equal to the kind radius of mu^0
    4275             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, &
    4276             :                                    ri_data%ri_metric, "HFX_3c_nl", qs_env, op_pos=2, sym_ij=.FALSE., &
    4277          70 :                                    own_dist=.TRUE.)
    4278             : 
    4279         140 :       ALLOCATE (RI_cells(nimg))
    4280        3120 :       RI_cells = 0
    4281             : 
    4282         210 :       ALLOCATE (nRI_per_atom(natom))
    4283         210 :       nRI_per_atom = 0
    4284             : 
    4285          70 :       CALL neighbor_list_3c_iterator_create(nl_3c_iter, nl_3c)
    4286       58584 :       DO WHILE (neighbor_list_3c_iterate(nl_3c_iter) == 0)
    4287             :          CALL get_3c_iterator_info(nl_3c_iter, cell_k=cell_k, rik=rik, cell_j=cell_j, &
    4288       58514 :                                    iatom=iatom, jatom=jatom, katom=katom)
    4289      234056 :          dik = NORM2(rik)
    4290             : 
    4291      409598 :          IF (ANY([cell_j(1), cell_j(2), cell_j(3)] < kp_index_lbounds) .OR. &
    4292             :              ANY([cell_j(1), cell_j(2), cell_j(3)] > kp_index_ubounds)) CYCLE
    4293             : 
    4294       58514 :          jcell = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4295       58514 :          IF (jcell > nimg .OR. jcell < 1) CYCLE
    4296             : 
    4297      386199 :          IF (ANY([cell_k(1), cell_k(2), cell_k(3)] < kp_index_lbounds) .OR. &
    4298             :              ANY([cell_k(1), cell_k(2), cell_k(3)] > kp_index_ubounds)) CYCLE
    4299             : 
    4300       51245 :          kcell = cell_to_index(cell_k(1), cell_k(2), cell_k(3))
    4301       51245 :          IF (kcell > nimg .OR. kcell < 1) CYCLE
    4302             : 
    4303       43587 :          IF (dik > ri_data%kp_RI_range) CYCLE
    4304        5751 :          RI_cells(kcell) = 1
    4305             : 
    4306        5821 :          IF (jcell == 1 .AND. iatom == jatom) nRI_per_atom(iatom) = nRI_per_atom(iatom) + ri_data%bsizes_RI(katom)
    4307             :       END DO
    4308          70 :       CALL neighbor_list_3c_iterator_destroy(nl_3c_iter)
    4309          70 :       CALL neighbor_list_3c_destroy(nl_3c)
    4310          70 :       CALL para_env%sum(RI_cells)
    4311          70 :       CALL para_env%sum(nRI_per_atom)
    4312             : 
    4313         140 :       ALLOCATE (ri_data%img_to_RI_cell(nimg))
    4314          70 :       ri_data%ncell_RI = 0
    4315        3120 :       ri_data%img_to_RI_cell = 0
    4316        3120 :       DO i_img = 1, nimg
    4317        3120 :          IF (RI_cells(i_img) > 0) THEN
    4318         436 :             ri_data%ncell_RI = ri_data%ncell_RI + 1
    4319         436 :             ri_data%img_to_RI_cell(i_img) = ri_data%ncell_RI
    4320             :          END IF
    4321             :       END DO
    4322             : 
    4323         210 :       ALLOCATE (ri_data%RI_cell_to_img(ri_data%ncell_RI))
    4324        3120 :       DO i_img = 1, nimg
    4325        3120 :          IF (ri_data%img_to_RI_cell(i_img) > 0) ri_data%RI_cell_to_img(ri_data%img_to_RI_cell(i_img)) = i_img
    4326             :       END DO
    4327             : 
    4328             :       !Print some info
    4329          70 :       IF (ri_data%unit_nr > 0) THEN
    4330             :          WRITE (ri_data%unit_nr, FMT="(/T3,A,I29)") &
    4331          35 :             "KP-HFX_RI_INFO| Number of RI-KP parallel groups:", ngroups
    4332             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F31.3,A)") &
    4333          35 :             "KP-HFX_RI_INFO| RI basis extension radius:", ri_data%kp_RI_range*angstrom, " Ang"
    4334             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F12.3,A, F6.3, A)") &
    4335          35 :             "KP-HFX_RI_INFO| RI basis bump factor and bump radius:", bump_fact, " /", &
    4336          70 :             ri_data%kp_bump_rad*angstrom, " Ang"
    4337             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I16,A)") &
    4338          35 :             "KP-HFX_RI_INFO| The extended RI bases cover up to ", ri_data%ncell_RI, " unit cells"
    4339             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I18)") &
    4340         105 :             "KP-HFX_RI_INFO| Average number of sgf in extended RI bases:", SUM(nRI_per_atom)/natom
    4341             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F13.3,A)") &
    4342          35 :             "KP-HFX_RI_INFO| Consider all image cells within a radius of ", ri_data%kp_image_range*angstrom, " Ang"
    4343             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I27/)") &
    4344          35 :             "KP-HFX_RI_INFO| Number of image cells considered: ", ri_data%nimg
    4345          35 :          CALL m_flush(ri_data%unit_nr)
    4346             :       END IF
    4347             : 
    4348          70 :       CALL timestop(handle)
    4349             : 
    4350         840 :    END SUBROUTINE get_kp_and_ri_images
    4351             : 
    4352             : ! **************************************************************************************************
    4353             : !> \brief A routine that creates tensors structure for rho_ao and 3c_ints in a stacked format for
    4354             : !>        the efficient contractions of rho_sigma^0,lambda^c * (mu^0 sigam^a | P) => TAS tensors
    4355             : !> \param res_stack ...
    4356             : !> \param rho_stack ...
    4357             : !> \param ints_stack ...
    4358             : !> \param rho_template ...
    4359             : !> \param ints_template ...
    4360             : !> \param stack_size ...
    4361             : !> \param ri_data ...
    4362             : !> \param qs_env ...
    4363             : !> \note The result tensor has the exact same shape and distribution as the integral tensor
    4364             : ! **************************************************************************************************
    4365         256 :    SUBROUTINE get_stack_tensors(res_stack, rho_stack, ints_stack, rho_template, ints_template, &
    4366             :                                 stack_size, ri_data, qs_env)
    4367             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: res_stack, rho_stack, ints_stack
    4368             :       TYPE(dbt_type), INTENT(INOUT)                      :: rho_template, ints_template
    4369             :       INTEGER, INTENT(IN)                                :: stack_size
    4370             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4371             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4372             : 
    4373             :       INTEGER                                            :: is, nblks, nblks_3c(3), pdims_3d(3)
    4374         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_stack, dist1, &
    4375         256 :                                                             dist2, dist3, dist_stack1, &
    4376         256 :                                                             dist_stack2, dist_stack3
    4377        2304 :       TYPE(dbt_distribution_type)                        :: t_dist
    4378         768 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4379             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4380             : 
    4381         256 :       NULLIFY (para_env)
    4382             : 
    4383         256 :       CALL get_qs_env(qs_env, para_env=para_env)
    4384             : 
    4385         256 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4386         768 :       ALLOCATE (bsizes_stack(stack_size*nblks))
    4387        3314 :       DO is = 1, stack_size
    4388       12476 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    4389             :       END DO
    4390             : 
    4391        2304 :       ALLOCATE (dist1(nblks), dist2(nblks), dist_stack1(stack_size*nblks), dist_stack2(stack_size*nblks))
    4392         256 :       CALL dbt_get_info(rho_template, proc_dist_1=dist1, proc_dist_2=dist2)
    4393        3314 :       DO is = 1, stack_size
    4394       12220 :          dist_stack1((is - 1)*nblks + 1:is*nblks) = dist1(:)
    4395       12476 :          dist_stack2((is - 1)*nblks + 1:is*nblks) = dist2(:)
    4396             :       END DO
    4397             : 
    4398             :       !First 2c tensor matches the distribution of template
    4399             :       !It is stacked in both directions
    4400         256 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_2d, dist_stack1, dist_stack2)
    4401         256 :       CALL dbt_create(rho_stack(1), "RHO_stack", t_dist, [1], [2], bsizes_stack, bsizes_stack)
    4402         256 :       CALL dbt_distribution_destroy(t_dist)
    4403         256 :       DEALLOCATE (dist1, dist2, dist_stack1, dist_stack2)
    4404             : 
    4405             :       !Second 2c tensor has optimal distribution on the 2d pgrid
    4406         256 :       CALL create_2c_tensor(rho_stack(2), dist1, dist2, ri_data%pgrid_2d, bsizes_stack, bsizes_stack, name="RHO_stack")
    4407         256 :       DEALLOCATE (dist1, dist2)
    4408             : 
    4409         256 :       CALL dbt_get_info(ints_template, nblks_total=nblks_3c)
    4410        1792 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)))
    4411        1280 :       ALLOCATE (dist_stack3(stack_size*nblks_3c(3)), bsizes_RI_ext(nblks_3c(2)))
    4412             :       CALL dbt_get_info(ints_template, proc_dist_1=dist1, proc_dist_2=dist2, &
    4413         256 :                         proc_dist_3=dist3, blk_size_2=bsizes_RI_ext)
    4414        3314 :       DO is = 1, stack_size
    4415       12476 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    4416             :       END DO
    4417             : 
    4418             :       !First 3c tensor matches the distribution of template
    4419         256 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_1, dist1, dist2, dist_stack3)
    4420             :       CALL dbt_create(ints_stack(1), "ints_stack", t_dist, [1, 2], [3], ri_data%bsizes_AO_split, &
    4421         256 :                       bsizes_RI_ext, bsizes_stack)
    4422         256 :       CALL dbt_distribution_destroy(t_dist)
    4423         256 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    4424             : 
    4425             :       !Second 3c tensor has optimal pgrid
    4426         256 :       pdims_3d = 0
    4427        1024 :       CALL dbt_pgrid_create(para_env, pdims_3d, pgrid, tensor_dims=[nblks_3c(1), nblks_3c(2), stack_size*nblks_3c(3)])
    4428             :       CALL create_3c_tensor(ints_stack(2), dist1, dist2, dist3, pgrid, ri_data%bsizes_AO_split, &
    4429         256 :                             bsizes_RI_ext, bsizes_stack, [1, 2], [3], name="ints_stack")
    4430         256 :       DEALLOCATE (dist1, dist2, dist3)
    4431         256 :       CALL dbt_pgrid_destroy(pgrid)
    4432             : 
    4433             :       !The result tensor has the same shape and dist as the integral tensor
    4434         256 :       CALL dbt_create(ints_stack(1), res_stack(1))
    4435         256 :       CALL dbt_create(ints_stack(2), res_stack(2))
    4436             : 
    4437         512 :    END SUBROUTINE get_stack_tensors
    4438             : 
    4439             : ! **************************************************************************************************
    4440             : !> \brief Fill the stack of 3c tensors accrding to the order in the images input
    4441             : !> \param t_3c_stack ...
    4442             : !> \param t_3c_in ...
    4443             : !> \param images ...
    4444             : !> \param stack_dim ...
    4445             : !> \param ri_data ...
    4446             : !> \param filter_at ...
    4447             : !> \param filter_dim ...
    4448             : !> \param idx_to_at ...
    4449             : !> \param img_bounds ...
    4450             : ! **************************************************************************************************
    4451       21893 :    SUBROUTINE fill_3c_stack(t_3c_stack, t_3c_in, images, stack_dim, ri_data, filter_at, filter_dim, &
    4452       21893 :                             idx_to_at, img_bounds)
    4453             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_stack
    4454             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_in
    4455             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4456             :       INTEGER, INTENT(IN)                                :: stack_dim
    4457             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4458             :       INTEGER, INTENT(IN), OPTIONAL                      :: filter_at, filter_dim
    4459             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: idx_to_at
    4460             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2)
    4461             : 
    4462             :       INTEGER                                            :: dest(3), i_img, idx, ind(3), lb, nblks, &
    4463             :                                                             nimg, offset, ub
    4464             :       LOGICAL                                            :: do_filter, found
    4465       21893 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4466             :       TYPE(dbt_iterator_type)                            :: iter
    4467             : 
    4468             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4469             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4470       21893 :       nimg = ri_data%nimg
    4471       21893 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4472             : 
    4473       21893 :       do_filter = .FALSE.
    4474       21319 :       IF (PRESENT(filter_at) .AND. PRESENT(filter_dim) .AND. PRESENT(idx_to_at)) do_filter = .TRUE.
    4475             : 
    4476       21893 :       lb = 1
    4477       21893 :       ub = nimg
    4478       21893 :       offset = 0
    4479       21893 :       IF (PRESENT(img_bounds)) THEN
    4480       21893 :          lb = img_bounds(1)
    4481       21893 :          ub = img_bounds(2) - 1
    4482       21893 :          offset = lb - 1
    4483             :       END IF
    4484             : 
    4485      433835 :       DO idx = lb, ub
    4486      411942 :          i_img = images(idx)
    4487      411942 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4488             : 
    4489             : !$OMP PARALLEL DEFAULT(NONE) &
    4490             : !$OMP SHARED(idx,i_img,t_3c_in,t_3c_stack,nblks,stack_dim,filter_at,filter_dim,idx_to_at,do_filter,offset) &
    4491      433835 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4492             :          CALL dbt_iterator_start(iter, t_3c_in(i_img))
    4493             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4494             :             CALL dbt_iterator_next_block(iter, ind)
    4495             :             CALL dbt_get_block(t_3c_in(i_img), ind, blk, found)
    4496             :             IF (.NOT. found) CYCLE
    4497             : 
    4498             :             IF (do_filter) THEN
    4499             :                IF (.NOT. idx_to_at(ind(filter_dim)) == filter_at) CYCLE
    4500             :             END IF
    4501             : 
    4502             :             IF (stack_dim == 1) THEN
    4503             :                dest = [(idx - offset - 1)*nblks + ind(1), ind(2), ind(3)]
    4504             :             ELSE IF (stack_dim == 2) THEN
    4505             :                dest = [ind(1), (idx - offset - 1)*nblks + ind(2), ind(3)]
    4506             :             ELSE
    4507             :                dest = [ind(1), ind(2), (idx - offset - 1)*nblks + ind(3)]
    4508             :             END IF
    4509             : 
    4510             :             CALL dbt_put_block(t_3c_stack, dest, SHAPE(blk), blk)
    4511             :             DEALLOCATE (blk)
    4512             :          END DO
    4513             :          CALL dbt_iterator_stop(iter)
    4514             : !$OMP END PARALLEL
    4515             :       END DO !i_img
    4516       21893 :       CALL dbt_finalize(t_3c_stack)
    4517             : 
    4518       43786 :    END SUBROUTINE fill_3c_stack
    4519             : 
    4520             : ! **************************************************************************************************
    4521             : !> \brief Fill the stack of 2c tensors based on the content of images input
    4522             : !> \param t_2c_stack ...
    4523             : !> \param t_2c_in ...
    4524             : !> \param images ...
    4525             : !> \param stack_dim ...
    4526             : !> \param ri_data ...
    4527             : !> \param img_bounds ...
    4528             : !> \param shift ...
    4529             : ! **************************************************************************************************
    4530       16412 :    SUBROUTINE fill_2c_stack(t_2c_stack, t_2c_in, images, stack_dim, ri_data, img_bounds, shift)
    4531             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_stack
    4532             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_in
    4533             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4534             :       INTEGER, INTENT(IN)                                :: stack_dim
    4535             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4536             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2), shift
    4537             : 
    4538             :       INTEGER                                            :: dest(2), i_img, idx, ind(2), lb, &
    4539             :                                                             my_shift, nblks, nimg, offset, ub
    4540             :       LOGICAL                                            :: found
    4541       16412 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    4542             :       TYPE(dbt_iterator_type)                            :: iter
    4543             : 
    4544             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4545             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4546       16412 :       nimg = ri_data%nimg
    4547       16412 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4548             : 
    4549       16412 :       lb = 1
    4550       16412 :       ub = nimg
    4551       16412 :       offset = 0
    4552       16412 :       IF (PRESENT(img_bounds)) THEN
    4553       16412 :          lb = img_bounds(1)
    4554       16412 :          ub = img_bounds(2) - 1
    4555       16412 :          offset = lb - 1
    4556             :       END IF
    4557             : 
    4558       16412 :       my_shift = 1
    4559       16412 :       IF (PRESENT(shift)) my_shift = shift
    4560             : 
    4561      204304 :       DO idx = lb, ub
    4562      187892 :          i_img = images(idx)
    4563      187892 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4564             : 
    4565             : !$OMP PARALLEL DEFAULT(NONE) SHARED(idx,i_img,t_2c_in,t_2c_stack,nblks,stack_dim,offset,my_shift) &
    4566      204304 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4567             :          CALL dbt_iterator_start(iter, t_2c_in(i_img))
    4568             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4569             :             CALL dbt_iterator_next_block(iter, ind)
    4570             :             CALL dbt_get_block(t_2c_in(i_img), ind, blk, found)
    4571             :             IF (.NOT. found) CYCLE
    4572             : 
    4573             :             IF (stack_dim == 1) THEN
    4574             :                dest = [(idx - offset - 1)*nblks + ind(1), (my_shift - 1)*nblks + ind(2)]
    4575             :             ELSE
    4576             :                dest = [(my_shift - 1)*nblks + ind(1), (idx - offset - 1)*nblks + ind(2)]
    4577             :             END IF
    4578             : 
    4579             :             CALL dbt_put_block(t_2c_stack, dest, SHAPE(blk), blk)
    4580             :             DEALLOCATE (blk)
    4581             :          END DO
    4582             :          CALL dbt_iterator_stop(iter)
    4583             : !$OMP END PARALLEL
    4584             :       END DO !idx
    4585       16412 :       CALL dbt_finalize(t_2c_stack)
    4586             : 
    4587       32824 :    END SUBROUTINE fill_2c_stack
    4588             : 
    4589             : ! **************************************************************************************************
    4590             : !> \brief Unstacks a stacked 3c tensor containing t_3c_apc
    4591             : !> \param t_3c_apc ...
    4592             : !> \param t_stacked ...
    4593             : !> \param idx ...
    4594             : ! **************************************************************************************************
    4595       18838 :    SUBROUTINE unstack_t_3c_apc(t_3c_apc, t_stacked, idx)
    4596             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_apc, t_stacked
    4597             :       INTEGER, INTENT(IN)                                :: idx
    4598             : 
    4599             :       INTEGER                                            :: current_idx
    4600             :       INTEGER, DIMENSION(3)                              :: ind, nblks_3c
    4601             :       LOGICAL                                            :: found
    4602       18838 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4603             :       TYPE(dbt_iterator_type)                            :: iter
    4604             : 
    4605             :       !Note: t_3c_apc and t_stacked must have the same ditribution
    4606       18838 :       CALL dbt_get_info(t_3c_apc, nblks_total=nblks_3c)
    4607             : 
    4608       18838 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_apc,t_stacked,idx,nblks_3c) PRIVATE(iter,ind,blk,found,current_idx)
    4609             :       CALL dbt_iterator_start(iter, t_stacked)
    4610             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4611             :          CALL dbt_iterator_next_block(iter, ind)
    4612             : 
    4613             :          !tensor is stacked along the 3rd dimension
    4614             :          current_idx = (ind(3) - 1)/nblks_3c(3) + 1
    4615             :          IF (.NOT. idx == current_idx) CYCLE
    4616             : 
    4617             :          CALL dbt_get_block(t_stacked, ind, blk, found)
    4618             :          IF (.NOT. found) CYCLE
    4619             : 
    4620             :          CALL dbt_put_block(t_3c_apc, [ind(1), ind(2), ind(3) - (idx - 1)*nblks_3c(3)], SHAPE(blk), blk)
    4621             :          DEALLOCATE (blk)
    4622             :       END DO
    4623             :       CALL dbt_iterator_stop(iter)
    4624             : !$OMP END PARALLEL
    4625             : 
    4626       18838 :    END SUBROUTINE unstack_t_3c_apc
    4627             : 
    4628             : ! **************************************************************************************************
    4629             : !> \brief copies the 3c integrals correspoinding to a single atom mu from the general (P^0| mu^0 sigam^a)
    4630             : !> \param t_3c_at ...
    4631             : !> \param t_3c_ints ...
    4632             : !> \param iatom ...
    4633             : !> \param dim_at ...
    4634             : !> \param idx_to_at ...
    4635             : ! **************************************************************************************************
    4636           0 :    SUBROUTINE get_atom_3c_ints(t_3c_at, t_3c_ints, iatom, dim_at, idx_to_at)
    4637             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_at, t_3c_ints
    4638             :       INTEGER, INTENT(IN)                                :: iatom, dim_at
    4639             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: idx_to_at
    4640             : 
    4641             :       INTEGER, DIMENSION(3)                              :: ind
    4642             :       LOGICAL                                            :: found
    4643           0 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4644             :       TYPE(dbt_iterator_type)                            :: iter
    4645             : 
    4646           0 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_ints,t_3c_at,iatom,idx_to_at,dim_at) PRIVATE(iter,ind,blk,found)
    4647             :       CALL dbt_iterator_start(iter, t_3c_ints)
    4648             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4649             :          CALL dbt_iterator_next_block(iter, ind)
    4650             :          IF (.NOT. idx_to_at(ind(dim_at)) == iatom) CYCLE
    4651             : 
    4652             :          CALL dbt_get_block(t_3c_ints, ind, blk, found)
    4653             :          IF (.NOT. found) CYCLE
    4654             : 
    4655             :          CALL dbt_put_block(t_3c_at, ind, SHAPE(blk), blk)
    4656             :          DEALLOCATE (blk)
    4657             :       END DO
    4658             :       CALL dbt_iterator_stop(iter)
    4659             : !$OMP END PARALLEL
    4660           0 :       CALL dbt_finalize(t_3c_at)
    4661             : 
    4662           0 :    END SUBROUTINE get_atom_3c_ints
    4663             : 
    4664             : ! **************************************************************************************************
    4665             : !> \brief Precalculate the 3c and 2c derivatives tensors
    4666             : !> \param t_3c_der_RI ...
    4667             : !> \param t_3c_der_AO ...
    4668             : !> \param mat_der_pot ...
    4669             : !> \param t_2c_der_metric ...
    4670             : !> \param ri_data ...
    4671             : !> \param qs_env ...
    4672             : ! **************************************************************************************************
    4673          42 :    SUBROUTINE precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
    4674             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_RI, t_3c_der_AO
    4675             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot
    4676             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_2c_der_metric
    4677             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4678             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4679             : 
    4680             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'precalc_derivatives'
    4681             : 
    4682             :       INTEGER                                            :: handle, handle2, i_img, i_mem, i_RI, &
    4683             :                                                             i_xyz, iatom, n_mem, natom, nblks_RI, &
    4684             :                                                             ncell_RI, nimg, nkind, nthreads
    4685             :       INTEGER(int_8)                                     :: nze
    4686          42 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: bsizes_RI_ext, bsizes_RI_ext_split, dist_AO_1, &
    4687          84 :          dist_AO_2, dist_RI, dist_RI_ext, dummy_end, dummy_start, end_blocks, start_blocks
    4688             :       INTEGER, DIMENSION(3)                              :: pcoord, pdims
    4689          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_bsize, row_bsize
    4690             :       REAL(dp)                                           :: occ
    4691             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist
    4692             :       TYPE(dbcsr_type)                                   :: dbcsr_template
    4693          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_metric
    4694         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    4695         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4696         378 :       TYPE(dbt_type)                                     :: t_3c_template
    4697          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :, :)    :: t_3c_der_AO_prv, t_3c_der_RI_prv
    4698             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4699             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4700             :       TYPE(distribution_3d_type)                         :: dist_3d
    4701             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4702          42 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4703          42 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4704             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4705             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4706             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4707          42 :          POINTER                                         :: nl_2c
    4708          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4709          42 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4710             : 
    4711          42 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, particle_set, dft_control, para_env, row_bsize, col_bsize)
    4712             : 
    4713          42 :       CALL timeset(routineN, handle)
    4714             : 
    4715             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, natom=natom, &
    4716          42 :                       particle_set=particle_set, dft_control=dft_control, para_env=para_env)
    4717             : 
    4718          42 :       nimg = ri_data%nimg
    4719          42 :       ncell_RI = ri_data%ncell_RI
    4720             : 
    4721         300 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4722          42 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4723          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_RI)
    4724          42 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4725          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_AO)
    4726             : 
    4727             :       !Dealing with the 3c derivatives
    4728          42 :       nthreads = 1
    4729          42 : !$    nthreads = omp_get_num_threads()
    4730          42 :       pdims = 0
    4731         168 :       CALL dbt_pgrid_create(para_env, pdims, pgrid, tensor_dims=[MAX(1, natom/(ri_data%n_mem*nthreads)), natom, natom])
    4732             : 
    4733             :       CALL create_3c_tensor(t_3c_template, dist_AO_1, dist_AO_2, dist_RI, pgrid, &
    4734             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4735          42 :                             map1=[1, 2], map2=[3], name="tmp")
    4736          42 :       CALL dbt_destroy(t_3c_template)
    4737             : 
    4738             :       !We stack the RI basis images. Keep consistent distribution
    4739          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    4740         126 :       ALLOCATE (dist_RI_ext(natom*ncell_RI))
    4741          84 :       ALLOCATE (bsizes_RI_ext(natom*ncell_RI))
    4742         126 :       ALLOCATE (bsizes_RI_ext_split(nblks_RI*ncell_RI))
    4743         294 :       DO i_RI = 1, ncell_RI
    4744         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    4745         756 :          dist_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = dist_RI(:)
    4746        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    4747             :       END DO
    4748             : 
    4749          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist_AO_1, dist_AO_2, dist_RI_ext)
    4750             :       CALL dbt_create(t_3c_template, "KP_3c_der", t_dist, [1, 2], [3], &
    4751          42 :                       ri_data%bsizes_AO, ri_data%bsizes_AO, bsizes_RI_ext)
    4752          42 :       CALL dbt_distribution_destroy(t_dist)
    4753             : 
    4754        6972 :       ALLOCATE (t_3c_der_RI_prv(nimg, 1, 3), t_3c_der_AO_prv(nimg, 1, 3))
    4755         168 :       DO i_xyz = 1, 3
    4756        2982 :          DO i_img = 1, nimg
    4757        2814 :             CALL dbt_create(t_3c_template, t_3c_der_RI_prv(i_img, 1, i_xyz))
    4758        2940 :             CALL dbt_create(t_3c_template, t_3c_der_AO_prv(i_img, 1, i_xyz))
    4759             :          END DO
    4760             :       END DO
    4761          42 :       CALL dbt_destroy(t_3c_template)
    4762             : 
    4763          42 :       CALL dbt_mp_environ_pgrid(pgrid, pdims, pcoord)
    4764          42 :       CALL mp_comm_t3c%create(pgrid%mp_comm_2d, 3, pdims)
    4765             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4766          42 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4767          42 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4768          42 :       CALL dbt_pgrid_destroy(pgrid)
    4769             : 
    4770             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, ri_data%ri_metric, &
    4771          42 :                                    "HFX_3c_nl", qs_env, op_pos=2, sym_jk=.FALSE., own_dist=.TRUE.)
    4772             : 
    4773          42 :       n_mem = ri_data%n_mem
    4774             :       CALL create_tensor_batches(ri_data%bsizes_RI, n_mem, dummy_start, dummy_end, &
    4775             :                                  start_blocks, end_blocks)
    4776          42 :       DEALLOCATE (dummy_start, dummy_end)
    4777             : 
    4778             :       CALL create_3c_tensor(t_3c_template, dist_RI, dist_AO_1, dist_AO_2, ri_data%pgrid_2, &
    4779             :                             bsizes_RI_ext_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    4780          42 :                             map1=[1], map2=[2, 3], name="der (RI | AO AO)")
    4781         168 :       DO i_xyz = 1, 3
    4782        2982 :          DO i_img = 1, nimg
    4783        2814 :             CALL dbt_create(t_3c_template, t_3c_der_RI(i_img, i_xyz))
    4784        2940 :             CALL dbt_create(t_3c_template, t_3c_der_AO(i_img, i_xyz))
    4785             :          END DO
    4786             :       END DO
    4787             : 
    4788         116 :       DO i_mem = 1, n_mem
    4789             :          CALL build_3c_derivatives(t_3c_der_AO_prv, t_3c_der_RI_prv, ri_data%filter_eps, qs_env, &
    4790             :                                    nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, &
    4791             :                                    ri_data%ri_metric, der_eps=ri_data%eps_schwarz_forces, op_pos=2, &
    4792             :                                    do_kpoints=.TRUE., do_hfx_kpoints=.TRUE., &
    4793             :                                    bounds_k=[start_blocks(i_mem), end_blocks(i_mem)], &
    4794         222 :                                    RI_range=ri_data%kp_RI_range, img_to_RI_cell=ri_data%img_to_RI_cell)
    4795             : 
    4796          74 :          CALL timeset(routineN//"_cpy", handle2)
    4797             :          !We go from (mu^0 sigma^i | P^j) to (P^i| sigma^j mu^0) and finally to (P^i| mu^0 sigma^j)
    4798        1850 :          DO i_img = 1, nimg
    4799        7178 :             DO i_xyz = 1, 3
    4800             :                !derivative wrt to mu^0
    4801        5328 :                CALL get_tensor_occupancy(t_3c_der_AO_prv(i_img, 1, i_xyz), nze, occ)
    4802        5328 :                IF (nze > 0) THEN
    4803             :                   CALL dbt_copy(t_3c_der_AO_prv(i_img, 1, i_xyz), t_3c_template, &
    4804        3454 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4805        3454 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4806             :                   CALL dbt_copy(t_3c_template, t_3c_der_AO(i_img, i_xyz), &
    4807        3454 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4808             :                END IF
    4809             : 
    4810             :                !derivative wrt to P^i
    4811        5328 :                CALL get_tensor_occupancy(t_3c_der_RI_prv(i_img, 1, i_xyz), nze, occ)
    4812       12432 :                IF (nze > 0) THEN
    4813             :                   CALL dbt_copy(t_3c_der_RI_prv(i_img, 1, i_xyz), t_3c_template, &
    4814        3452 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4815        3452 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4816             :                   CALL dbt_copy(t_3c_template, t_3c_der_RI(i_img, i_xyz), &
    4817        3452 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4818             :                END IF
    4819             :             END DO
    4820             :          END DO
    4821         190 :          CALL timestop(handle2)
    4822             :       END DO
    4823          42 :       CALL dbt_destroy(t_3c_template)
    4824             : 
    4825          42 :       CALL neighbor_list_3c_destroy(nl_3c)
    4826         168 :       DO i_xyz = 1, 3
    4827        2982 :          DO i_img = 1, nimg
    4828        2814 :             CALL dbt_destroy(t_3c_der_RI_prv(i_img, 1, i_xyz))
    4829        2940 :             CALL dbt_destroy(t_3c_der_AO_prv(i_img, 1, i_xyz))
    4830             :          END DO
    4831             :       END DO
    4832        5670 :       DEALLOCATE (t_3c_der_RI_prv, t_3c_der_AO_prv)
    4833             : 
    4834             :       !Reorder 3c derivatives to be consistant with ints
    4835          42 :       CALL reorder_3c_derivs(t_3c_der_RI, ri_data)
    4836          42 :       CALL reorder_3c_derivs(t_3c_der_AO, ri_data)
    4837             : 
    4838          42 :       CALL timeset(routineN//"_2c", handle2)
    4839             :       !The 2-center derivatives
    4840          42 :       CALL cp_dbcsr_dist2d_to_dist(dist_2d, dbcsr_dist)
    4841         126 :       ALLOCATE (row_bsize(SIZE(ri_data%bsizes_RI)))
    4842          84 :       ALLOCATE (col_bsize(SIZE(ri_data%bsizes_RI)))
    4843         126 :       row_bsize(:) = ri_data%bsizes_RI
    4844         126 :       col_bsize(:) = ri_data%bsizes_RI
    4845             : 
    4846             :       CALL dbcsr_create(dbcsr_template, "2c_der", dbcsr_dist, dbcsr_type_no_symmetry, &
    4847          42 :                         row_bsize, col_bsize)
    4848          42 :       CALL dbcsr_distribution_release(dbcsr_dist)
    4849          42 :       DEALLOCATE (col_bsize, row_bsize)
    4850             : 
    4851        3066 :       ALLOCATE (mat_der_metric(nimg, 3))
    4852         168 :       DO i_xyz = 1, 3
    4853        2982 :          DO i_img = 1, nimg
    4854        2814 :             CALL dbcsr_create(mat_der_pot(i_img, i_xyz), template=dbcsr_template)
    4855        2940 :             CALL dbcsr_create(mat_der_metric(i_img, i_xyz), template=dbcsr_template)
    4856             :          END DO
    4857             :       END DO
    4858          42 :       CALL dbcsr_release(dbcsr_template)
    4859             : 
    4860             :       !HFX potential derivatives
    4861             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4862          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4863             :       CALL build_2c_derivatives(mat_der_pot, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4864          42 :                                 basis_set_RI, basis_set_RI, ri_data%hfx_pot, do_kpoints=.TRUE.)
    4865          42 :       CALL release_neighbor_list_sets(nl_2c)
    4866             : 
    4867             :       !RI metric derivatives
    4868             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    4869          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4870             :       CALL build_2c_derivatives(mat_der_metric, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4871          42 :                                 basis_set_RI, basis_set_RI, ri_data%ri_metric, do_kpoints=.TRUE.)
    4872          42 :       CALL release_neighbor_list_sets(nl_2c)
    4873             : 
    4874             :       !Get into extended RI basis and tensor format
    4875         168 :       DO i_xyz = 1, 3
    4876         378 :          DO iatom = 1, natom
    4877         252 :             CALL dbt_create(ri_data%t_2c_inv(1, 1), t_2c_der_metric(iatom, i_xyz))
    4878             :             CALL get_ext_2c_int(t_2c_der_metric(iatom, i_xyz), mat_der_metric(:, i_xyz), &
    4879         378 :                                 iatom, iatom, 1, ri_data, qs_env)
    4880             :          END DO
    4881        2982 :          DO i_img = 1, nimg
    4882        2940 :             CALL dbcsr_release(mat_der_metric(i_img, i_xyz))
    4883             :          END DO
    4884             :       END DO
    4885          42 :       CALL timestop(handle2)
    4886             : 
    4887          42 :       CALL timestop(handle)
    4888             : 
    4889         252 :    END SUBROUTINE precalc_derivatives
    4890             : 
    4891             : ! **************************************************************************************************
    4892             : !> \brief Update the forces due to the derivative of the a 2-center product d/dR (Q|R)
    4893             : !> \param force ...
    4894             : !> \param t_2c_contr A precontracted tensor containing sum_abcdPS (ab|P)(P|Q)^-1 (R|S)^-1 (S|cd) P_ac P_bd
    4895             : !> \param t_2c_der the d/dR (Q|R) tensor, in all 3 cartesian directions
    4896             : !> \param atom_of_kind ...
    4897             : !> \param kind_of ...
    4898             : !> \param img in which periodic image the second center of the tensor is
    4899             : !> \param pref ...
    4900             : !> \param ri_data ...
    4901             : !> \param qs_env ...
    4902             : !> \param work_virial ...
    4903             : !> \param cell ...
    4904             : !> \param particle_set ...
    4905             : !> \param diag ...
    4906             : !> \param offdiag ...
    4907             : !> \note IMPORTANT: t_tc_contr and t_2c_der need to have the same distribution. Atomic block sizes are
    4908             : !>                  assumed
    4909             : ! **************************************************************************************************
    4910        2905 :    SUBROUTINE get_2c_der_force(force, t_2c_contr, t_2c_der, atom_of_kind, kind_of, img, pref, &
    4911             :                                ri_data, qs_env, work_virial, cell, particle_set, diag, offdiag)
    4912             : 
    4913             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    4914             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_contr
    4915             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_der
    4916             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    4917             :       INTEGER, INTENT(IN)                                :: img
    4918             :       REAL(dp), INTENT(IN)                               :: pref
    4919             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4920             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4921             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    4922             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    4923             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    4924             :          POINTER                                         :: particle_set
    4925             :       LOGICAL, INTENT(IN), OPTIONAL                      :: diag, offdiag
    4926             : 
    4927             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_2c_der_force'
    4928             : 
    4929             :       INTEGER                                            :: handle, i_img, i_RI, i_xyz, iat, &
    4930             :                                                             iat_of_kind, ikind, j_img, j_RI, &
    4931             :                                                             j_xyz, jat, jat_of_kind, jkind, natom
    4932             :       INTEGER, DIMENSION(2)                              :: ind
    4933        2905 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    4934             :       LOGICAL                                            :: found, my_diag, my_offdiag, use_virial
    4935             :       REAL(dp)                                           :: new_force
    4936        2905 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :), TARGET     :: contr_blk, der_blk
    4937             :       REAL(dp), DIMENSION(3)                             :: scoord
    4938             :       TYPE(dbt_iterator_type)                            :: iter
    4939             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4940             : 
    4941        2905 :       NULLIFY (kpoints, index_to_cell)
    4942             : 
    4943             :       !Loop over the blocks of d/dR (Q|R), contract with the corresponding block of t_2c_contr and
    4944             :       !update the relevant force
    4945             : 
    4946        2905 :       CALL timeset(routineN, handle)
    4947             : 
    4948        2905 :       use_virial = .FALSE.
    4949        2905 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    4950             : 
    4951        2905 :       my_diag = .FALSE.
    4952        2905 :       IF (PRESENT(diag)) my_diag = diag
    4953             : 
    4954        2324 :       my_offdiag = .FALSE.
    4955        2324 :       IF (PRESENT(diag)) my_offdiag = offdiag
    4956             : 
    4957        2905 :       CALL get_qs_env(qs_env, kpoints=kpoints, natom=natom)
    4958        2905 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    4959             : 
    4960             : !$OMP PARALLEL DEFAULT(NONE) &
    4961             : !$OMP SHARED(t_2c_der,t_2c_contr,work_virial,force,use_virial,natom,index_to_cell,ri_data,img) &
    4962             : !$OMP SHARED(pref,atom_of_kind,kind_of,particle_set,cell,my_diag,my_offdiag) &
    4963             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk,contr_blk,found,new_force,i_RI,i_img,j_RI,j_img) &
    4964        2905 : !$OMP PRIVATE(iat,jat,iat_of_kind,jat_of_kind,ikind,jkind,scoord)
    4965             :       DO i_xyz = 1, 3
    4966             :          CALL dbt_iterator_start(iter, t_2c_der(i_xyz))
    4967             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4968             :             CALL dbt_iterator_next_block(iter, ind)
    4969             : 
    4970             :             !Only take forecs due to block diagonal or block off-diagonal, depending on arguments
    4971             :             IF ((my_diag .AND. .NOT. my_offdiag) .OR. (.NOT. my_diag .AND. my_offdiag)) THEN
    4972             :                IF (my_diag .AND. (ind(1) .NE. ind(2))) CYCLE
    4973             :                IF (my_offdiag .AND. (ind(1) == ind(2))) CYCLE
    4974             :             END IF
    4975             : 
    4976             :             CALL dbt_get_block(t_2c_der(i_xyz), ind, der_blk, found)
    4977             :             CPASSERT(found)
    4978             :             CALL dbt_get_block(t_2c_contr, ind, contr_blk, found)
    4979             : 
    4980             :             IF (found) THEN
    4981             : 
    4982             :                !an element of d/dR (Q|R) corresponds to 2 things because of translational invariance
    4983             :                !(Q'| R) = - (Q| R'), once wrt the center on Q, and once on R
    4984             :                new_force = pref*SUM(der_blk(:, :)*contr_blk(:, :))
    4985             : 
    4986             :                i_RI = (ind(1) - 1)/natom + 1
    4987             :                i_img = ri_data%RI_cell_to_img(i_RI)
    4988             :                iat = ind(1) - (i_RI - 1)*natom
    4989             :                iat_of_kind = atom_of_kind(iat)
    4990             :                ikind = kind_of(iat)
    4991             : 
    4992             :                j_RI = (ind(2) - 1)/natom + 1
    4993             :                j_img = ri_data%RI_cell_to_img(j_RI)
    4994             :                jat = ind(2) - (j_RI - 1)*natom
    4995             :                jat_of_kind = atom_of_kind(jat)
    4996             :                jkind = kind_of(jat)
    4997             : 
    4998             :                !Force on iatom (first center)
    4999             : !$OMP ATOMIC
    5000             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5001             :                                                           + new_force
    5002             : 
    5003             :                IF (use_virial) THEN
    5004             : 
    5005             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5006             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_img), dp)
    5007             : 
    5008             :                   DO j_xyz = 1, 3
    5009             : !$OMP ATOMIC
    5010             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5011             :                   END DO
    5012             :                END IF
    5013             : 
    5014             :                !Force on jatom (second center)
    5015             : !$OMP ATOMIC
    5016             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5017             :                                                           - new_force
    5018             : 
    5019             :                IF (use_virial) THEN
    5020             : 
    5021             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5022             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, j_img) + index_to_cell(:, img), dp)
    5023             : 
    5024             :                   DO j_xyz = 1, 3
    5025             : !$OMP ATOMIC
    5026             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) - new_force*scoord(j_xyz)
    5027             :                   END DO
    5028             :                END IF
    5029             : 
    5030             :                DEALLOCATE (contr_blk)
    5031             :             END IF
    5032             : 
    5033             :             DEALLOCATE (der_blk)
    5034             :          END DO !iter
    5035             :          CALL dbt_iterator_stop(iter)
    5036             : 
    5037             :       END DO !i_xyz
    5038             : !$OMP END PARALLEL
    5039        2905 :       CALL timestop(handle)
    5040             : 
    5041        5810 :    END SUBROUTINE get_2c_der_force
    5042             : 
    5043             : ! **************************************************************************************************
    5044             : !> \brief This routines calculates the force contribution from a trace over 3D tensors, i.e.
    5045             : !>        force = sum_ijk A_ijk B_ijk., the B tensor is (P^0| sigma^0 lambda^img), with P in the
    5046             : !>        extended RI basis. Note that all tensors are stacked along the 3rd dimension
    5047             : !> \param force ...
    5048             : !> \param t_3c_contr ...
    5049             : !> \param t_3c_der_1 ...
    5050             : !> \param t_3c_der_2 ...
    5051             : !> \param atom_of_kind ...
    5052             : !> \param kind_of ...
    5053             : !> \param idx_to_at_RI ...
    5054             : !> \param idx_to_at_AO ...
    5055             : !> \param i_images ...
    5056             : !> \param lb_img ...
    5057             : !> \param pref ...
    5058             : !> \param ri_data ...
    5059             : !> \param qs_env ...
    5060             : !> \param work_virial ...
    5061             : !> \param cell ...
    5062             : !> \param particle_set ...
    5063             : ! **************************************************************************************************
    5064        1525 :    SUBROUTINE get_force_from_3c_trace(force, t_3c_contr, t_3c_der_1, t_3c_der_2, atom_of_kind, kind_of, &
    5065        3050 :                                       idx_to_at_RI, idx_to_at_AO, i_images, lb_img, pref, &
    5066             :                                       ri_data, qs_env, work_virial, cell, particle_set)
    5067             : 
    5068             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    5069             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_contr
    5070             :       TYPE(dbt_type), DIMENSION(3), INTENT(INOUT)        :: t_3c_der_1, t_3c_der_2
    5071             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of, idx_to_at_RI, &
    5072             :                                                             idx_to_at_AO, i_images
    5073             :       INTEGER, INTENT(IN)                                :: lb_img
    5074             :       REAL(dp), INTENT(IN)                               :: pref
    5075             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    5076             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    5077             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    5078             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    5079             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    5080             :          POINTER                                         :: particle_set
    5081             : 
    5082             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_force_from_3c_trace'
    5083             : 
    5084             :       INTEGER :: handle, i_RI, i_xyz, iat, iat_of_kind, idx, ikind, j_xyz, jat, jat_of_kind, &
    5085             :          jkind, kat, kat_of_kind, kkind, nblks_AO, nblks_RI, RI_img
    5086             :       INTEGER, DIMENSION(3)                              :: ind
    5087        1525 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    5088             :       LOGICAL                                            :: found, found_1, found_2, use_virial
    5089             :       REAL(dp)                                           :: new_force
    5090        1525 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET  :: contr_blk, der_blk_1, der_blk_2, &
    5091        1525 :                                                             der_blk_3
    5092             :       REAL(dp), DIMENSION(3)                             :: scoord
    5093             :       TYPE(dbt_iterator_type)                            :: iter
    5094             :       TYPE(kpoint_type), POINTER                         :: kpoints
    5095             : 
    5096        1525 :       NULLIFY (kpoints, index_to_cell)
    5097             : 
    5098        1525 :       CALL timeset(routineN, handle)
    5099             : 
    5100        1525 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    5101        1525 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    5102             : 
    5103        1525 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    5104        1525 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    5105             : 
    5106        1525 :       use_virial = .FALSE.
    5107        1525 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    5108             : 
    5109             : !$OMP PARALLEL DEFAULT(NONE) &
    5110             : !$OMP SHARED(t_3c_der_1, t_3c_der_2,t_3c_contr,work_virial,force,use_virial,index_to_cell,i_images,lb_img) &
    5111             : !$OMP SHARED(pref,idx_to_at_AO,atom_of_kind,kind_of,particle_set,cell,idx_to_at_RI,ri_data,nblks_RI,nblks_AO) &
    5112             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk_1,contr_blk,found,new_force,iat,iat_of_kind,ikind,scoord) &
    5113        1525 : !$OMP PRIVATE(jat,kat,jat_of_kind,kat_of_kind,jkind,kkind,i_RI,RI_img,der_blk_2,der_blk_3,found_1,found_2,idx)
    5114             :       CALL dbt_iterator_start(iter, t_3c_contr)
    5115             :       DO WHILE (dbt_iterator_blocks_left(iter))
    5116             :          CALL dbt_iterator_next_block(iter, ind)
    5117             : 
    5118             :          CALL dbt_get_block(t_3c_contr, ind, contr_blk, found)
    5119             :          IF (found) THEN
    5120             : 
    5121             :             DO i_xyz = 1, 3
    5122             :                CALL dbt_get_block(t_3c_der_1(i_xyz), ind, der_blk_1, found_1)
    5123             :                IF (.NOT. found_1) THEN
    5124             :                   DEALLOCATE (der_blk_1)
    5125             :                   ALLOCATE (der_blk_1(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5126             :                   der_blk_1(:, :, :) = 0.0_dp
    5127             :                END IF
    5128             :                CALL dbt_get_block(t_3c_der_2(i_xyz), ind, der_blk_2, found_2)
    5129             :                IF (.NOT. found_2) THEN
    5130             :                   DEALLOCATE (der_blk_2)
    5131             :                   ALLOCATE (der_blk_2(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5132             :                   der_blk_2(:, :, :) = 0.0_dp
    5133             :                END IF
    5134             : 
    5135             :                ALLOCATE (der_blk_3(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5136             :                der_blk_3(:, :, :) = -(der_blk_1(:, :, :) + der_blk_2(:, :, :))
    5137             : 
    5138             :                !We assume the tensors are in the format (P^0| sigma^0 mu^a+c-b), with P a member of the
    5139             :                !extended RI basis set
    5140             : 
    5141             :                !Force for the first center (RI extended basis, zero cell)
    5142             :                new_force = pref*SUM(der_blk_1(:, :, :)*contr_blk(:, :, :))
    5143             : 
    5144             :                i_RI = (ind(1) - 1)/nblks_RI + 1
    5145             :                RI_img = ri_data%RI_cell_to_img(i_RI)
    5146             :                iat = idx_to_at_RI(ind(1) - (i_RI - 1)*nblks_RI)
    5147             :                iat_of_kind = atom_of_kind(iat)
    5148             :                ikind = kind_of(iat)
    5149             : 
    5150             : !$OMP ATOMIC
    5151             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5152             :                                                           + new_force
    5153             : 
    5154             :                IF (use_virial) THEN
    5155             : 
    5156             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5157             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, RI_img), dp)
    5158             : 
    5159             :                   DO j_xyz = 1, 3
    5160             : !$OMP ATOMIC
    5161             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5162             :                   END DO
    5163             :                END IF
    5164             : 
    5165             :                !Force with respect to the second center (AO basis, zero cell)
    5166             :                new_force = pref*SUM(der_blk_2(:, :, :)*contr_blk(:, :, :))
    5167             :                jat = idx_to_at_AO(ind(2))
    5168             :                jat_of_kind = atom_of_kind(jat)
    5169             :                jkind = kind_of(jat)
    5170             : 
    5171             : !$OMP ATOMIC
    5172             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5173             :                                                           + new_force
    5174             : 
    5175             :                IF (use_virial) THEN
    5176             : 
    5177             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5178             : 
    5179             :                   DO j_xyz = 1, 3
    5180             : !$OMP ATOMIC
    5181             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5182             :                   END DO
    5183             :                END IF
    5184             : 
    5185             :                !Force with respect to the third center (AO basis, apc_img - b_img)
    5186             :                !Note: tensors are stacked along the 3rd direction
    5187             :                new_force = pref*SUM(der_blk_3(:, :, :)*contr_blk(:, :, :))
    5188             :                idx = (ind(3) - 1)/nblks_AO + 1
    5189             :                kat = idx_to_at_AO(ind(3) - (idx - 1)*nblks_AO)
    5190             :                kat_of_kind = atom_of_kind(kat)
    5191             :                kkind = kind_of(kat)
    5192             : 
    5193             : !$OMP ATOMIC
    5194             :                force(kkind)%fock_4c(i_xyz, kat_of_kind) = force(kkind)%fock_4c(i_xyz, kat_of_kind) &
    5195             :                                                           + new_force
    5196             : 
    5197             :                IF (use_virial) THEN
    5198             :                   CALL real_to_scaled(scoord, pbc(particle_set(kat)%r, cell), cell)
    5199             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_images(lb_img - 1 + idx)), dp)
    5200             : 
    5201             :                   DO j_xyz = 1, 3
    5202             : !$OMP ATOMIC
    5203             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5204             :                   END DO
    5205             :                END IF
    5206             : 
    5207             :                DEALLOCATE (der_blk_1, der_blk_2, der_blk_3)
    5208             :             END DO !i_xyz
    5209             :             DEALLOCATE (contr_blk)
    5210             :          END IF !found
    5211             :       END DO !iter
    5212             :       CALL dbt_iterator_stop(iter)
    5213             : !$OMP END PARALLEL
    5214        1525 :       CALL timestop(handle)
    5215             : 
    5216        3050 :    END SUBROUTINE get_force_from_3c_trace
    5217             : 
    5218             : END MODULE

Generated by: LCOV version 1.15