LCOV - code coverage report
Current view: top level - src/dbt - dbt_methods.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 94.2 % 895 843
Test Date: 2025-07-25 12:55:17 Functions: 95.7 % 23 22

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief DBT tensor framework for block-sparse tensor contraction.
      10              : !>        Representation of n-rank tensors as DBT tall-and-skinny matrices.
      11              : !>        Support for arbitrary redistribution between different representations.
      12              : !>        Support for arbitrary tensor contractions
      13              : !> \todo implement checks and error messages
      14              : !> \author Patrick Seewald
      15              : ! **************************************************************************************************
      16              : MODULE dbt_methods
      17              :    #:include "dbt_macros.fypp"
      18              :    #:set maxdim = maxrank
      19              :    #:set ndims = range(2,maxdim+1)
      20              : 
      21              :    USE cp_dbcsr_api, ONLY: &
      22              :       dbcsr_type, dbcsr_release, &
      23              :       dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
      24              :       dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
      25              :    USE dbt_allocate_wrap, ONLY: &
      26              :       allocate_any
      27              :    USE dbt_array_list_methods, ONLY: &
      28              :       get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, &
      29              :       create_array_list, destroy_array_list, sizes_of_arrays
      30              :    USE dbm_api, ONLY: &
      31              :       dbm_clear
      32              :    USE dbt_tas_types, ONLY: &
      33              :       dbt_tas_split_info
      34              :    USE dbt_tas_base, ONLY: &
      35              :       dbt_tas_copy, dbt_tas_finalize, dbt_tas_get_info, dbt_tas_info
      36              :    USE dbt_tas_mm, ONLY: &
      37              :       dbt_tas_multiply, dbt_tas_batched_mm_init, dbt_tas_batched_mm_finalize, &
      38              :       dbt_tas_batched_mm_complete, dbt_tas_set_batched_state
      39              :    USE dbt_block, ONLY: &
      40              :       dbt_iterator_type, dbt_get_block, dbt_put_block, dbt_iterator_start, &
      41              :       dbt_iterator_blocks_left, dbt_iterator_stop, dbt_iterator_next_block, &
      42              :       ndims_iterator, dbt_reserve_blocks, block_nd, destroy_block, checker_tr
      43              :    USE dbt_index, ONLY: &
      44              :       dbt_get_mapping_info, nd_to_2d_mapping, dbt_inverse_order, permute_index, get_nd_indices_tensor, &
      45              :       ndims_mapping_row, ndims_mapping_column, ndims_mapping
      46              :    USE dbt_types, ONLY: &
      47              :       dbt_create, dbt_type, ndims_tensor, dims_tensor, &
      48              :       dbt_distribution_type, dbt_distribution, dbt_nd_mp_comm, dbt_destroy, &
      49              :       dbt_distribution_destroy, dbt_distribution_new_expert, dbt_get_stored_coordinates, &
      50              :       blk_dims_tensor, dbt_hold, dbt_pgrid_type, mp_environ_pgrid, dbt_filter, &
      51              :       dbt_clear, dbt_finalize, dbt_get_num_blocks, dbt_scale, &
      52              :       dbt_get_num_blocks_total, dbt_get_info, ndims_matrix_row, ndims_matrix_column, &
      53              :       dbt_max_nblks_local, dbt_default_distvec, dbt_contraction_storage, dbt_nblks_total, &
      54              :       dbt_distribution_new, dbt_copy_contraction_storage, dbt_pgrid_destroy
      55              :    USE kinds, ONLY: &
      56              :       dp, default_string_length, int_8, dp
      57              :    USE message_passing, ONLY: &
      58              :       mp_cart_type
      59              :    USE util, ONLY: &
      60              :       sort
      61              :    USE dbt_reshape_ops, ONLY: &
      62              :       dbt_reshape
      63              :    USE dbt_tas_split, ONLY: &
      64              :       dbt_tas_mp_comm, rowsplit, colsplit, dbt_tas_info_hold, dbt_tas_release_info, default_nsplit_accept_ratio, &
      65              :       default_pdims_accept_ratio, dbt_tas_create_split
      66              :    USE dbt_split, ONLY: &
      67              :       dbt_split_copyback, dbt_make_compatible_blocks, dbt_crop
      68              :    USE dbt_io, ONLY: &
      69              :       dbt_write_tensor_info, dbt_write_tensor_dist, prep_output_unit, dbt_write_split_info
      70              :    USE message_passing, ONLY: mp_comm_type
      71              : 
      72              : #include "../base/base_uses.f90"
      73              : 
      74              :    IMPLICIT NONE
      75              :    PRIVATE
      76              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
      77              : 
      78              :    PUBLIC :: &
      79              :       dbt_contract, &
      80              :       dbt_copy, &
      81              :       dbt_get_block, &
      82              :       dbt_get_stored_coordinates, &
      83              :       dbt_inverse_order, &
      84              :       dbt_iterator_blocks_left, &
      85              :       dbt_iterator_next_block, &
      86              :       dbt_iterator_start, &
      87              :       dbt_iterator_stop, &
      88              :       dbt_iterator_type, &
      89              :       dbt_put_block, &
      90              :       dbt_reserve_blocks, &
      91              :       dbt_copy_matrix_to_tensor, &
      92              :       dbt_copy_tensor_to_matrix, &
      93              :       dbt_batched_contract_init, &
      94              :       dbt_batched_contract_finalize
      95              : 
      96              : CONTAINS
      97              : 
      98              : ! **************************************************************************************************
      99              : !> \brief Copy tensor data.
     100              : !>        Redistributes tensor data according to distributions of target and source tensor.
     101              : !>        Permutes tensor index according to `order` argument (if present).
     102              : !>        Source and target tensor formats are arbitrary as long as the following requirements are met:
     103              : !>        * source and target tensors have the same rank and the same sizes in each dimension in terms
     104              : !>          of tensor elements (block sizes don't need to be the same).
     105              : !>          If `order` argument is present, sizes must match after index permutation.
     106              : !>        OR
     107              : !>        * target tensor is not yet created, in this case an exact copy of source tensor is returned.
     108              : !> \param tensor_in Source
     109              : !> \param tensor_out Target
     110              : !> \param order Permutation of target tensor index.
     111              : !>              Exact same convention as order argument of RESHAPE intrinsic.
     112              : !> \param bounds crop tensor data: start and end index for each tensor dimension
     113              : !> \author Patrick Seewald
     114              : ! **************************************************************************************************
     115       897012 :    SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     116              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
     117              :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
     118              :          INTENT(IN), OPTIONAL                        :: order
     119              :       LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
     120              :       INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
     121              :          INTENT(IN), OPTIONAL                        :: bounds
     122              :       INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
     123              :       INTEGER :: handle
     124              : 
     125       448506 :       CALL tensor_in%pgrid%mp_comm_2d%sync()
     126       448506 :       CALL timeset("dbt_total", handle)
     127              : 
     128              :       ! make sure that it is safe to use dbt_copy during a batched contraction
     129       448506 :       CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
     130       448506 :       CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
     131              : 
     132       448506 :       CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     133       448506 :       CALL tensor_in%pgrid%mp_comm_2d%sync()
     134       448506 :       CALL timestop(handle)
     135       570666 :    END SUBROUTINE
     136              : 
     137              : ! **************************************************************************************************
     138              : !> \brief expert routine for copying a tensor. For internal use only.
     139              : !> \author Patrick Seewald
     140              : ! **************************************************************************************************
     141       474969 :    SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     142              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
     143              :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
     144              :          INTENT(IN), OPTIONAL                        :: order
     145              :       LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
     146              :       INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
     147              :          INTENT(IN), OPTIONAL                        :: bounds
     148              :       INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
     149              : 
     150              :       TYPE(dbt_type), POINTER                    :: in_tmp_1, in_tmp_2, &
     151              :                                                     in_tmp_3, out_tmp_1
     152              :       INTEGER                                        :: handle, unit_nr_prv
     153       474969 :       INTEGER, DIMENSION(:), ALLOCATABLE             :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
     154              : 
     155              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy'
     156              :       LOGICAL                                        :: dist_compatible_tas, dist_compatible_tensor, &
     157              :                                                         summation_prv, new_in_1, new_in_2, &
     158              :                                                         new_in_3, new_out_1, block_compatible, &
     159              :                                                         move_prv
     160       474969 :       TYPE(array_list)                               :: blk_sizes_in
     161              : 
     162       474969 :       CALL timeset(routineN, handle)
     163              : 
     164       474969 :       CPASSERT(tensor_out%valid)
     165              : 
     166       474969 :       unit_nr_prv = prep_output_unit(unit_nr)
     167              : 
     168       474969 :       IF (PRESENT(move_data)) THEN
     169       358054 :          move_prv = move_data
     170              :       ELSE
     171       116915 :          move_prv = .FALSE.
     172              :       END IF
     173              : 
     174       474969 :       dist_compatible_tas = .FALSE.
     175       474969 :       dist_compatible_tensor = .FALSE.
     176       474969 :       block_compatible = .FALSE.
     177       474969 :       new_in_1 = .FALSE.
     178       474969 :       new_in_2 = .FALSE.
     179       474969 :       new_in_3 = .FALSE.
     180       474969 :       new_out_1 = .FALSE.
     181              : 
     182       474969 :       IF (PRESENT(summation)) THEN
     183       126933 :          summation_prv = summation
     184              :       ELSE
     185              :          summation_prv = .FALSE.
     186              :       END IF
     187              : 
     188       474969 :       IF (PRESENT(bounds)) THEN
     189        39732 :          ALLOCATE (in_tmp_1)
     190         5676 :          CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
     191         5676 :          new_in_1 = .TRUE.
     192         5676 :          move_prv = .TRUE.
     193              :       ELSE
     194              :          in_tmp_1 => tensor_in
     195              :       END IF
     196              : 
     197       474969 :       IF (PRESENT(order)) THEN
     198       122160 :          CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
     199       122160 :          block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
     200              :       ELSE
     201       352809 :          block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
     202              :       END IF
     203              : 
     204       474969 :       IF (.NOT. block_compatible) THEN
     205       962845 :          ALLOCATE (in_tmp_2, out_tmp_1)
     206              :          CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
     207        74065 :                                          nodata2=.NOT. summation_prv, move_data=move_prv)
     208        74065 :          new_in_2 = .TRUE.; new_out_1 = .TRUE.
     209        74065 :          move_prv = .TRUE.
     210              :       ELSE
     211              :          in_tmp_2 => in_tmp_1
     212              :          out_tmp_1 => tensor_out
     213              :       END IF
     214              : 
     215       474969 :       IF (PRESENT(order)) THEN
     216       855120 :          ALLOCATE (in_tmp_3)
     217       122160 :          CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
     218       122160 :          new_in_3 = .TRUE.
     219              :       ELSE
     220              :          in_tmp_3 => in_tmp_2
     221              :       END IF
     222              : 
     223      1424907 :       ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
     224      1424907 :       ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
     225       474969 :       CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
     226              : 
     227      1424907 :       ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
     228      1424907 :       ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
     229       474969 :       CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
     230              : 
     231       474969 :       IF (.NOT. PRESENT(order)) THEN
     232       352809 :          IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
     233       313230 :             dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
     234       514583 :          ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
     235        19272 :             dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
     236              :          END IF
     237              :       END IF
     238              : 
     239       313230 :       IF (dist_compatible_tas) THEN
     240       258376 :          CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
     241       258376 :          IF (move_prv) CALL dbt_clear(in_tmp_3)
     242       216593 :       ELSEIF (dist_compatible_tensor) THEN
     243        11010 :          CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
     244        11010 :          IF (move_prv) CALL dbt_clear(in_tmp_3)
     245              :       ELSE
     246       205583 :          CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
     247              :       END IF
     248              : 
     249       474969 :       IF (new_in_1) THEN
     250         5676 :          CALL dbt_destroy(in_tmp_1)
     251         5676 :          DEALLOCATE (in_tmp_1)
     252              :       END IF
     253              : 
     254       474969 :       IF (new_in_2) THEN
     255        74065 :          CALL dbt_destroy(in_tmp_2)
     256        74065 :          DEALLOCATE (in_tmp_2)
     257              :       END IF
     258              : 
     259       474969 :       IF (new_in_3) THEN
     260       122160 :          CALL dbt_destroy(in_tmp_3)
     261       122160 :          DEALLOCATE (in_tmp_3)
     262              :       END IF
     263              : 
     264       474969 :       IF (new_out_1) THEN
     265        74065 :          IF (unit_nr_prv /= 0) THEN
     266            0 :             CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
     267              :          END IF
     268        74065 :          CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
     269        74065 :          CALL dbt_destroy(out_tmp_1)
     270        74065 :          DEALLOCATE (out_tmp_1)
     271              :       END IF
     272              : 
     273       474969 :       CALL timestop(handle)
     274              : 
     275       949938 :    END SUBROUTINE
     276              : 
     277              : ! **************************************************************************************************
     278              : !> \brief copy without communication, requires that both tensors have same process grid and distribution
     279              : !> \param summation Whether to sum matrices b = a + b
     280              : !> \author Patrick Seewald
     281              : ! **************************************************************************************************
     282        11010 :    SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
     283              :       TYPE(dbt_type), INTENT(INOUT) :: tensor_in
     284              :       TYPE(dbt_type), INTENT(INOUT) :: tensor_out
     285              :       LOGICAL, INTENT(IN), OPTIONAL                      :: summation
     286              :       TYPE(dbt_iterator_type) :: iter
     287        11010 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))  :: ind_nd
     288        11010 :       TYPE(block_nd) :: blk_data
     289              :       LOGICAL :: found
     290              : 
     291              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_nocomm'
     292              :       INTEGER :: handle
     293              : 
     294        11010 :       CALL timeset(routineN, handle)
     295        11010 :       CPASSERT(tensor_out%valid)
     296              : 
     297        11010 :       IF (PRESENT(summation)) THEN
     298         5472 :          IF (.NOT. summation) CALL dbt_clear(tensor_out)
     299              :       ELSE
     300         5538 :          CALL dbt_clear(tensor_out)
     301              :       END IF
     302              : 
     303        11010 :       CALL dbt_reserve_blocks(tensor_in, tensor_out)
     304              : 
     305              : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
     306        11010 : !$OMP PRIVATE(iter,ind_nd,blk_data,found)
     307              :       CALL dbt_iterator_start(iter, tensor_in)
     308              :       DO WHILE (dbt_iterator_blocks_left(iter))
     309              :          CALL dbt_iterator_next_block(iter, ind_nd)
     310              :          CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
     311              :          CPASSERT(found)
     312              :          CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
     313              :          CALL destroy_block(blk_data)
     314              :       END DO
     315              :       CALL dbt_iterator_stop(iter)
     316              : !$OMP END PARALLEL
     317              : 
     318        11010 :       CALL timestop(handle)
     319        22020 :    END SUBROUTINE
     320              : 
     321              : ! **************************************************************************************************
     322              : !> \brief copy matrix to tensor.
     323              : !> \param summation tensor_out = tensor_out + matrix_in
     324              : !> \author Patrick Seewald
     325              : ! **************************************************************************************************
     326        76478 :    SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
     327              :       TYPE(dbcsr_type), TARGET, INTENT(IN)               :: matrix_in
     328              :       TYPE(dbt_type), INTENT(INOUT)             :: tensor_out
     329              :       LOGICAL, INTENT(IN), OPTIONAL                      :: summation
     330              :       TYPE(dbcsr_type), POINTER                          :: matrix_in_desym
     331              : 
     332              :       INTEGER, DIMENSION(2)                              :: ind_2d
     333        76478 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)    :: block_arr
     334        76478 :       REAL(KIND=dp), DIMENSION(:, :), POINTER        :: block
     335              :       TYPE(dbcsr_iterator_type)                          :: iter
     336              : 
     337              :       INTEGER                                            :: handle
     338              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_matrix_to_tensor'
     339              : 
     340        76478 :       CALL timeset(routineN, handle)
     341        76478 :       CPASSERT(tensor_out%valid)
     342              : 
     343        76478 :       NULLIFY (block)
     344              : 
     345        76478 :       IF (dbcsr_has_symmetry(matrix_in)) THEN
     346         5066 :          ALLOCATE (matrix_in_desym)
     347         5066 :          CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
     348              :       ELSE
     349              :          matrix_in_desym => matrix_in
     350              :       END IF
     351              : 
     352        76478 :       IF (PRESENT(summation)) THEN
     353            0 :          IF (.NOT. summation) CALL dbt_clear(tensor_out)
     354              :       ELSE
     355        76478 :          CALL dbt_clear(tensor_out)
     356              :       END IF
     357              : 
     358        76478 :       CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
     359              : 
     360              : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
     361        76478 : !$OMP PRIVATE(iter,ind_2d,block,block_arr)
     362              :       CALL dbcsr_iterator_start(iter, matrix_in_desym)
     363              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     364              :          CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block)
     365              :          CALL allocate_any(block_arr, source=block)
     366              :          CALL dbt_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation)
     367              :          DEALLOCATE (block_arr)
     368              :       END DO
     369              :       CALL dbcsr_iterator_stop(iter)
     370              : !$OMP END PARALLEL
     371              : 
     372        76478 :       IF (dbcsr_has_symmetry(matrix_in)) THEN
     373         5066 :          CALL dbcsr_release(matrix_in_desym)
     374         5066 :          DEALLOCATE (matrix_in_desym)
     375              :       END IF
     376              : 
     377        76478 :       CALL timestop(handle)
     378              : 
     379       152956 :    END SUBROUTINE
     380              : 
     381              : ! **************************************************************************************************
     382              : !> \brief copy tensor to matrix
     383              : !> \param summation matrix_out = matrix_out + tensor_in
     384              : !> \author Patrick Seewald
     385              : ! **************************************************************************************************
     386        48090 :    SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
     387              :       TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
     388              :       TYPE(dbcsr_type), INTENT(INOUT)             :: matrix_out
     389              :       LOGICAL, INTENT(IN), OPTIONAL          :: summation
     390              :       TYPE(dbt_iterator_type)            :: iter
     391              :       INTEGER                                :: handle
     392              :       INTEGER, DIMENSION(2)                  :: ind_2d
     393        48090 :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: block
     394              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_tensor_to_matrix'
     395              :       LOGICAL :: found
     396              : 
     397        48090 :       CALL timeset(routineN, handle)
     398              : 
     399        48090 :       IF (PRESENT(summation)) THEN
     400         5748 :          IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
     401              :       ELSE
     402        42342 :          CALL dbcsr_clear(matrix_out)
     403              :       END IF
     404              : 
     405        48090 :       CALL dbt_reserve_blocks(tensor_in, matrix_out)
     406              : 
     407              : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
     408        48090 : !$OMP PRIVATE(iter,ind_2d,block,found)
     409              :       CALL dbt_iterator_start(iter, tensor_in)
     410              :       DO WHILE (dbt_iterator_blocks_left(iter))
     411              :          CALL dbt_iterator_next_block(iter, ind_2d)
     412              :          IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE
     413              : 
     414              :          CALL dbt_get_block(tensor_in, ind_2d, block, found)
     415              :          CPASSERT(found)
     416              : 
     417              :          IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
     418              :             CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation)
     419              :          ELSE
     420              :             CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
     421              :          END IF
     422              :          DEALLOCATE (block)
     423              :       END DO
     424              :       CALL dbt_iterator_stop(iter)
     425              : !$OMP END PARALLEL
     426              : 
     427        48090 :       CALL timestop(handle)
     428              : 
     429        96180 :    END SUBROUTINE
     430              : 
     431              : ! **************************************************************************************************
     432              : !> \brief Contract tensors by multiplying matrix representations.
     433              : !>        tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
     434              : !>        * tensor_2(contract_2, notcontract_2)
     435              : !>        + beta * tensor_3(map_1, map_2)
     436              : !>
     437              : !> \note
     438              : !>      note 1: block sizes of the corresponding indices need to be the same in all tensors.
     439              : !>
     440              : !>      note 2: for best performance the tensors should have been created in matrix layouts
     441              : !>      compatible with the contraction, e.g. tensor_1 should have been created with either
     442              : !>      map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
     443              : !>      map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
     444              : !>      tensor_3 and map_1 / map_2).
     445              : !>      Furthermore the two largest tensors involved in the contraction should map both to either
     446              : !>      tall or short matrices: the largest matrix dimension should be "on the same side"
     447              : !>      and should have identical distribution (which is always the case if the distributions were
     448              : !>      obtained with dbt_default_distvec).
     449              : !>
     450              : !>      note 3: if the same tensor occurs in multiple contractions, a different tensor object should
     451              : !>      be created for each contraction and the data should be copied between the tensors by use of
     452              : !>      dbt_copy. If the same tensor object is used in multiple contractions,
     453              : !>       matrix layouts are not compatible for all contractions (see note 2).
     454              : !>
     455              : !>      note 4: automatic optimizations are enabled by using the feature of batched contraction, see
     456              : !>      dbt_batched_contract_init, dbt_batched_contract_finalize.
     457              : !>      The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
     458              : !>
     459              : !> \param tensor_1 first tensor (in)
     460              : !> \param tensor_2 second tensor (in)
     461              : !> \param contract_1 indices of tensor_1 to contract
     462              : !> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
     463              : !> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
     464              : !> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
     465              : !> \param notcontract_1 indices of tensor_1 not to contract
     466              : !> \param notcontract_2 indices of tensor_2 not to contract
     467              : !> \param tensor_3 contracted tensor (out)
     468              : !> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
     469              : !>                 start and end index of an index range over which to contract.
     470              : !>                 For use in batched contraction.
     471              : !> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
     472              : !>                 For use in batched contraction.
     473              : !> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
     474              : !>                 For use in batched contraction.
     475              : !> \param optimize_dist Whether distribution should be optimized internally. In the current
     476              : !>                      implementation this guarantees optimal parameters only for dense matrices.
     477              : !> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
     478              : !>                    This can be used to choose optimal process grids for subsequent tensor
     479              : !>                    contractions with tensors of similar shape and sparsity. Under some conditions,
     480              : !>                    pgrid_opt_1 can not be returned, in this case the pointer is not associated.
     481              : !> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
     482              : !> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
     483              : !> \param filter_eps As in DBM mm
     484              : !> \param flop As in DBM mm
     485              : !> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
     486              : !> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
     487              : !> \param unit_nr output unit for logging
     488              : !>                       set it to -1 on ranks that should not write (and any valid unit number on
     489              : !>                       ranks that should write output) if 0 on ALL ranks, no output is written
     490              : !> \param log_verbose verbose logging (for testing only)
     491              : !> \author Patrick Seewald
     492              : ! **************************************************************************************************
     493       336494 :    SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
     494       168247 :                            contract_1, notcontract_1, &
     495       168247 :                            contract_2, notcontract_2, &
     496       168247 :                            map_1, map_2, &
     497       122974 :                            bounds_1, bounds_2, bounds_3, &
     498              :                            optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
     499              :                            filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
     500              :       REAL(dp), INTENT(IN)            :: alpha
     501              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_1
     502              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_2
     503              :       REAL(dp), INTENT(IN)            :: beta
     504              :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
     505              :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
     506              :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
     507              :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
     508              :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
     509              :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
     510              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_3
     511              :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
     512              :          INTENT(IN), OPTIONAL                        :: bounds_1
     513              :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
     514              :          INTENT(IN), OPTIONAL                        :: bounds_2
     515              :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
     516              :          INTENT(IN), OPTIONAL                        :: bounds_3
     517              :       LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
     518              :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     519              :          POINTER, OPTIONAL                           :: pgrid_opt_1
     520              :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     521              :          POINTER, OPTIONAL                           :: pgrid_opt_2
     522              :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     523              :          POINTER, OPTIONAL                           :: pgrid_opt_3
     524              :       REAL(KIND=dp), INTENT(IN), OPTIONAL        :: filter_eps
     525              :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
     526              :       LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
     527              :       LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
     528              :       INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
     529              :       LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
     530              : 
     531              :       INTEGER                     :: handle
     532              : 
     533       168247 :       CALL tensor_1%pgrid%mp_comm_2d%sync()
     534       168247 :       CALL timeset("dbt_total", handle)
     535              :       CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
     536              :                                contract_1, notcontract_1, &
     537              :                                contract_2, notcontract_2, &
     538              :                                map_1, map_2, &
     539              :                                bounds_1=bounds_1, &
     540              :                                bounds_2=bounds_2, &
     541              :                                bounds_3=bounds_3, &
     542              :                                optimize_dist=optimize_dist, &
     543              :                                pgrid_opt_1=pgrid_opt_1, &
     544              :                                pgrid_opt_2=pgrid_opt_2, &
     545              :                                pgrid_opt_3=pgrid_opt_3, &
     546              :                                filter_eps=filter_eps, &
     547              :                                flop=flop, &
     548              :                                move_data=move_data, &
     549              :                                retain_sparsity=retain_sparsity, &
     550              :                                unit_nr=unit_nr, &
     551       168247 :                                log_verbose=log_verbose)
     552       168247 :       CALL tensor_1%pgrid%mp_comm_2d%sync()
     553       168247 :       CALL timestop(handle)
     554              : 
     555       244245 :    END SUBROUTINE
     556              : 
     557              : ! **************************************************************************************************
     558              : !> \brief expert routine for tensor contraction. For internal use only.
     559              : !> \param nblks_local number of local blocks on this MPI rank
     560              : !> \author Patrick Seewald
     561              : ! **************************************************************************************************
     562       168247 :    SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
     563       168247 :                                   contract_1, notcontract_1, &
     564       168247 :                                   contract_2, notcontract_2, &
     565       168247 :                                   map_1, map_2, &
     566       168247 :                                   bounds_1, bounds_2, bounds_3, &
     567              :                                   optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
     568              :                                   filter_eps, flop, move_data, retain_sparsity, &
     569              :                                   nblks_local, unit_nr, log_verbose)
     570              :       REAL(dp), INTENT(IN)            :: alpha
     571              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_1
     572              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_2
     573              :       REAL(dp), INTENT(IN)            :: beta
     574              :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
     575              :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
     576              :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
     577              :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
     578              :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
     579              :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
     580              :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_3
     581              :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
     582              :          INTENT(IN), OPTIONAL                        :: bounds_1
     583              :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
     584              :          INTENT(IN), OPTIONAL                        :: bounds_2
     585              :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
     586              :          INTENT(IN), OPTIONAL                        :: bounds_3
     587              :       LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
     588              :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     589              :          POINTER, OPTIONAL                           :: pgrid_opt_1
     590              :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     591              :          POINTER, OPTIONAL                           :: pgrid_opt_2
     592              :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     593              :          POINTER, OPTIONAL                           :: pgrid_opt_3
     594              :       REAL(KIND=dp), INTENT(IN), OPTIONAL        :: filter_eps
     595              :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
     596              :       LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
     597              :       LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
     598              :       INTEGER, INTENT(OUT), OPTIONAL                 :: nblks_local
     599              :       INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
     600              :       LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
     601              : 
     602              :       TYPE(dbt_type), POINTER                    :: tensor_contr_1, tensor_contr_2, tensor_contr_3
     603      3196693 :       TYPE(dbt_type), TARGET                     :: tensor_algn_1, tensor_algn_2, tensor_algn_3
     604              :       TYPE(dbt_type), POINTER                    :: tensor_crop_1, tensor_crop_2
     605              :       TYPE(dbt_type), POINTER                    :: tensor_small, tensor_large
     606              : 
     607              :       LOGICAL                                        :: assert_stmt, tensors_remapped
     608              :       INTEGER                                        :: max_mm_dim, max_tensor, &
     609              :                                                         unit_nr_prv, ref_tensor, handle
     610       168247 :       TYPE(mp_cart_type) :: mp_comm_opt
     611       336494 :       INTEGER, DIMENSION(SIZE(contract_1))           :: contract_1_mod
     612       336494 :       INTEGER, DIMENSION(SIZE(notcontract_1))        :: notcontract_1_mod
     613       336494 :       INTEGER, DIMENSION(SIZE(contract_2))           :: contract_2_mod
     614       336494 :       INTEGER, DIMENSION(SIZE(notcontract_2))        :: notcontract_2_mod
     615       336494 :       INTEGER, DIMENSION(SIZE(map_1))                :: map_1_mod
     616       336494 :       INTEGER, DIMENSION(SIZE(map_2))                :: map_2_mod
     617              :       LOGICAL                                        :: trans_1, trans_2, trans_3
     618              :       LOGICAL                                        :: new_1, new_2, new_3, move_data_1, move_data_2
     619              :       INTEGER                                        :: ndims1, ndims2, ndims3
     620              :       INTEGER                                        :: occ_1, occ_2
     621       168247 :       INTEGER, DIMENSION(:), ALLOCATABLE             :: dims1, dims2, dims3
     622              : 
     623              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_contract'
     624       168247 :       CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE    :: indchar1, indchar2, indchar3, indchar1_mod, &
     625       168247 :                                                         indchar2_mod, indchar3_mod
     626              :       CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
     627              :                                                ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
     628       336494 :       INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
     629       336494 :       INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
     630              :       LOGICAL                                        :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
     631              :                                                         pgrid_changed_any, do_change_pgrid(2)
     632      2187211 :       TYPE(dbt_tas_split_info)                     :: split_opt, split, split_opt_avg
     633              :       INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
     634              :       REAL(dp) :: pdim_ratio, pdim_ratio_opt
     635              : 
     636       168247 :       NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
     637       168247 :                tensor_small)
     638              : 
     639       168247 :       CALL timeset(routineN, handle)
     640              : 
     641       168247 :       CPASSERT(tensor_1%valid)
     642       168247 :       CPASSERT(tensor_2%valid)
     643       168247 :       CPASSERT(tensor_3%valid)
     644              : 
     645       168247 :       assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
     646       168247 :       CPASSERT(assert_stmt)
     647              : 
     648       168247 :       assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
     649       168247 :       CPASSERT(assert_stmt)
     650              : 
     651       168247 :       assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
     652       168247 :       CPASSERT(assert_stmt)
     653              : 
     654       168247 :       assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
     655       168247 :       CPASSERT(assert_stmt)
     656              : 
     657       168247 :       assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
     658       168247 :       CPASSERT(assert_stmt)
     659              : 
     660       168247 :       assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
     661       168247 :       CPASSERT(assert_stmt)
     662              : 
     663       168247 :       unit_nr_prv = prep_output_unit(unit_nr)
     664              : 
     665       168247 :       IF (PRESENT(flop)) flop = 0
     666       168247 :       IF (PRESENT(nblks_local)) nblks_local = 0
     667              : 
     668       168247 :       IF (PRESENT(move_data)) THEN
     669        41295 :          move_data_1 = move_data
     670        41295 :          move_data_2 = move_data
     671              :       ELSE
     672       126952 :          move_data_1 = .FALSE.
     673       126952 :          move_data_2 = .FALSE.
     674              :       END IF
     675              : 
     676       168247 :       nodata_3 = .TRUE.
     677       168247 :       IF (PRESENT(retain_sparsity)) THEN
     678         4794 :          IF (retain_sparsity) nodata_3 = .FALSE.
     679              :       END IF
     680              : 
     681              :       CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
     682              :                                      contract_1, notcontract_1, &
     683              :                                      contract_2, notcontract_2, &
     684              :                                      bounds_t1, bounds_t2, &
     685              :                                      bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
     686       168247 :                                      do_crop_1=do_crop_1, do_crop_2=do_crop_2)
     687              : 
     688       168247 :       IF (do_crop_1) THEN
     689       510006 :          ALLOCATE (tensor_crop_1)
     690        72858 :          CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
     691        72858 :          move_data_1 = .TRUE.
     692              :       ELSE
     693              :          tensor_crop_1 => tensor_1
     694              :       END IF
     695              : 
     696       168247 :       IF (do_crop_2) THEN
     697       491036 :          ALLOCATE (tensor_crop_2)
     698        70148 :          CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
     699        70148 :          move_data_2 = .TRUE.
     700              :       ELSE
     701              :          tensor_crop_2 => tensor_2
     702              :       END IF
     703              : 
     704              :       ! shortcut for empty tensors
     705              :       ! this is needed to avoid unnecessary work in case user contracts different portions of a
     706              :       ! tensor consecutively to save memory
     707              :       ASSOCIATE (mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
     708       168247 :          occ_1 = dbt_get_num_blocks(tensor_crop_1)
     709       168247 :          CALL mp_comm%max(occ_1)
     710       168247 :          occ_2 = dbt_get_num_blocks(tensor_crop_2)
     711       168247 :          CALL mp_comm%max(occ_2)
     712              :       END ASSOCIATE
     713              : 
     714       168247 :       IF (occ_1 == 0 .OR. occ_2 == 0) THEN
     715        16431 :          CALL dbt_scale(tensor_3, beta)
     716        16431 :          IF (do_crop_1) THEN
     717         3075 :             CALL dbt_destroy(tensor_crop_1)
     718         3075 :             DEALLOCATE (tensor_crop_1)
     719              :          END IF
     720        16431 :          IF (do_crop_2) THEN
     721         3375 :             CALL dbt_destroy(tensor_crop_2)
     722         3375 :             DEALLOCATE (tensor_crop_2)
     723              :          END IF
     724              : 
     725        16431 :          CALL timestop(handle)
     726        16431 :          RETURN
     727              :       END IF
     728              : 
     729       151816 :       IF (unit_nr_prv /= 0) THEN
     730        52104 :          IF (unit_nr_prv > 0) THEN
     731           10 :             WRITE (unit_nr_prv, '(A)') repeat("-", 80)
     732           10 :             WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
     733           20 :                TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name)
     734           10 :             WRITE (unit_nr_prv, '(A)') repeat("-", 80)
     735              :          END IF
     736        52104 :          CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
     737        52104 :          CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
     738        52104 :          CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
     739        52104 :          CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
     740              :       END IF
     741              : 
     742              :       ! align tensor index with data, tensor data is not modified
     743       151816 :       ndims1 = ndims_tensor(tensor_crop_1)
     744       151816 :       ndims2 = ndims_tensor(tensor_crop_2)
     745       151816 :       ndims3 = ndims_tensor(tensor_3)
     746       607264 :       ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
     747       607264 :       ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
     748       607264 :       ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
     749              : 
     750              :       ! labeling tensor index with letters
     751              : 
     752      1295254 :       indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
     753       369186 :       indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
     754       355229 :       indchar2(contract_2) = indchar1(contract_1)
     755       329549 :       indchar3(map_1) = indchar1(notcontract_1)
     756       369186 :       indchar3(map_2) = indchar2(notcontract_2)
     757              : 
     758       151816 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
     759              :                                                              tensor_crop_2, indchar2, &
     760        52104 :                                                              tensor_3, indchar3, unit_nr_prv)
     761       151816 :       IF (unit_nr_prv > 0) THEN
     762           10 :          WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
     763              :       END IF
     764              : 
     765              :       CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
     766       151816 :                         tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
     767              : 
     768              :       CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
     769       151816 :                         tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
     770              : 
     771              :       CALL align_tensor(tensor_3, map_1, map_2, &
     772       151816 :                         tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
     773              : 
     774       151816 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
     775              :                                                              tensor_algn_2, indchar2_mod, &
     776        52104 :                                                              tensor_algn_3, indchar3_mod, unit_nr_prv)
     777              : 
     778       455448 :       ALLOCATE (dims1(ndims1))
     779       455448 :       ALLOCATE (dims2(ndims2))
     780       455448 :       ALLOCATE (dims3(ndims3))
     781              : 
     782              :       ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
     783              :       ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
     784       151816 :       CALL blk_dims_tensor(tensor_crop_1, dims1)
     785       151816 :       CALL blk_dims_tensor(tensor_crop_2, dims2)
     786       151816 :       CALL blk_dims_tensor(tensor_3, dims3)
     787              : 
     788              :       max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), &
     789              :                            PRODUCT(INT(dims1(contract_1), int_8)), &
     790      1205780 :                            PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1)
     791      1804296 :       max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1)
     792        34795 :       SELECT CASE (max_mm_dim)
     793              :       CASE (1)
     794        34795 :          IF (unit_nr_prv > 0) THEN
     795            3 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
     796            3 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     797              :          END IF
     798        34795 :          CALL index_linked_sort(contract_1_mod, contract_2_mod)
     799        34795 :          CALL index_linked_sort(map_2_mod, notcontract_2_mod)
     800        34367 :          SELECT CASE (max_tensor)
     801              :          CASE (1)
     802        34367 :             CALL index_linked_sort(notcontract_1_mod, map_1_mod)
     803              :          CASE (3)
     804          428 :             CALL index_linked_sort(map_1_mod, notcontract_1_mod)
     805              :          CASE DEFAULT
     806        34795 :             CPABORT("should not happen")
     807              :          END SELECT
     808              : 
     809              :          CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
     810              :                                     contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
     811              :                                     trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
     812        34795 :                                     move_data_1=move_data_1, unit_nr=unit_nr_prv)
     813              : 
     814              :          CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
     815        34795 :                                new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
     816              : 
     817        34367 :          SELECT CASE (ref_tensor)
     818              :          CASE (1)
     819        34367 :             tensor_large => tensor_contr_1
     820              :          CASE (2)
     821        34795 :             tensor_large => tensor_contr_3
     822              :          END SELECT
     823        34795 :          tensor_small => tensor_contr_2
     824              : 
     825              :       CASE (2)
     826        52217 :          IF (unit_nr_prv > 0) THEN
     827            5 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
     828            5 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     829              :          END IF
     830              : 
     831        52217 :          CALL index_linked_sort(notcontract_1_mod, map_1_mod)
     832        52217 :          CALL index_linked_sort(notcontract_2_mod, map_2_mod)
     833        51217 :          SELECT CASE (max_tensor)
     834              :          CASE (1)
     835        51217 :             CALL index_linked_sort(contract_1_mod, contract_2_mod)
     836              :          CASE (2)
     837         1000 :             CALL index_linked_sort(contract_2_mod, contract_1_mod)
     838              :          CASE DEFAULT
     839        52217 :             CPABORT("should not happen")
     840              :          END SELECT
     841              : 
     842              :          CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
     843              :                                     notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
     844              :                                     trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
     845        52217 :                                     move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
     846        52217 :          trans_1 = .NOT. trans_1
     847              : 
     848              :          CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
     849        52217 :                                new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
     850              : 
     851        51217 :          SELECT CASE (ref_tensor)
     852              :          CASE (1)
     853        51217 :             tensor_large => tensor_contr_1
     854              :          CASE (2)
     855        52217 :             tensor_large => tensor_contr_2
     856              :          END SELECT
     857        52217 :          tensor_small => tensor_contr_3
     858              : 
     859              :       CASE (3)
     860        64804 :          IF (unit_nr_prv > 0) THEN
     861            2 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
     862            2 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     863              :          END IF
     864        64804 :          CALL index_linked_sort(map_1_mod, notcontract_1_mod)
     865        64804 :          CALL index_linked_sort(contract_2_mod, contract_1_mod)
     866        64420 :          SELECT CASE (max_tensor)
     867              :          CASE (2)
     868        64420 :             CALL index_linked_sort(notcontract_2_mod, map_2_mod)
     869              :          CASE (3)
     870          384 :             CALL index_linked_sort(map_2_mod, notcontract_2_mod)
     871              :          CASE DEFAULT
     872        64804 :             CPABORT("should not happen")
     873              :          END SELECT
     874              : 
     875              :          CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
     876              :                                     contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
     877              :                                     trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
     878        64804 :                                     move_data_1=move_data_2, unit_nr=unit_nr_prv)
     879              : 
     880        64804 :          trans_2 = .NOT. trans_2
     881        64804 :          trans_3 = .NOT. trans_3
     882              : 
     883              :          CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
     884        64804 :                                trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
     885              : 
     886        64420 :          SELECT CASE (ref_tensor)
     887              :          CASE (1)
     888        64420 :             tensor_large => tensor_contr_2
     889              :          CASE (2)
     890        64804 :             tensor_large => tensor_contr_3
     891              :          END SELECT
     892       216620 :          tensor_small => tensor_contr_1
     893              : 
     894              :       END SELECT
     895              : 
     896       151816 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
     897              :                                                              tensor_contr_2, indchar2_mod, &
     898        52104 :                                                              tensor_contr_3, indchar3_mod, unit_nr_prv)
     899       151816 :       IF (unit_nr_prv /= 0) THEN
     900        52104 :          IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
     901        52104 :          IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
     902        52104 :          IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
     903        52104 :          IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
     904              :       END IF
     905              : 
     906              :       CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
     907              :                             tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
     908              :                             beta, &
     909              :                             tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
     910              :                             unit_nr=unit_nr_prv, log_verbose=log_verbose, &
     911              :                             split_opt=split_opt, &
     912       151816 :                             move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
     913              : 
     914       151816 :       IF (PRESENT(pgrid_opt_1)) THEN
     915            0 :          IF (.NOT. new_1) THEN
     916            0 :             ALLOCATE (pgrid_opt_1)
     917            0 :             pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
     918              :          END IF
     919              :       END IF
     920              : 
     921       151816 :       IF (PRESENT(pgrid_opt_2)) THEN
     922            0 :          IF (.NOT. new_2) THEN
     923            0 :             ALLOCATE (pgrid_opt_2)
     924            0 :             pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
     925              :          END IF
     926              :       END IF
     927              : 
     928       151816 :       IF (PRESENT(pgrid_opt_3)) THEN
     929            0 :          IF (.NOT. new_3) THEN
     930            0 :             ALLOCATE (pgrid_opt_3)
     931            0 :             pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
     932              :          END IF
     933              :       END IF
     934              : 
     935       151816 :       do_batched = tensor_small%matrix_rep%do_batched > 0
     936              : 
     937       151816 :       tensors_remapped = .FALSE.
     938       151816 :       IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE.
     939              : 
     940       151816 :       IF (tensors_remapped .AND. do_batched) THEN
     941              :          CALL cp_warn(__LOCATION__, &
     942            0 :                       "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
     943              :       END IF
     944              : 
     945              :       ! optimize process grid during batched contraction
     946       151816 :       do_change_pgrid(:) = .FALSE.
     947       151816 :       IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
     948              :          ASSOCIATE (storage => tensor_small%contraction_storage)
     949            0 :             CPASSERT(storage%static)
     950        93704 :             split = dbt_tas_info(tensor_large%matrix_rep)
     951              :             do_change_pgrid(:) = &
     952        93704 :                update_contraction_storage(storage, split_opt, split)
     953              : 
     954       373576 :             IF (ANY(do_change_pgrid)) THEN
     955          620 :                mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
     956              :                CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
     957          620 :                                          NINT(storage%nsplit_avg), own_comm=.TRUE.)
     958         1860 :                pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
     959              :             END IF
     960              : 
     961              :          END ASSOCIATE
     962              : 
     963        93704 :          IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
     964              :             ! check if new grid has better subgrid, if not there is no need to change process grid
     965         1860 :             pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
     966         1860 :             pdims_sub = split%mp_comm_group%num_pe_cart
     967              : 
     968         3100 :             pdim_ratio = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
     969         3100 :             pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, dp))/MINVAL(pdims_sub_opt)
     970          620 :             IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
     971            0 :                do_change_pgrid(1) = .FALSE.
     972            0 :                CALL dbt_tas_release_info(split_opt_avg)
     973              :             END IF
     974              :          END IF
     975              :       END IF
     976              : 
     977       151816 :       IF (unit_nr_prv /= 0) THEN
     978        52104 :          do_write_3 = .TRUE.
     979        52104 :          IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
     980        22680 :             IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE.
     981              :          END IF
     982              :          IF (do_write_3) THEN
     983        29462 :             CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
     984        29462 :             CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
     985              :          END IF
     986              :       END IF
     987              : 
     988       151816 :       IF (new_3) THEN
     989              :          ! need redistribute if we created new tensor for tensor 3
     990        11144 :          CALL dbt_scale(tensor_algn_3, beta)
     991        11144 :          CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.)
     992        11144 :          IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
     993              :          ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
     994              :          ! pointer to data of tensor_3
     995              :       END IF
     996              : 
     997              :       ! transfer contraction storage
     998       151816 :       CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
     999       151816 :       CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
    1000       151816 :       CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
    1001              : 
    1002       151816 :       IF (unit_nr_prv /= 0) THEN
    1003        52104 :          IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
    1004        52104 :          IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
    1005              :       END IF
    1006              : 
    1007       151816 :       CALL dbt_destroy(tensor_algn_1)
    1008       151816 :       CALL dbt_destroy(tensor_algn_2)
    1009       151816 :       CALL dbt_destroy(tensor_algn_3)
    1010              : 
    1011       151816 :       IF (do_crop_1) THEN
    1012        69783 :          CALL dbt_destroy(tensor_crop_1)
    1013        69783 :          DEALLOCATE (tensor_crop_1)
    1014              :       END IF
    1015              : 
    1016       151816 :       IF (do_crop_2) THEN
    1017        66773 :          CALL dbt_destroy(tensor_crop_2)
    1018        66773 :          DEALLOCATE (tensor_crop_2)
    1019              :       END IF
    1020              : 
    1021       151816 :       IF (new_1) THEN
    1022        11214 :          CALL dbt_destroy(tensor_contr_1)
    1023        11214 :          DEALLOCATE (tensor_contr_1)
    1024              :       END IF
    1025       151816 :       IF (new_2) THEN
    1026         3333 :          CALL dbt_destroy(tensor_contr_2)
    1027         3333 :          DEALLOCATE (tensor_contr_2)
    1028              :       END IF
    1029       151816 :       IF (new_3) THEN
    1030        11144 :          CALL dbt_destroy(tensor_contr_3)
    1031        11144 :          DEALLOCATE (tensor_contr_3)
    1032              :       END IF
    1033              : 
    1034       151816 :       IF (PRESENT(move_data)) THEN
    1035        38325 :          IF (move_data) THEN
    1036        32893 :             CALL dbt_clear(tensor_1)
    1037        32893 :             CALL dbt_clear(tensor_2)
    1038              :          END IF
    1039              :       END IF
    1040              : 
    1041       151816 :       IF (unit_nr_prv > 0) THEN
    1042           10 :          WRITE (unit_nr_prv, '(A)') repeat("-", 80)
    1043           10 :          WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
    1044           10 :          WRITE (unit_nr_prv, '(A)') repeat("-", 80)
    1045              :       END IF
    1046              : 
    1047       454208 :       IF (ANY(do_change_pgrid)) THEN
    1048          620 :          pgrid_changed_any = .FALSE.
    1049            0 :          SELECT CASE (max_mm_dim)
    1050              :          CASE (1)
    1051            0 :             IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
    1052              :                CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1053              :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1054              :                                         pgrid_changed=pgrid_changed, &
    1055            0 :                                         unit_nr=unit_nr_prv)
    1056            0 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1057              :                CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1058              :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1059              :                                         pgrid_changed=pgrid_changed, &
    1060            0 :                                         unit_nr=unit_nr_prv)
    1061            0 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1062              :             END IF
    1063            0 :             IF (pgrid_changed_any) THEN
    1064            0 :                IF (tensor_2%matrix_rep%do_batched == 3) THEN
    1065              :                   ! set flag that process grid has been optimized to make sure that no grid optimizations are done
    1066              :                   ! in TAS multiply algorithm
    1067            0 :                   CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
    1068              :                END IF
    1069              :             END IF
    1070              :          CASE (2)
    1071          172 :             IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
    1072              :                CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1073              :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1074              :                                         pgrid_changed=pgrid_changed, &
    1075          172 :                                         unit_nr=unit_nr_prv)
    1076          172 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1077              :                CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1078              :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1079              :                                         pgrid_changed=pgrid_changed, &
    1080          172 :                                         unit_nr=unit_nr_prv)
    1081          172 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1082              :             END IF
    1083            8 :             IF (pgrid_changed_any) THEN
    1084          172 :                IF (tensor_3%matrix_rep%do_batched == 3) THEN
    1085          160 :                   CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
    1086              :                END IF
    1087              :             END IF
    1088              :          CASE (3)
    1089          448 :             IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
    1090              :                CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1091              :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1092              :                                         pgrid_changed=pgrid_changed, &
    1093          218 :                                         unit_nr=unit_nr_prv)
    1094          218 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1095              :                CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1096              :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1097              :                                         pgrid_changed=pgrid_changed, &
    1098          218 :                                         unit_nr=unit_nr_prv)
    1099          218 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1100              :             END IF
    1101          620 :             IF (pgrid_changed_any) THEN
    1102          218 :                IF (tensor_1%matrix_rep%do_batched == 3) THEN
    1103          218 :                   CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
    1104              :                END IF
    1105              :             END IF
    1106              :          END SELECT
    1107          620 :          CALL dbt_tas_release_info(split_opt_avg)
    1108              :       END IF
    1109              : 
    1110       151816 :       IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
    1111              :          ! freeze TAS process grids if tensor grids were optimized
    1112        93704 :          CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.)
    1113        93704 :          CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.)
    1114        93704 :          CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.)
    1115              :       END IF
    1116              : 
    1117       151816 :       CALL dbt_tas_release_info(split_opt)
    1118              : 
    1119       151816 :       CALL timestop(handle)
    1120              : 
    1121       488310 :    END SUBROUTINE
    1122              : 
    1123              : ! **************************************************************************************************
    1124              : !> \brief align tensor index with data
    1125              : !> \author Patrick Seewald
    1126              : ! **************************************************************************************************
    1127      4099032 :    SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
    1128       455448 :                            tensor_out, contract_out, notcontract_out, indp_in, indp_out)
    1129              :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in
    1130              :       INTEGER, DIMENSION(:), INTENT(IN)            :: contract_in, notcontract_in
    1131              :       TYPE(dbt_type), INTENT(OUT)              :: tensor_out
    1132              :       INTEGER, DIMENSION(SIZE(contract_in)), &
    1133              :          INTENT(OUT)                               :: contract_out
    1134              :       INTEGER, DIMENSION(SIZE(notcontract_in)), &
    1135              :          INTENT(OUT)                               :: notcontract_out
    1136              :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
    1137              :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
    1138       455448 :       INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
    1139              : 
    1140       455448 :       CALL dbt_align_index(tensor_in, tensor_out, order=align)
    1141      1040007 :       contract_out = align(contract_in)
    1142      1067921 :       notcontract_out = align(notcontract_in)
    1143      1652480 :       indp_out(align) = indp_in
    1144              : 
    1145       455448 :    END SUBROUTINE
    1146              : 
    1147              : ! **************************************************************************************************
    1148              : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
    1149              : !>        matrix multiplication. This routine reshapes the two largest of the three tensors.
    1150              : !>        Redistribution is avoided if tensors already in a consistent layout.
    1151              : !> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
    1152              : !> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
    1153              : !>                    1:1 correspondence with ind1_linked
    1154              : !> \param trans1 transpose flag of matrix rep. tensor 1
    1155              : !> \param trans2 transpose flag of matrix rep. tensor 2
    1156              : !> \param new1 whether a new tensor 1 was created
    1157              : !> \param new2 whether a new tensor 2 was created
    1158              : !> \param nodata1 don't copy data of tensor 1
    1159              : !> \param nodata2 don't copy data of tensor 2
    1160              : !> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
    1161              : !> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
    1162              : !> \param optimize_dist experimental: optimize distribution
    1163              : !> \param unit_nr output unit
    1164              : !> \author Patrick Seewald
    1165              : ! **************************************************************************************************
    1166       151816 :    SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
    1167       151816 :                                     ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
    1168              :                                     nodata1, nodata2, move_data_1, &
    1169              :                                     move_data_2, optimize_dist, unit_nr)
    1170              :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor1
    1171              :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor2
    1172              :       TYPE(dbt_type), POINTER, INTENT(OUT)    :: tensor1_out, tensor2_out
    1173              :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_free, ind2_free
    1174              :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_linked, ind2_linked
    1175              :       LOGICAL, INTENT(OUT)                        :: trans1, trans2
    1176              :       LOGICAL, INTENT(OUT)                        :: new1, new2
    1177              :       INTEGER, INTENT(OUT) :: ref_tensor
    1178              :       LOGICAL, INTENT(IN), OPTIONAL               :: nodata1, nodata2
    1179              :       LOGICAL, INTENT(INOUT), OPTIONAL            :: move_data_1, move_data_2
    1180              :       LOGICAL, INTENT(IN), OPTIONAL               :: optimize_dist
    1181              :       INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
    1182              :       INTEGER                                     :: compat1, compat1_old, compat2, compat2_old, &
    1183              :                                                      unit_nr_prv
    1184       151816 :       TYPE(mp_cart_type)                          :: comm_2d
    1185       151816 :       TYPE(array_list)                            :: dist_list
    1186       151816 :       INTEGER, DIMENSION(:), ALLOCATABLE          :: mp_dims
    1187      1062712 :       TYPE(dbt_distribution_type)             :: dist_in
    1188              :       INTEGER(KIND=int_8)                         :: nblkrows, nblkcols
    1189              :       LOGICAL                                     :: optimize_dist_prv
    1190       303632 :       INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
    1191       151816 :       INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
    1192              : 
    1193       151816 :       NULLIFY (tensor1_out, tensor2_out)
    1194              : 
    1195       151816 :       unit_nr_prv = prep_output_unit(unit_nr)
    1196              : 
    1197       151816 :       CALL blk_dims_tensor(tensor1, dims1)
    1198       151816 :       CALL blk_dims_tensor(tensor2, dims2)
    1199              : 
    1200      1043254 :       IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN
    1201       150004 :          ref_tensor = 1
    1202              :       ELSE
    1203         1812 :          ref_tensor = 2
    1204              :       END IF
    1205              : 
    1206       151816 :       IF (PRESENT(optimize_dist)) THEN
    1207          298 :          optimize_dist_prv = optimize_dist
    1208              :       ELSE
    1209              :          optimize_dist_prv = .FALSE.
    1210              :       END IF
    1211              : 
    1212       151816 :       compat1 = compat_map(tensor1%nd_index, ind1_linked)
    1213       151816 :       compat2 = compat_map(tensor2%nd_index, ind2_linked)
    1214       151816 :       compat1_old = compat1
    1215       151816 :       compat2_old = compat2
    1216              : 
    1217       151816 :       IF (unit_nr_prv > 0) THEN
    1218           10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":"
    1219            6 :          SELECT CASE (compat1)
    1220              :          CASE (0)
    1221            6 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1222              :          CASE (1)
    1223            3 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1224              :          CASE (2)
    1225           10 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1226              :          END SELECT
    1227           10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":"
    1228            5 :          SELECT CASE (compat2)
    1229              :          CASE (0)
    1230            5 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1231              :          CASE (1)
    1232            4 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1233              :          CASE (2)
    1234           10 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1235              :          END SELECT
    1236              :       END IF
    1237              : 
    1238       151816 :       new1 = .FALSE.
    1239       151816 :       new2 = .FALSE.
    1240              : 
    1241       151816 :       IF (compat1 == 0 .OR. optimize_dist_prv) THEN
    1242        14403 :          new1 = .TRUE.
    1243              :       END IF
    1244              : 
    1245       151816 :       IF (compat2 == 0 .OR. optimize_dist_prv) THEN
    1246        11266 :          new2 = .TRUE.
    1247              :       END IF
    1248              : 
    1249       151816 :       IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
    1250       150004 :          IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
    1251        14267 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name)
    1252        42801 :             nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8))
    1253        28536 :             nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8))
    1254        14267 :             comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
    1255        99869 :             ALLOCATE (tensor1_out)
    1256              :             CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
    1257        14267 :                            nodata=nodata1, move_data=move_data_1)
    1258        14267 :             CALL comm_2d%free()
    1259        14267 :             compat1 = 1
    1260              :          ELSE
    1261       135737 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
    1262       135737 :             tensor1_out => tensor1
    1263              :          END IF
    1264       150004 :          IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
    1265        11130 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
    1266            8 :                TRIM(tensor2%name), "compatible with", TRIM(tensor1%name)
    1267        11126 :             dist_in = dbt_distribution(tensor1_out)
    1268        11126 :             dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
    1269        11126 :             IF (compat1 == 1) THEN ! linked index is first 2d dimension
    1270              :                ! get distribution of linked index, tensor 2 must adopt this distribution
    1271              :                ! get grid dimensions of linked index
    1272        17004 :                ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
    1273         5668 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
    1274        39676 :                ALLOCATE (tensor2_out)
    1275              :                CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1276         5668 :                               dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
    1277         5458 :             ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
    1278              :                ! get distribution of linked index, tensor 2 must adopt this distribution
    1279              :                ! get grid dimensions of linked index
    1280        16374 :                ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
    1281         5458 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
    1282        38206 :                ALLOCATE (tensor2_out)
    1283              :                CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1284         5458 :                               dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
    1285              :             ELSE
    1286            0 :                CPABORT("should not happen")
    1287              :             END IF
    1288        11126 :             compat2 = compat1
    1289              :          ELSE
    1290       138878 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
    1291       138878 :             tensor2_out => tensor2
    1292              :          END IF
    1293              :       ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
    1294         1812 :          IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
    1295          140 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name)
    1296          290 :             nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8))
    1297          282 :             nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8))
    1298          140 :             comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
    1299          980 :             ALLOCATE (tensor2_out)
    1300          140 :             CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
    1301          140 :             CALL comm_2d%free()
    1302          140 :             compat2 = 1
    1303              :          ELSE
    1304         1672 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
    1305         1672 :             tensor2_out => tensor2
    1306              :          END IF
    1307         1812 :          IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
    1308          139 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), &
    1309            6 :                "compatible with", TRIM(tensor2%name)
    1310          136 :             dist_in = dbt_distribution(tensor2_out)
    1311          136 :             dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
    1312          136 :             IF (compat2 == 1) THEN
    1313          402 :                ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
    1314          134 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
    1315          938 :                ALLOCATE (tensor1_out)
    1316              :                CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1317          134 :                               dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
    1318            2 :             ELSEIF (compat2 == 2) THEN
    1319            6 :                ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
    1320            2 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
    1321           14 :                ALLOCATE (tensor1_out)
    1322              :                CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1323            2 :                               dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
    1324              :             ELSE
    1325            0 :                CPABORT("should not happen")
    1326              :             END IF
    1327          136 :             compat1 = compat2
    1328              :          ELSE
    1329         1676 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
    1330         1676 :             tensor1_out => tensor1
    1331              :          END IF
    1332              :       END IF
    1333              : 
    1334        97229 :       SELECT CASE (compat1)
    1335              :       CASE (1)
    1336        97229 :          trans1 = .FALSE.
    1337              :       CASE (2)
    1338        54587 :          trans1 = .TRUE.
    1339              :       CASE DEFAULT
    1340       151816 :          CPABORT("should not happen")
    1341              :       END SELECT
    1342              : 
    1343        97447 :       SELECT CASE (compat2)
    1344              :       CASE (1)
    1345        97447 :          trans2 = .FALSE.
    1346              :       CASE (2)
    1347        54369 :          trans2 = .TRUE.
    1348              :       CASE DEFAULT
    1349       151816 :          CPABORT("should not happen")
    1350              :       END SELECT
    1351              : 
    1352       151816 :       IF (unit_nr_prv > 0) THEN
    1353           10 :          IF (compat1 .NE. compat1_old) THEN
    1354            6 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":"
    1355            0 :             SELECT CASE (compat1)
    1356              :             CASE (0)
    1357            0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1358              :             CASE (1)
    1359            5 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1360              :             CASE (2)
    1361            6 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1362              :             END SELECT
    1363              :          END IF
    1364           10 :          IF (compat2 .NE. compat2_old) THEN
    1365            5 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":"
    1366            0 :             SELECT CASE (compat2)
    1367              :             CASE (0)
    1368            0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1369              :             CASE (1)
    1370            4 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1371              :             CASE (2)
    1372            5 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1373              :             END SELECT
    1374              :          END IF
    1375              :       END IF
    1376              : 
    1377       151816 :       IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE.
    1378       151816 :       IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE.
    1379              : 
    1380       151816 :    END SUBROUTINE
    1381              : 
    1382              : ! **************************************************************************************************
    1383              : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
    1384              : !>        matrix multiplication. This routine reshapes the smallest of the three tensors.
    1385              : !> \param ind1 index that should be mapped to first matrix dimension
    1386              : !> \param ind2 index that should be mapped to second matrix dimension
    1387              : !> \param trans transpose flag of matrix rep.
    1388              : !> \param new whether a new tensor was created for tensor_out
    1389              : !> \param nodata don't copy tensor data
    1390              : !> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
    1391              : !> \param unit_nr output unit
    1392              : !> \author Patrick Seewald
    1393              : ! **************************************************************************************************
    1394       151816 :    SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
    1395              :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor_in
    1396              :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1, ind2
    1397              :       TYPE(dbt_type), POINTER, INTENT(OUT)    :: tensor_out
    1398              :       LOGICAL, INTENT(OUT)                        :: trans
    1399              :       LOGICAL, INTENT(OUT)                        :: new
    1400              :       LOGICAL, INTENT(IN), OPTIONAL               :: nodata, move_data
    1401              :       INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
    1402              :       INTEGER                                     :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
    1403              :       LOGICAL                                     :: nodata_prv
    1404              : 
    1405       151816 :       NULLIFY (tensor_out)
    1406              :       IF (PRESENT(nodata)) THEN
    1407       151816 :          nodata_prv = nodata
    1408              :       ELSE
    1409       151816 :          nodata_prv = .FALSE.
    1410              :       END IF
    1411              : 
    1412       151816 :       unit_nr_prv = prep_output_unit(unit_nr)
    1413              : 
    1414       151816 :       new = .FALSE.
    1415       151816 :       compat1 = compat_map(tensor_in%nd_index, ind1)
    1416       151816 :       compat2 = compat_map(tensor_in%nd_index, ind2)
    1417       151816 :       compat1_old = compat1; compat2_old = compat2
    1418       151816 :       IF (unit_nr_prv > 0) THEN
    1419           10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":"
    1420           10 :          IF (compat1 == 1 .AND. compat2 == 2) THEN
    1421            4 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1422            6 :          ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1423            2 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1424              :          ELSE
    1425            4 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1426              :          END IF
    1427              :       END IF
    1428       151816 :       IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
    1429              : 
    1430           22 :          IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name)
    1431              : 
    1432          154 :          ALLOCATE (tensor_out)
    1433           22 :          CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
    1434           22 :          CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
    1435           22 :          compat1 = 1
    1436           22 :          compat2 = 2
    1437           22 :          new = .TRUE.
    1438              :       ELSE
    1439       151794 :          IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name)
    1440       151794 :          tensor_out => tensor_in
    1441              :       END IF
    1442              : 
    1443       151816 :       IF (compat1 == 1 .AND. compat2 == 2) THEN
    1444       113162 :          trans = .FALSE.
    1445        38654 :       ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1446        38654 :          trans = .TRUE.
    1447              :       ELSE
    1448            0 :          CPABORT("this should not happen")
    1449              :       END IF
    1450              : 
    1451       151816 :       IF (unit_nr_prv > 0) THEN
    1452           10 :          IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
    1453            4 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":"
    1454            4 :             IF (compat1 == 1 .AND. compat2 == 2) THEN
    1455            4 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1456            0 :             ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1457            0 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1458              :             ELSE
    1459            0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1460              :             END IF
    1461              :          END IF
    1462              :       END IF
    1463              : 
    1464       151816 :    END SUBROUTINE
    1465              : 
    1466              : ! **************************************************************************************************
    1467              : !> \brief update contraction storage that keeps track of process grids during a batched contraction
    1468              : !>        and decide if tensor process grid needs to be optimized
    1469              : !> \param split_opt optimized TAS process grid
    1470              : !> \param split current TAS process grid
    1471              : !> \author Patrick Seewald
    1472              : ! **************************************************************************************************
    1473        93704 :    FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
    1474              :       TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
    1475              :       TYPE(dbt_tas_split_info), INTENT(IN)           :: split_opt
    1476              :       TYPE(dbt_tas_split_info), INTENT(IN)           :: split
    1477              :       INTEGER, DIMENSION(2) :: pdims, pdims_sub
    1478              :       LOGICAL, DIMENSION(2) :: do_change_pgrid
    1479              :       REAL(kind=dp) :: change_criterion, pdims_ratio
    1480              :       INTEGER :: nsplit_opt, nsplit
    1481              : 
    1482        93704 :       CPASSERT(ALLOCATED(split_opt%ngroup_opt))
    1483        93704 :       nsplit_opt = split_opt%ngroup_opt
    1484        93704 :       nsplit = split%ngroup
    1485              : 
    1486       281112 :       pdims = split%mp_comm%num_pe_cart
    1487              : 
    1488        93704 :       storage%ibatch = storage%ibatch + 1
    1489              : 
    1490              :       storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, dp) + REAL(nsplit_opt, dp)) &
    1491        93704 :                            /REAL(storage%ibatch, dp)
    1492              : 
    1493        93704 :       SELECT CASE (split_opt%split_rowcol)
    1494              :       CASE (rowsplit)
    1495        93704 :          pdims_ratio = REAL(pdims(1), dp)/pdims(2)
    1496              :       CASE (colsplit)
    1497        93704 :          pdims_ratio = REAL(pdims(2), dp)/pdims(1)
    1498              :       END SELECT
    1499              : 
    1500       281112 :       do_change_pgrid(:) = .FALSE.
    1501              : 
    1502              :       ! check for process grid dimensions
    1503       281112 :       pdims_sub = split%mp_comm_group%num_pe_cart
    1504       562224 :       change_criterion = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
    1505        93704 :       IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.
    1506              : 
    1507              :       ! check for split factor
    1508        93704 :       change_criterion = MAX(REAL(nsplit, dp)/storage%nsplit_avg, REAL(storage%nsplit_avg, dp)/nsplit)
    1509        93704 :       IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE.
    1510              : 
    1511        93704 :    END FUNCTION
    1512              : 
    1513              : ! **************************************************************************************************
    1514              : !> \brief Check if 2d index is compatible with tensor index
    1515              : !> \author Patrick Seewald
    1516              : ! **************************************************************************************************
    1517       607264 :    FUNCTION compat_map(nd_index, compat_ind)
    1518              :       TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
    1519              :       INTEGER, DIMENSION(:), INTENT(IN)  :: compat_ind
    1520      1214528 :       INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
    1521      1214528 :       INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
    1522              :       INTEGER                            :: compat_map
    1523              : 
    1524       607264 :       CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
    1525              : 
    1526       607264 :       compat_map = 0
    1527       607264 :       IF (array_eq_i(map1, compat_ind)) THEN
    1528              :          compat_map = 1
    1529       280871 :       ELSEIF (array_eq_i(map2, compat_ind)) THEN
    1530       255764 :          compat_map = 2
    1531              :       END IF
    1532              : 
    1533       607264 :    END FUNCTION
    1534              : 
    1535              : ! **************************************************************************************************
    1536              : !> \brief
    1537              : !> \author Patrick Seewald
    1538              : ! **************************************************************************************************
    1539       455448 :    SUBROUTINE index_linked_sort(ind_ref, ind_dep)
    1540              :       INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
    1541       910896 :       INTEGER, DIMENSION(SIZE(ind_ref))    :: sort_indices
    1542              : 
    1543       455448 :       CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
    1544      2107928 :       ind_dep(:) = ind_dep(sort_indices)
    1545              : 
    1546       455448 :    END SUBROUTINE
    1547              : 
    1548              : ! **************************************************************************************************
    1549              : !> \brief
    1550              : !> \author Patrick Seewald
    1551              : ! **************************************************************************************************
    1552            0 :    FUNCTION opt_pgrid(tensor, tas_split_info)
    1553              :       TYPE(dbt_type), INTENT(IN) :: tensor
    1554              :       TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
    1555            0 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    1556            0 :       INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
    1557              :       TYPE(dbt_pgrid_type) :: opt_pgrid
    1558            0 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
    1559              : 
    1560            0 :       CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
    1561            0 :       CALL blk_dims_tensor(tensor, dims)
    1562            0 :       opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
    1563              : 
    1564            0 :       ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info)
    1565            0 :       CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
    1566            0 :    END FUNCTION
    1567              : 
    1568              : ! **************************************************************************************************
    1569              : !> \brief Copy tensor to tensor with modified index mapping
    1570              : !> \param map1_2d new index mapping
    1571              : !> \param map2_2d new index mapping
    1572              : !> \author Patrick Seewald
    1573              : ! **************************************************************************************************
    1574       231219 :    SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
    1575        25691 :                         mp_dims_1, mp_dims_2, name, nodata, move_data)
    1576              :       TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
    1577              :       INTEGER, DIMENSION(:), INTENT(IN)      :: map1_2d, map2_2d
    1578              :       TYPE(dbt_type), INTENT(OUT)        :: tensor_out
    1579              :       CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
    1580              :       LOGICAL, INTENT(IN), OPTIONAL          :: nodata, move_data
    1581              :       CLASS(mp_comm_type), INTENT(IN), OPTIONAL          :: comm_2d
    1582              :       TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
    1583              :       INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
    1584              :       INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
    1585              :       CHARACTER(len=default_string_length)   :: name_tmp
    1586        25691 :       INTEGER, DIMENSION(:), ALLOCATABLE     :: ${varlist("blk_sizes")}$, &
    1587        25691 :                                                 ${varlist("nd_dist")}$
    1588       179837 :       TYPE(dbt_distribution_type)        :: dist
    1589        25691 :       TYPE(mp_cart_type) :: comm_2d_prv
    1590              :       INTEGER                                :: handle, i
    1591        25691 :       INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
    1592              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_remap'
    1593              :       LOGICAL                               :: nodata_prv
    1594        77073 :       TYPE(dbt_pgrid_type)              :: comm_nd
    1595              : 
    1596        25691 :       CALL timeset(routineN, handle)
    1597              : 
    1598        25691 :       IF (PRESENT(name)) THEN
    1599            0 :          name_tmp = name
    1600              :       ELSE
    1601        25691 :          name_tmp = tensor_in%name
    1602              :       END IF
    1603        25691 :       IF (PRESENT(dist1)) THEN
    1604         5802 :          CPASSERT(PRESENT(mp_dims_1))
    1605              :       END IF
    1606              : 
    1607        25691 :       IF (PRESENT(dist2)) THEN
    1608         5460 :          CPASSERT(PRESENT(mp_dims_2))
    1609              :       END IF
    1610              : 
    1611        25691 :       IF (PRESENT(comm_2d)) THEN
    1612        25529 :          comm_2d_prv = comm_2d
    1613              :       ELSE
    1614          162 :          comm_2d_prv = tensor_in%pgrid%mp_comm_2d
    1615              :       END IF
    1616              : 
    1617        25691 :       comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
    1618        25691 :       CALL mp_environ_pgrid(comm_nd, pdims, myploc)
    1619              : 
    1620              :       #:for ndim in ndims
    1621        51252 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1622        25561 :             CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$)
    1623              :          END IF
    1624              :       #:endfor
    1625              : 
    1626              :       #:for ndim in ndims
    1627        51378 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1628              :             #:for idim in range(1, ndim+1)
    1629        76947 :                IF (PRESENT(dist1)) THEN
    1630        40486 :                   IF (ANY(map1_2d == ${idim}$)) THEN
    1631        34292 :                      i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d
    1632         5804 :                      CALL get_ith_array(dist1, i, nd_dist_${idim}$)
    1633              :                   END IF
    1634              :                END IF
    1635              : 
    1636        76947 :                IF (PRESENT(dist2)) THEN
    1637        43664 :                   IF (ANY(map2_2d == ${idim}$)) THEN
    1638        32728 :                      i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d
    1639        10912 :                      CALL get_ith_array(dist2, i, nd_dist_${idim}$)
    1640              :                   END IF
    1641              :                END IF
    1642              : 
    1643        76947 :                IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN
    1644       163683 :                   ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$)))
    1645        54561 :                   CALL dbt_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$)
    1646              :                END IF
    1647              :             #:endfor
    1648              :             CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
    1649        25691 :                                              ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.)
    1650              :          END IF
    1651              :       #:endfor
    1652              : 
    1653              :       #:for ndim in ndims
    1654        51378 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1655              :             CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
    1656        25691 :                             ${varlist("blk_sizes", nmax=ndim)}$)
    1657              :          END IF
    1658              :       #:endfor
    1659              : 
    1660        25691 :       IF (PRESENT(nodata)) THEN
    1661        11144 :          nodata_prv = nodata
    1662              :       ELSE
    1663              :          nodata_prv = .FALSE.
    1664              :       END IF
    1665              : 
    1666        25691 :       IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
    1667        25691 :       CALL dbt_distribution_destroy(dist)
    1668              : 
    1669        25691 :       CALL timestop(handle)
    1670        77073 :    END SUBROUTINE
    1671              : 
    1672              : ! **************************************************************************************************
    1673              : !> \brief Align index with data
    1674              : !> \param order permutation resulting from alignment
    1675              : !> \author Patrick Seewald
    1676              : ! **************************************************************************************************
    1677      3643584 :    SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
    1678              :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in
    1679              :       TYPE(dbt_type), INTENT(OUT)                 :: tensor_out
    1680       910896 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
    1681       910896 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
    1682              :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
    1683              :          INTENT(OUT), OPTIONAL                        :: order
    1684       455448 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))     :: order_prv
    1685              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_align_index'
    1686              :       INTEGER                                         :: handle
    1687              : 
    1688       455448 :       CALL timeset(routineN, handle)
    1689              : 
    1690       455448 :       CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
    1691      2849512 :       order_prv = dbt_inverse_order([map1_2d, map2_2d])
    1692       455448 :       CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
    1693              : 
    1694      1652480 :       IF (PRESENT(order)) order = order_prv
    1695              : 
    1696       455448 :       CALL timestop(handle)
    1697       455448 :    END SUBROUTINE
    1698              : 
    1699              : ! **************************************************************************************************
    1700              : !> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
    1701              : !> \author Patrick Seewald
    1702              : ! **************************************************************************************************
    1703      5198472 :    SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
    1704              :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor_in
    1705              :       TYPE(dbt_type), INTENT(OUT)                 :: tensor_out
    1706              :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
    1707              :          INTENT(IN)                                   :: order
    1708              : 
    1709      2888040 :       TYPE(nd_to_2d_mapping)                          :: nd_index_blk_rs, nd_index_rs
    1710              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_permute_index'
    1711              :       INTEGER                                         :: handle
    1712              :       INTEGER                                         :: ndims
    1713              : 
    1714       577608 :       CALL timeset(routineN, handle)
    1715              : 
    1716       577608 :       ndims = ndims_tensor(tensor_in)
    1717              : 
    1718       577608 :       CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
    1719       577608 :       CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
    1720       577608 :       CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
    1721              : 
    1722       577608 :       tensor_out%matrix_rep => tensor_in%matrix_rep
    1723       577608 :       tensor_out%owns_matrix = .FALSE.
    1724              : 
    1725       577608 :       tensor_out%nd_index = nd_index_rs
    1726       577608 :       tensor_out%nd_index_blk = nd_index_blk_rs
    1727       577608 :       tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
    1728       577608 :       IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
    1729       577608 :          ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info)
    1730              :       END IF
    1731       577608 :       tensor_out%refcount => tensor_in%refcount
    1732       577608 :       CALL dbt_hold(tensor_out)
    1733              : 
    1734       577608 :       CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
    1735       577608 :       CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
    1736       577608 :       CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
    1737       577608 :       CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
    1738      1732824 :       ALLOCATE (tensor_out%nblks_local(ndims))
    1739      1155216 :       ALLOCATE (tensor_out%nfull_local(ndims))
    1740      2117676 :       tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
    1741      2117676 :       tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
    1742       577608 :       tensor_out%name = tensor_in%name
    1743       577608 :       tensor_out%valid = .TRUE.
    1744              : 
    1745       577608 :       IF (ALLOCATED(tensor_in%contraction_storage)) THEN
    1746       332308 :          ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage)
    1747       332308 :          CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
    1748       332308 :          CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
    1749              :       END IF
    1750              : 
    1751       577608 :       CALL timestop(handle)
    1752      1155216 :    END SUBROUTINE
    1753              : 
    1754              : ! **************************************************************************************************
    1755              : !> \brief Map contraction bounds to bounds referring to tensor indices
    1756              : !>        see dbt_contract for docu of dummy arguments
    1757              : !> \param bounds_t1 bounds mapped to tensor_1
    1758              : !> \param bounds_t2 bounds mapped to tensor_2
    1759              : !> \param do_crop_1 whether tensor 1 should be cropped
    1760              : !> \param do_crop_2 whether tensor 2 should be cropped
    1761              : !> \author Patrick Seewald
    1762              : ! **************************************************************************************************
    1763       168247 :    SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
    1764       168247 :                                         contract_1, notcontract_1, &
    1765       336494 :                                         contract_2, notcontract_2, &
    1766       168247 :                                         bounds_t1, bounds_t2, &
    1767       122974 :                                         bounds_1, bounds_2, bounds_3, &
    1768              :                                         do_crop_1, do_crop_2)
    1769              : 
    1770              :       TYPE(dbt_type), INTENT(IN)      :: tensor_1, tensor_2
    1771              :       INTEGER, DIMENSION(:), INTENT(IN)   :: contract_1, contract_2, &
    1772              :                                              notcontract_1, notcontract_2
    1773              :       INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
    1774              :          INTENT(OUT)                                 :: bounds_t1
    1775              :       INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
    1776              :          INTENT(OUT)                                 :: bounds_t2
    1777              :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
    1778              :          INTENT(IN), OPTIONAL                        :: bounds_1
    1779              :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
    1780              :          INTENT(IN), OPTIONAL                        :: bounds_2
    1781              :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
    1782              :          INTENT(IN), OPTIONAL                        :: bounds_3
    1783              :       LOGICAL, INTENT(OUT), OPTIONAL                 :: do_crop_1, do_crop_2
    1784              :       LOGICAL, DIMENSION(2)                          :: do_crop
    1785              : 
    1786       168247 :       do_crop = .FALSE.
    1787              : 
    1788       590978 :       bounds_t1(1, :) = 1
    1789       590978 :       CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
    1790              : 
    1791       631621 :       bounds_t2(1, :) = 1
    1792       631621 :       CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
    1793              : 
    1794       168247 :       IF (PRESENT(bounds_1)) THEN
    1795       180838 :          bounds_t1(:, contract_1) = bounds_1
    1796        78374 :          do_crop(1) = .TRUE.
    1797       180838 :          bounds_t2(:, contract_2) = bounds_1
    1798       168247 :          do_crop(2) = .TRUE.
    1799              :       END IF
    1800              : 
    1801       168247 :       IF (PRESENT(bounds_2)) THEN
    1802       243412 :          bounds_t1(:, notcontract_1) = bounds_2
    1803       168247 :          do_crop(1) = .TRUE.
    1804              :       END IF
    1805              : 
    1806       168247 :       IF (PRESENT(bounds_3)) THEN
    1807       267092 :          bounds_t2(:, notcontract_2) = bounds_3
    1808       168247 :          do_crop(2) = .TRUE.
    1809              :       END IF
    1810              : 
    1811       168247 :       IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
    1812       168247 :       IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
    1813              : 
    1814       385350 :    END SUBROUTINE
    1815              : 
    1816              : ! **************************************************************************************************
    1817              : !> \brief print tensor contraction indices in a human readable way
    1818              : !> \param indchar1 characters printed for index of tensor 1
    1819              : !> \param indchar2 characters printed for index of tensor 2
    1820              : !> \param indchar3 characters printed for index of tensor 3
    1821              : !> \param unit_nr output unit
    1822              : !> \author Patrick Seewald
    1823              : ! **************************************************************************************************
    1824       156312 :    SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
    1825              :       TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
    1826              :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
    1827              :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
    1828              :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
    1829              :       INTEGER, INTENT(IN) :: unit_nr
    1830       312624 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
    1831       312624 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
    1832       312624 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
    1833       312624 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
    1834       312624 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
    1835       312624 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
    1836              :       INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
    1837              : 
    1838       156312 :       unit_nr_prv = prep_output_unit(unit_nr)
    1839              : 
    1840       156312 :       IF (unit_nr_prv /= 0) THEN
    1841       156312 :          CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
    1842       156312 :          CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
    1843       156312 :          CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
    1844              :       END IF
    1845              : 
    1846       156312 :       IF (unit_nr_prv > 0) THEN
    1847           30 :          WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
    1848           30 :          WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
    1849          123 :          DO ichar1 = 1, SIZE(indchar1)
    1850          123 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
    1851              :          END DO
    1852           30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
    1853          120 :          DO ichar2 = 1, SIZE(indchar2)
    1854          120 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
    1855              :          END DO
    1856           30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
    1857          123 :          DO ichar3 = 1, SIZE(indchar3)
    1858          123 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
    1859              :          END DO
    1860           30 :          WRITE (unit_nr_prv, '(A)') ")"
    1861              : 
    1862           30 :          WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
    1863           82 :          DO ichar1 = 1, SIZE(map11)
    1864           82 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
    1865              :          END DO
    1866           30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1867           71 :          DO ichar1 = 1, SIZE(map12)
    1868           71 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
    1869              :          END DO
    1870           30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
    1871           76 :          DO ichar2 = 1, SIZE(map21)
    1872           76 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
    1873              :          END DO
    1874           30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1875           74 :          DO ichar2 = 1, SIZE(map22)
    1876           74 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
    1877              :          END DO
    1878           30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
    1879           79 :          DO ichar3 = 1, SIZE(map31)
    1880           79 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
    1881              :          END DO
    1882           30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1883           74 :          DO ichar3 = 1, SIZE(map32)
    1884           74 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
    1885              :          END DO
    1886           30 :          WRITE (unit_nr_prv, '(A)') ")"
    1887              :       END IF
    1888              : 
    1889       156312 :    END SUBROUTINE
    1890              : 
    1891              : ! **************************************************************************************************
    1892              : !> \brief Initialize batched contraction for this tensor.
    1893              : !>
    1894              : !>        Explanation: A batched contraction is a contraction performed in several consecutive steps
    1895              : !>        by specification of bounds in dbt_contract. This can be used to reduce memory by
    1896              : !>        a large factor. The routines dbt_batched_contract_init and
    1897              : !>        dbt_batched_contract_finalize should be called to define the scope of a batched
    1898              : !>        contraction as this enables important optimizations (adapting communication scheme to
    1899              : !>        batches and adapting process grid to multiplication algorithm). The routines
    1900              : !>        dbt_batched_contract_init and dbt_batched_contract_finalize must be
    1901              : !>        called before the first and after the last contraction step on all 3 tensors.
    1902              : !>
    1903              : !>        Requirements:
    1904              : !>        - the tensors are in a compatible matrix layout (see documentation of
    1905              : !>          `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
    1906              : !>          disabled and a warning is issued.
    1907              : !>        - within the scope of a batched contraction, it is not allowed to access or change tensor
    1908              : !>          data except by calling the routines dbt_contract & dbt_copy.
    1909              : !>        - the bounds affecting indices of the smallest tensor must not change in the course of a
    1910              : !>          batched contraction (todo: get rid of this requirement).
    1911              : !>
    1912              : !>        Side effects:
    1913              : !>        - the parallel layout (process grid and distribution) of all tensors may change. In order
    1914              : !>          to disable the process grid optimization including this side effect, call this routine
    1915              : !>          only on the smallest of the 3 tensors.
    1916              : !>
    1917              : !> \note
    1918              : !>        Note 1: for an example of batched contraction see `examples/dbt_example.F`.
    1919              : !>        (todo: the example is outdated and should be updated).
    1920              : !>
    1921              : !>        Note 2: it is meaningful to use this feature if the contraction consists of one batch only
    1922              : !>        but if multiple contractions involving the same 3 tensors are performed
    1923              : !>        (batched_contract_init and batched_contract_finalize must then be called before/after each
    1924              : !>        contraction call). The process grid is then optimized after the first contraction
    1925              : !>        and future contraction may profit from this optimization.
    1926              : !>
    1927              : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
    1928              : !>                      a new range. The size should be the number of ranges plus one, the last
    1929              : !>                      element being the block index plus one of the last block in the last range.
    1930              : !>                      For internal load balancing optimizations, optionally specify the index
    1931              : !>                      ranges of batched contraction.
    1932              : !> \author Patrick Seewald
    1933              : ! **************************************************************************************************
    1934       117943 :    SUBROUTINE dbt_batched_contract_init(tensor, ${varlist("batch_range")}$)
    1935              :       TYPE(dbt_type), INTENT(INOUT) :: tensor
    1936              :       INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
    1937       235886 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
    1938       117943 :       INTEGER, DIMENSION(:), ALLOCATABLE                 :: ${varlist("batch_range_prv")}$
    1939              :       LOGICAL :: static_range
    1940              : 
    1941       117943 :       CALL dbt_get_info(tensor, nblks_total=tdims)
    1942              : 
    1943       117943 :       static_range = .TRUE.
    1944              :       #:for idim in range(1, maxdim+1)
    1945       117943 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    1946       278112 :             IF (PRESENT(batch_range_${idim}$)) THEN
    1947       432380 :                ALLOCATE (batch_range_prv_${idim}$, source=batch_range_${idim}$)
    1948       278112 :                static_range = .FALSE.
    1949              :             ELSE
    1950       211256 :                ALLOCATE (batch_range_prv_${idim}$ (2))
    1951       211256 :                batch_range_prv_${idim}$ (1) = 1
    1952       211256 :                batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1
    1953              :             END IF
    1954              :          END IF
    1955              :       #:endfor
    1956              : 
    1957       117943 :       ALLOCATE (tensor%contraction_storage)
    1958       117943 :       tensor%contraction_storage%static = static_range
    1959       117943 :       IF (static_range) THEN
    1960        79575 :          CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
    1961              :       END IF
    1962       117943 :       tensor%contraction_storage%nsplit_avg = 0.0_dp
    1963       117943 :       tensor%contraction_storage%ibatch = 0
    1964              : 
    1965              :       #:for ndim in range(1, maxdim+1)
    1966       235886 :          IF (ndims_tensor(tensor) == ${ndim}$) THEN
    1967              :             CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, &
    1968       117943 :                                    ${varlist("batch_range_prv", nmax=ndim)}$)
    1969              :          END IF
    1970              :       #:endfor
    1971              : 
    1972       117943 :    END SUBROUTINE
    1973              : 
    1974              : ! **************************************************************************************************
    1975              : !> \brief finalize batched contraction. This performs all communication that has been postponed in
    1976              : !>         the contraction calls.
    1977              : !> \author Patrick Seewald
    1978              : ! **************************************************************************************************
    1979       235886 :    SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
    1980              :       TYPE(dbt_type), INTENT(INOUT) :: tensor
    1981              :       INTEGER, INTENT(IN), OPTIONAL :: unit_nr
    1982              :       LOGICAL :: do_write
    1983              :       INTEGER :: unit_nr_prv, handle
    1984              : 
    1985       117943 :       CALL tensor%pgrid%mp_comm_2d%sync()
    1986       117943 :       CALL timeset("dbt_total", handle)
    1987       117943 :       unit_nr_prv = prep_output_unit(unit_nr)
    1988              : 
    1989       117943 :       do_write = .FALSE.
    1990              : 
    1991       117943 :       IF (tensor%contraction_storage%static) THEN
    1992        79575 :          IF (tensor%matrix_rep%do_batched > 0) THEN
    1993        79575 :             IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE.
    1994              :          END IF
    1995        79575 :          CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
    1996              :       END IF
    1997              : 
    1998       117943 :       IF (do_write .AND. unit_nr_prv /= 0) THEN
    1999        17242 :          IF (unit_nr_prv > 0) THEN
    2000              :             WRITE (unit_nr_prv, "(T2,A)") &
    2001            0 :                "FINALIZING BATCHED PROCESSING OF MATMUL"
    2002              :          END IF
    2003        17242 :          CALL dbt_write_tensor_info(tensor, unit_nr_prv)
    2004        17242 :          CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
    2005              :       END IF
    2006              : 
    2007       117943 :       CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
    2008       117943 :       DEALLOCATE (tensor%contraction_storage)
    2009       117943 :       CALL tensor%pgrid%mp_comm_2d%sync()
    2010       117943 :       CALL timestop(handle)
    2011              : 
    2012       117943 :    END SUBROUTINE
    2013              : 
    2014              : ! **************************************************************************************************
    2015              : !> \brief change the process grid of a tensor
    2016              : !> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
    2017              : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
    2018              : !>                      a new range. The size should be the number of ranges plus one, the last
    2019              : !>                      element being the block index plus one of the last block in the last range.
    2020              : !>                      For internal load balancing optimizations, optionally specify the index
    2021              : !>                      ranges of batched contraction.
    2022              : !> \author Patrick Seewald
    2023              : ! **************************************************************************************************
    2024          780 :    SUBROUTINE dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, &
    2025              :                                nodata, pgrid_changed, unit_nr)
    2026              :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor
    2027              :       TYPE(dbt_pgrid_type), INTENT(IN)               :: pgrid
    2028              :       INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
    2029              :       !!
    2030              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    2031              :       LOGICAL, INTENT(OUT), OPTIONAL                     :: pgrid_changed
    2032              :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    2033              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_change_pgrid'
    2034              :       CHARACTER(default_string_length)                   :: name
    2035              :       INTEGER                                            :: handle
    2036          780 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ${varlist("bs")}$, &
    2037          780 :                                                             ${varlist("dist")}$
    2038         1560 :       INTEGER, DIMENSION(ndims_tensor(tensor))           :: pcoord, pcoord_ref, pdims, pdims_ref, &
    2039         1560 :                                                             tdims
    2040         5460 :       TYPE(dbt_type)                                 :: t_tmp
    2041         5460 :       TYPE(dbt_distribution_type)                    :: dist
    2042         1560 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    2043              :       INTEGER, &
    2044         1560 :          DIMENSION(ndims_matrix_column(tensor))    :: map2
    2045         1560 :       LOGICAL, DIMENSION(ndims_tensor(tensor))             :: mem_aware
    2046          780 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
    2047              :       INTEGER :: ind1, ind2, batch_size, ibatch
    2048              : 
    2049          780 :       IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE.
    2050          780 :       CALL mp_environ_pgrid(pgrid, pdims, pcoord)
    2051          780 :       CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
    2052              : 
    2053          804 :       IF (ALL(pdims == pdims_ref)) THEN
    2054            8 :          IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
    2055            8 :             IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
    2056              :                RETURN
    2057              :             END IF
    2058              :          END IF
    2059              :       END IF
    2060              : 
    2061          772 :       CALL timeset(routineN, handle)
    2062              : 
    2063              :       #:for idim in range(1, maxdim+1)
    2064         3088 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    2065         2316 :             mem_aware(${idim}$) = PRESENT(batch_range_${idim}$)
    2066         2316 :             IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1
    2067              :          END IF
    2068              :       #:endfor
    2069              : 
    2070          772 :       CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
    2071              : 
    2072              :       #:for idim in range(1, maxdim+1)
    2073         3088 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    2074         6948 :             ALLOCATE (bs_${idim}$ (dbt_nblks_total(tensor, ${idim}$)))
    2075         2316 :             CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$)
    2076         6948 :             ALLOCATE (dist_${idim}$ (tdims(${idim}$)))
    2077        16996 :             dist_${idim}$ = 0
    2078         2316 :             IF (mem_aware(${idim}$)) THEN
    2079         6292 :                DO ibatch = 1, nbatch(${idim}$)
    2080         3976 :                   ind1 = batch_range_${idim}$ (ibatch)
    2081         3976 :                   ind2 = batch_range_${idim}$ (ibatch + 1) - 1
    2082         3976 :                   batch_size = ind2 - ind1 + 1
    2083              :                   CALL dbt_default_distvec(batch_size, pdims(${idim}$), &
    2084         6292 :                                            bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2))
    2085              :                END DO
    2086              :             ELSE
    2087            0 :                CALL dbt_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$)
    2088              :             END IF
    2089              :          END IF
    2090              :       #:endfor
    2091              : 
    2092          772 :       CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
    2093              :       #:for ndim in ndims
    2094         1544 :          IF (ndims_tensor(tensor) == ${ndim}$) THEN
    2095          772 :             CALL dbt_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$)
    2096          772 :             CALL dbt_create(t_tmp, name, dist, map1, map2, ${varlist("bs", nmax=ndim)}$)
    2097              :          END IF
    2098              :       #:endfor
    2099          772 :       CALL dbt_distribution_destroy(dist)
    2100              : 
    2101          772 :       IF (PRESENT(nodata)) THEN
    2102            0 :          IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
    2103              :       ELSE
    2104          772 :          CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
    2105              :       END IF
    2106              : 
    2107          772 :       CALL dbt_copy_contraction_storage(tensor, t_tmp)
    2108              : 
    2109          772 :       CALL dbt_destroy(tensor)
    2110          772 :       tensor = t_tmp
    2111              : 
    2112          772 :       IF (PRESENT(unit_nr)) THEN
    2113          772 :          IF (unit_nr > 0) THEN
    2114            0 :             WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name)
    2115            0 :             WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
    2116            0 :             CALL dbt_write_split_info(pgrid, unit_nr)
    2117              :          END IF
    2118              :       END IF
    2119              : 
    2120          772 :       IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE.
    2121              : 
    2122          772 :       CALL timestop(handle)
    2123          780 :    END SUBROUTINE
    2124              : 
    2125              : ! **************************************************************************************************
    2126              : !> \brief map tensor to a new 2d process grid for the matrix representation.
    2127              : !> \author Patrick Seewald
    2128              : ! **************************************************************************************************
    2129          780 :    SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
    2130              :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor
    2131              :       TYPE(mp_cart_type), INTENT(IN)               :: mp_comm
    2132              :       INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
    2133              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    2134              :       INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
    2135              :       LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
    2136              :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    2137         1560 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    2138         1560 :       INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
    2139         1560 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
    2140         2340 :       TYPE(dbt_pgrid_type) :: pgrid
    2141          780 :       INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$
    2142          780 :       INTEGER, DIMENSION(:), ALLOCATABLE :: array
    2143              :       INTEGER :: idim
    2144              : 
    2145          780 :       CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
    2146          780 :       CALL blk_dims_tensor(tensor, dims)
    2147              : 
    2148          780 :       IF (ALLOCATED(tensor%contraction_storage)) THEN
    2149              :          ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges)
    2150         3120 :             nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
    2151              :             ! for good load balancing the process grid dimensions should be chosen adapted to the
    2152              :             ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
    2153              :             ! the number of batches (number of index ranges).
    2154         3900 :             DO idim = 1, ndims_tensor(tensor)
    2155         2340 :                CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
    2156         2340 :                dims(idim) = array(nbatches(idim) + 1) - array(1)
    2157         2340 :                DEALLOCATE (array)
    2158         2340 :                dims(idim) = dims(idim)/nbatches(idim)
    2159         5460 :                IF (dims(idim) <= 0) dims(idim) = 1
    2160              :             END DO
    2161              :          END ASSOCIATE
    2162              :       END IF
    2163              : 
    2164          780 :       pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
    2165          780 :       IF (ALLOCATED(tensor%contraction_storage)) THEN
    2166              :          #:for ndim in range(1, maxdim+1)
    2167         1560 :             IF (ndims_tensor(tensor) == ${ndim}$) THEN
    2168          780 :                CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$)
    2169              :                CALL dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, &
    2170          780 :                                      nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
    2171              :             END IF
    2172              :          #:endfor
    2173              :       ELSE
    2174            0 :          CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
    2175              :       END IF
    2176          780 :       CALL dbt_pgrid_destroy(pgrid)
    2177              : 
    2178          780 :    END SUBROUTINE
    2179              : 
    2180       131437 : END MODULE
        

Generated by: LCOV version 2.0-1