LCOV - code coverage report
Current view: top level - src/dbt - dbt_methods.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:34ef472) Lines: 843 893 94.4 %
Date: 2024-04-26 08:30:29 Functions: 22 23 95.7 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief DBT tensor framework for block-sparse tensor contraction.
      10             : !>        Representation of n-rank tensors as DBT tall-and-skinny matrices.
      11             : !>        Support for arbitrary redistribution between different representations.
      12             : !>        Support for arbitrary tensor contractions
      13             : !> \todo implement checks and error messages
      14             : !> \author Patrick Seewald
      15             : ! **************************************************************************************************
      16             : MODULE dbt_methods
      17             :    #:include "dbt_macros.fypp"
      18             :    #:set maxdim = maxrank
      19             :    #:set ndims = range(2,maxdim+1)
      20             : 
      21             :    USE dbcsr_api, ONLY: &
      22             :       dbcsr_type, dbcsr_release, &
      23             :       dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
      24             :       dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
      25             :    USE dbt_allocate_wrap, ONLY: &
      26             :       allocate_any
      27             :    USE dbt_array_list_methods, ONLY: &
      28             :       get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, &
      29             :       create_array_list, destroy_array_list, sizes_of_arrays
      30             :    USE dbm_api, ONLY: &
      31             :       dbm_clear
      32             :    USE dbt_tas_types, ONLY: &
      33             :       dbt_tas_split_info
      34             :    USE dbt_tas_base, ONLY: &
      35             :       dbt_tas_copy, dbt_tas_finalize, dbt_tas_get_info, dbt_tas_info
      36             :    USE dbt_tas_mm, ONLY: &
      37             :       dbt_tas_multiply, dbt_tas_batched_mm_init, dbt_tas_batched_mm_finalize, &
      38             :       dbt_tas_batched_mm_complete, dbt_tas_set_batched_state
      39             :    USE dbt_block, ONLY: &
      40             :       dbt_iterator_type, dbt_get_block, dbt_put_block, dbt_iterator_start, &
      41             :       dbt_iterator_blocks_left, dbt_iterator_stop, dbt_iterator_next_block, &
      42             :       ndims_iterator, dbt_reserve_blocks, block_nd, destroy_block, checker_tr
      43             :    USE dbt_index, ONLY: &
      44             :       dbt_get_mapping_info, nd_to_2d_mapping, dbt_inverse_order, permute_index, get_nd_indices_tensor, &
      45             :       ndims_mapping_row, ndims_mapping_column, ndims_mapping
      46             :    USE dbt_types, ONLY: &
      47             :       dbt_create, dbt_type, ndims_tensor, dims_tensor, &
      48             :       dbt_distribution_type, dbt_distribution, dbt_nd_mp_comm, dbt_destroy, &
      49             :       dbt_distribution_destroy, dbt_distribution_new_expert, dbt_get_stored_coordinates, &
      50             :       blk_dims_tensor, dbt_hold, dbt_pgrid_type, mp_environ_pgrid, dbt_filter, &
      51             :       dbt_clear, dbt_finalize, dbt_get_num_blocks, dbt_scale, &
      52             :       dbt_get_num_blocks_total, dbt_get_info, ndims_matrix_row, ndims_matrix_column, &
      53             :       dbt_max_nblks_local, dbt_default_distvec, dbt_contraction_storage, dbt_nblks_total, &
      54             :       dbt_distribution_new, dbt_copy_contraction_storage, dbt_pgrid_destroy
      55             :    USE kinds, ONLY: &
      56             :       dp, default_string_length, int_8, dp
      57             :    USE message_passing, ONLY: &
      58             :       mp_cart_type
      59             :    USE util, ONLY: &
      60             :       sort
      61             :    USE dbt_reshape_ops, ONLY: &
      62             :       dbt_reshape
      63             :    USE dbt_tas_split, ONLY: &
      64             :       dbt_tas_mp_comm, rowsplit, colsplit, dbt_tas_info_hold, dbt_tas_release_info, default_nsplit_accept_ratio, &
      65             :       default_pdims_accept_ratio, dbt_tas_create_split
      66             :    USE dbt_split, ONLY: &
      67             :       dbt_split_copyback, dbt_make_compatible_blocks, dbt_crop
      68             :    USE dbt_io, ONLY: &
      69             :       dbt_write_tensor_info, dbt_write_tensor_dist, prep_output_unit, dbt_write_split_info
      70             :    USE message_passing, ONLY: mp_comm_type
      71             : 
      72             : #include "../base/base_uses.f90"
      73             : 
      74             :    IMPLICIT NONE
      75             :    PRIVATE
      76             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
      77             : 
      78             :    PUBLIC :: &
      79             :       dbt_contract, &
      80             :       dbt_copy, &
      81             :       dbt_get_block, &
      82             :       dbt_get_stored_coordinates, &
      83             :       dbt_inverse_order, &
      84             :       dbt_iterator_blocks_left, &
      85             :       dbt_iterator_next_block, &
      86             :       dbt_iterator_start, &
      87             :       dbt_iterator_stop, &
      88             :       dbt_iterator_type, &
      89             :       dbt_put_block, &
      90             :       dbt_reserve_blocks, &
      91             :       dbt_copy_matrix_to_tensor, &
      92             :       dbt_copy_tensor_to_matrix, &
      93             :       dbt_batched_contract_init, &
      94             :       dbt_batched_contract_finalize
      95             : 
      96             : CONTAINS
      97             : 
      98             : ! **************************************************************************************************
      99             : !> \brief Copy tensor data.
     100             : !>        Redistributes tensor data according to distributions of target and source tensor.
     101             : !>        Permutes tensor index according to `order` argument (if present).
     102             : !>        Source and target tensor formats are arbitrary as long as the following requirements are met:
     103             : !>        * source and target tensors have the same rank and the same sizes in each dimension in terms
     104             : !>          of tensor elements (block sizes don't need to be the same).
     105             : !>          If `order` argument is present, sizes must match after index permutation.
     106             : !>        OR
     107             : !>        * target tensor is not yet created, in this case an exact copy of source tensor is returned.
     108             : !> \param tensor_in Source
     109             : !> \param tensor_out Target
     110             : !> \param order Permutation of target tensor index.
     111             : !>              Exact same convention as order argument of RESHAPE intrinsic.
     112             : !> \param bounds crop tensor data: start and end index for each tensor dimension
     113             : !> \author Patrick Seewald
     114             : ! **************************************************************************************************
     115      710234 :    SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     116             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
     117             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
     118             :          INTENT(IN), OPTIONAL                        :: order
     119             :       LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
     120             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
     121             :          INTENT(IN), OPTIONAL                        :: bounds
     122             :       INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
     123             :       INTEGER :: handle
     124             : 
     125      355117 :       CALL tensor_in%pgrid%mp_comm_2d%sync()
     126      355117 :       CALL timeset("dbt_total", handle)
     127             : 
     128             :       ! make sure that it is safe to use dbt_copy during a batched contraction
     129      355117 :       CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
     130      355117 :       CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
     131             : 
     132      355117 :       CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     133      355117 :       CALL tensor_in%pgrid%mp_comm_2d%sync()
     134      355117 :       CALL timestop(handle)
     135      355117 :    END SUBROUTINE
     136             : 
     137             : ! **************************************************************************************************
     138             : !> \brief expert routine for copying a tensor. For internal use only.
     139             : !> \author Patrick Seewald
     140             : ! **************************************************************************************************
     141      356351 :    SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     142             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
     143             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
     144             :          INTENT(IN), OPTIONAL                        :: order
     145             :       LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
     146             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
     147             :          INTENT(IN), OPTIONAL                        :: bounds
     148             :       INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
     149             : 
     150             :       TYPE(dbt_type), POINTER                    :: in_tmp_1, in_tmp_2, &
     151             :                                                     in_tmp_3, out_tmp_1
     152             :       INTEGER                                        :: handle, unit_nr_prv
     153      356351 :       INTEGER, DIMENSION(:), ALLOCATABLE             :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
     154             : 
     155             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy'
     156             :       LOGICAL                                        :: dist_compatible_tas, dist_compatible_tensor, &
     157             :                                                         summation_prv, new_in_1, new_in_2, &
     158             :                                                         new_in_3, new_out_1, block_compatible, &
     159             :                                                         move_prv
     160      356351 :       TYPE(array_list)                               :: blk_sizes_in
     161             : 
     162      356351 :       CALL timeset(routineN, handle)
     163             : 
     164      356351 :       CPASSERT(tensor_out%valid)
     165             : 
     166      356351 :       unit_nr_prv = prep_output_unit(unit_nr)
     167             : 
     168      356351 :       IF (PRESENT(move_data)) THEN
     169      276339 :          move_prv = move_data
     170             :       ELSE
     171       80012 :          move_prv = .FALSE.
     172             :       END IF
     173             : 
     174      356351 :       dist_compatible_tas = .FALSE.
     175      356351 :       dist_compatible_tensor = .FALSE.
     176      356351 :       block_compatible = .FALSE.
     177      356351 :       new_in_1 = .FALSE.
     178      356351 :       new_in_2 = .FALSE.
     179      356351 :       new_in_3 = .FALSE.
     180      356351 :       new_out_1 = .FALSE.
     181             : 
     182      356351 :       IF (PRESENT(summation)) THEN
     183       96396 :          summation_prv = summation
     184             :       ELSE
     185             :          summation_prv = .FALSE.
     186             :       END IF
     187             : 
     188      356351 :       IF (PRESENT(bounds)) THEN
     189       39424 :          ALLOCATE (in_tmp_1)
     190        5632 :          CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
     191        5632 :          new_in_1 = .TRUE.
     192        5632 :          move_prv = .TRUE.
     193             :       ELSE
     194             :          in_tmp_1 => tensor_in
     195             :       END IF
     196             : 
     197      356351 :       IF (PRESENT(order)) THEN
     198       90634 :          CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
     199       90634 :          block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
     200             :       ELSE
     201      265717 :          block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
     202             :       END IF
     203             : 
     204      356351 :       IF (.NOT. block_compatible) THEN
     205      958243 :          ALLOCATE (in_tmp_2, out_tmp_1)
     206             :          CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
     207       73711 :                                          nodata2=.NOT. summation_prv, move_data=move_prv)
     208       73711 :          new_in_2 = .TRUE.; new_out_1 = .TRUE.
     209       73711 :          move_prv = .TRUE.
     210             :       ELSE
     211             :          in_tmp_2 => in_tmp_1
     212             :          out_tmp_1 => tensor_out
     213             :       END IF
     214             : 
     215      356351 :       IF (PRESENT(order)) THEN
     216      634438 :          ALLOCATE (in_tmp_3)
     217       90634 :          CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
     218       90634 :          new_in_3 = .TRUE.
     219             :       ELSE
     220             :          in_tmp_3 => in_tmp_2
     221             :       END IF
     222             : 
     223     1069053 :       ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
     224     1069053 :       ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
     225      356351 :       CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
     226             : 
     227     1069053 :       ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
     228     1069053 :       ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
     229      356351 :       CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
     230             : 
     231      356351 :       IF (.NOT. PRESENT(order)) THEN
     232      265717 :          IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
     233      254561 :             dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
     234      145308 :          ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
     235       10566 :             dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
     236             :          END IF
     237             :       END IF
     238             : 
     239      254561 :       IF (dist_compatible_tas) THEN
     240      205997 :          CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
     241      205997 :          IF (move_prv) CALL dbt_clear(in_tmp_3)
     242      150354 :       ELSEIF (dist_compatible_tensor) THEN
     243        2962 :          CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
     244        2962 :          IF (move_prv) CALL dbt_clear(in_tmp_3)
     245             :       ELSE
     246      147392 :          CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
     247             :       END IF
     248             : 
     249      356351 :       IF (new_in_1) THEN
     250        5632 :          CALL dbt_destroy(in_tmp_1)
     251        5632 :          DEALLOCATE (in_tmp_1)
     252             :       END IF
     253             : 
     254      356351 :       IF (new_in_2) THEN
     255       73711 :          CALL dbt_destroy(in_tmp_2)
     256       73711 :          DEALLOCATE (in_tmp_2)
     257             :       END IF
     258             : 
     259      356351 :       IF (new_in_3) THEN
     260       90634 :          CALL dbt_destroy(in_tmp_3)
     261       90634 :          DEALLOCATE (in_tmp_3)
     262             :       END IF
     263             : 
     264      356351 :       IF (new_out_1) THEN
     265       73711 :          IF (unit_nr_prv /= 0) THEN
     266           0 :             CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
     267             :          END IF
     268       73711 :          CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
     269       73711 :          CALL dbt_destroy(out_tmp_1)
     270       73711 :          DEALLOCATE (out_tmp_1)
     271             :       END IF
     272             : 
     273      356351 :       CALL timestop(handle)
     274             : 
     275      712702 :    END SUBROUTINE
     276             : 
     277             : ! **************************************************************************************************
     278             : !> \brief copy without communication, requires that both tensors have same process grid and distribution
     279             : !> \param summation Whether to sum matrices b = a + b
     280             : !> \author Patrick Seewald
     281             : ! **************************************************************************************************
     282        2962 :    SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
     283             :       TYPE(dbt_type), INTENT(INOUT) :: tensor_in
     284             :       TYPE(dbt_type), INTENT(INOUT) :: tensor_out
     285             :       LOGICAL, INTENT(IN), OPTIONAL                      :: summation
     286             :       TYPE(dbt_iterator_type) :: iter
     287        2962 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))  :: ind_nd
     288        2962 :       TYPE(block_nd) :: blk_data
     289             :       LOGICAL :: found
     290             : 
     291             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_nocomm'
     292             :       INTEGER :: handle
     293             : 
     294        2962 :       CALL timeset(routineN, handle)
     295        2962 :       CPASSERT(tensor_out%valid)
     296             : 
     297        2962 :       IF (PRESENT(summation)) THEN
     298          24 :          IF (.NOT. summation) CALL dbt_clear(tensor_out)
     299             :       ELSE
     300        2938 :          CALL dbt_clear(tensor_out)
     301             :       END IF
     302             : 
     303        2962 :       CALL dbt_reserve_blocks(tensor_in, tensor_out)
     304             : 
     305             : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
     306        2962 : !$OMP PRIVATE(iter,ind_nd,blk_data,found)
     307             :       CALL dbt_iterator_start(iter, tensor_in)
     308             :       DO WHILE (dbt_iterator_blocks_left(iter))
     309             :          CALL dbt_iterator_next_block(iter, ind_nd)
     310             :          CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
     311             :          CPASSERT(found)
     312             :          CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
     313             :          CALL destroy_block(blk_data)
     314             :       END DO
     315             :       CALL dbt_iterator_stop(iter)
     316             : !$OMP END PARALLEL
     317             : 
     318        2962 :       CALL timestop(handle)
     319        5924 :    END SUBROUTINE
     320             : 
     321             : ! **************************************************************************************************
     322             : !> \brief copy matrix to tensor.
     323             : !> \param summation tensor_out = tensor_out + matrix_in
     324             : !> \author Patrick Seewald
     325             : ! **************************************************************************************************
     326       56972 :    SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
     327             :       TYPE(dbcsr_type), TARGET, INTENT(IN)               :: matrix_in
     328             :       TYPE(dbt_type), INTENT(INOUT)             :: tensor_out
     329             :       LOGICAL, INTENT(IN), OPTIONAL                      :: summation
     330             :       TYPE(dbcsr_type), POINTER                          :: matrix_in_desym
     331             : 
     332             :       INTEGER, DIMENSION(2)                              :: ind_2d
     333       56972 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)    :: block_arr
     334       56972 :       REAL(KIND=dp), DIMENSION(:, :), POINTER        :: block
     335             :       TYPE(dbcsr_iterator_type)                          :: iter
     336             :       LOGICAL                                            :: tr
     337             : 
     338             :       INTEGER                                            :: handle
     339             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_matrix_to_tensor'
     340             : 
     341       56972 :       CALL timeset(routineN, handle)
     342       56972 :       CPASSERT(tensor_out%valid)
     343             : 
     344       56972 :       NULLIFY (block)
     345             : 
     346       56972 :       IF (dbcsr_has_symmetry(matrix_in)) THEN
     347        4530 :          ALLOCATE (matrix_in_desym)
     348        4530 :          CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
     349             :       ELSE
     350             :          matrix_in_desym => matrix_in
     351             :       END IF
     352             : 
     353       56972 :       IF (PRESENT(summation)) THEN
     354           0 :          IF (.NOT. summation) CALL dbt_clear(tensor_out)
     355             :       ELSE
     356       56972 :          CALL dbt_clear(tensor_out)
     357             :       END IF
     358             : 
     359       56972 :       CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
     360             : 
     361             : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
     362       56972 : !$OMP PRIVATE(iter,ind_2d,block,tr,block_arr)
     363             :       CALL dbcsr_iterator_start(iter, matrix_in_desym)
     364             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     365             :          CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr)
     366             :          CALL allocate_any(block_arr, source=block)
     367             :          CALL dbt_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation)
     368             :          DEALLOCATE (block_arr)
     369             :       END DO
     370             :       CALL dbcsr_iterator_stop(iter)
     371             : !$OMP END PARALLEL
     372             : 
     373       56972 :       IF (dbcsr_has_symmetry(matrix_in)) THEN
     374        4530 :          CALL dbcsr_release(matrix_in_desym)
     375        4530 :          DEALLOCATE (matrix_in_desym)
     376             :       END IF
     377             : 
     378       56972 :       CALL timestop(handle)
     379             : 
     380      113944 :    END SUBROUTINE
     381             : 
     382             : ! **************************************************************************************************
     383             : !> \brief copy tensor to matrix
     384             : !> \param summation matrix_out = matrix_out + tensor_in
     385             : !> \author Patrick Seewald
     386             : ! **************************************************************************************************
     387       35282 :    SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
     388             :       TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
     389             :       TYPE(dbcsr_type), INTENT(INOUT)             :: matrix_out
     390             :       LOGICAL, INTENT(IN), OPTIONAL          :: summation
     391             :       TYPE(dbt_iterator_type)            :: iter
     392             :       INTEGER                                :: handle
     393             :       INTEGER, DIMENSION(2)                  :: ind_2d
     394       35282 :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: block
     395             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_tensor_to_matrix'
     396             :       LOGICAL :: found
     397             : 
     398       35282 :       CALL timeset(routineN, handle)
     399             : 
     400       35282 :       IF (PRESENT(summation)) THEN
     401        5840 :          IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
     402             :       ELSE
     403       29442 :          CALL dbcsr_clear(matrix_out)
     404             :       END IF
     405             : 
     406       35282 :       CALL dbt_reserve_blocks(tensor_in, matrix_out)
     407             : 
     408             : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
     409       35282 : !$OMP PRIVATE(iter,ind_2d,block,found)
     410             :       CALL dbt_iterator_start(iter, tensor_in)
     411             :       DO WHILE (dbt_iterator_blocks_left(iter))
     412             :          CALL dbt_iterator_next_block(iter, ind_2d)
     413             :          IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE
     414             : 
     415             :          CALL dbt_get_block(tensor_in, ind_2d, block, found)
     416             :          CPASSERT(found)
     417             : 
     418             :          IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
     419             :             CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation)
     420             :          ELSE
     421             :             CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
     422             :          END IF
     423             :          DEALLOCATE (block)
     424             :       END DO
     425             :       CALL dbt_iterator_stop(iter)
     426             : !$OMP END PARALLEL
     427             : 
     428       35282 :       CALL timestop(handle)
     429             : 
     430       70564 :    END SUBROUTINE
     431             : 
     432             : ! **************************************************************************************************
     433             : !> \brief Contract tensors by multiplying matrix representations.
     434             : !>        tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
     435             : !>        * tensor_2(contract_2, notcontract_2)
     436             : !>        + beta * tensor_3(map_1, map_2)
     437             : !>
     438             : !> \note
     439             : !>      note 1: block sizes of the corresponding indices need to be the same in all tensors.
     440             : !>
     441             : !>      note 2: for best performance the tensors should have been created in matrix layouts
     442             : !>      compatible with the contraction, e.g. tensor_1 should have been created with either
     443             : !>      map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
     444             : !>      map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
     445             : !>      tensor_3 and map_1 / map_2).
     446             : !>      Furthermore the two largest tensors involved in the contraction should map both to either
     447             : !>      tall or short matrices: the largest matrix dimension should be "on the same side"
     448             : !>      and should have identical distribution (which is always the case if the distributions were
     449             : !>      obtained with dbt_default_distvec).
     450             : !>
     451             : !>      note 3: if the same tensor occurs in multiple contractions, a different tensor object should
     452             : !>      be created for each contraction and the data should be copied between the tensors by use of
     453             : !>      dbt_copy. If the same tensor object is used in multiple contractions,
     454             : !>       matrix layouts are not compatible for all contractions (see note 2).
     455             : !>
     456             : !>      note 4: automatic optimizations are enabled by using the feature of batched contraction, see
     457             : !>      dbt_batched_contract_init, dbt_batched_contract_finalize.
     458             : !>      The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
     459             : !>
     460             : !> \param tensor_1 first tensor (in)
     461             : !> \param tensor_2 second tensor (in)
     462             : !> \param contract_1 indices of tensor_1 to contract
     463             : !> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
     464             : !> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
     465             : !> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
     466             : !> \param notcontract_1 indices of tensor_1 not to contract
     467             : !> \param notcontract_2 indices of tensor_2 not to contract
     468             : !> \param tensor_3 contracted tensor (out)
     469             : !> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
     470             : !>                 start and end index of an index range over which to contract.
     471             : !>                 For use in batched contraction.
     472             : !> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
     473             : !>                 For use in batched contraction.
     474             : !> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
     475             : !>                 For use in batched contraction.
     476             : !> \param optimize_dist Whether distribution should be optimized internally. In the current
     477             : !>                      implementation this guarantees optimal parameters only for dense matrices.
     478             : !> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
     479             : !>                    This can be used to choose optimal process grids for subsequent tensor
     480             : !>                    contractions with tensors of similar shape and sparsity. Under some conditions,
     481             : !>                    pgrid_opt_1 can not be returned, in this case the pointer is not associated.
     482             : !> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
     483             : !> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
     484             : !> \param filter_eps As in DBM mm
     485             : !> \param flop As in DBM mm
     486             : !> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
     487             : !> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
     488             : !> \param unit_nr output unit for logging
     489             : !>                       set it to -1 on ranks that should not write (and any valid unit number on
     490             : !>                       ranks that should write output) if 0 on ALL ranks, no output is written
     491             : !> \param log_verbose verbose logging (for testing only)
     492             : !> \author Patrick Seewald
     493             : ! **************************************************************************************************
     494      222472 :    SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
     495      111236 :                            contract_1, notcontract_1, &
     496      111236 :                            contract_2, notcontract_2, &
     497      111236 :                            map_1, map_2, &
     498      114954 :                            bounds_1, bounds_2, bounds_3, &
     499             :                            optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
     500             :                            filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
     501             :       REAL(dp), INTENT(IN)            :: alpha
     502             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_1
     503             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_2
     504             :       REAL(dp), INTENT(IN)            :: beta
     505             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
     506             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
     507             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
     508             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
     509             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
     510             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
     511             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_3
     512             :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
     513             :          INTENT(IN), OPTIONAL                        :: bounds_1
     514             :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
     515             :          INTENT(IN), OPTIONAL                        :: bounds_2
     516             :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
     517             :          INTENT(IN), OPTIONAL                        :: bounds_3
     518             :       LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
     519             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     520             :          POINTER, OPTIONAL                           :: pgrid_opt_1
     521             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     522             :          POINTER, OPTIONAL                           :: pgrid_opt_2
     523             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     524             :          POINTER, OPTIONAL                           :: pgrid_opt_3
     525             :       REAL(KIND=dp), INTENT(IN), OPTIONAL        :: filter_eps
     526             :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
     527             :       LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
     528             :       LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
     529             :       INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
     530             :       LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
     531             : 
     532             :       INTEGER                     :: handle
     533             : 
     534      111236 :       CALL tensor_1%pgrid%mp_comm_2d%sync()
     535      111236 :       CALL timeset("dbt_total", handle)
     536             :       CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
     537             :                                contract_1, notcontract_1, &
     538             :                                contract_2, notcontract_2, &
     539             :                                map_1, map_2, &
     540             :                                bounds_1=bounds_1, &
     541             :                                bounds_2=bounds_2, &
     542             :                                bounds_3=bounds_3, &
     543             :                                optimize_dist=optimize_dist, &
     544             :                                pgrid_opt_1=pgrid_opt_1, &
     545             :                                pgrid_opt_2=pgrid_opt_2, &
     546             :                                pgrid_opt_3=pgrid_opt_3, &
     547             :                                filter_eps=filter_eps, &
     548             :                                flop=flop, &
     549             :                                move_data=move_data, &
     550             :                                retain_sparsity=retain_sparsity, &
     551             :                                unit_nr=unit_nr, &
     552      111236 :                                log_verbose=log_verbose)
     553      111236 :       CALL tensor_1%pgrid%mp_comm_2d%sync()
     554      111236 :       CALL timestop(handle)
     555             : 
     556      181396 :    END SUBROUTINE
     557             : 
     558             : ! **************************************************************************************************
     559             : !> \brief expert routine for tensor contraction. For internal use only.
     560             : !> \param nblks_local number of local blocks on this MPI rank
     561             : !> \author Patrick Seewald
     562             : ! **************************************************************************************************
     563      111236 :    SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
     564      111236 :                                   contract_1, notcontract_1, &
     565      111236 :                                   contract_2, notcontract_2, &
     566      111236 :                                   map_1, map_2, &
     567      111236 :                                   bounds_1, bounds_2, bounds_3, &
     568             :                                   optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
     569             :                                   filter_eps, flop, move_data, retain_sparsity, &
     570             :                                   nblks_local, unit_nr, log_verbose)
     571             :       REAL(dp), INTENT(IN)            :: alpha
     572             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_1
     573             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_2
     574             :       REAL(dp), INTENT(IN)            :: beta
     575             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
     576             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
     577             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
     578             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
     579             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
     580             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
     581             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_3
     582             :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
     583             :          INTENT(IN), OPTIONAL                        :: bounds_1
     584             :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
     585             :          INTENT(IN), OPTIONAL                        :: bounds_2
     586             :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
     587             :          INTENT(IN), OPTIONAL                        :: bounds_3
     588             :       LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
     589             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     590             :          POINTER, OPTIONAL                           :: pgrid_opt_1
     591             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     592             :          POINTER, OPTIONAL                           :: pgrid_opt_2
     593             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     594             :          POINTER, OPTIONAL                           :: pgrid_opt_3
     595             :       REAL(KIND=dp), INTENT(IN), OPTIONAL        :: filter_eps
     596             :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
     597             :       LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
     598             :       LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
     599             :       INTEGER, INTENT(OUT), OPTIONAL                 :: nblks_local
     600             :       INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
     601             :       LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
     602             : 
     603             :       TYPE(dbt_type), POINTER                    :: tensor_contr_1, tensor_contr_2, tensor_contr_3
     604     2113484 :       TYPE(dbt_type), TARGET                     :: tensor_algn_1, tensor_algn_2, tensor_algn_3
     605             :       TYPE(dbt_type), POINTER                    :: tensor_crop_1, tensor_crop_2
     606             :       TYPE(dbt_type), POINTER                    :: tensor_small, tensor_large
     607             : 
     608             :       LOGICAL                                        :: assert_stmt, tensors_remapped
     609             :       INTEGER                                        :: max_mm_dim, max_tensor, &
     610             :                                                         unit_nr_prv, ref_tensor, handle
     611      111236 :       TYPE(mp_cart_type) :: mp_comm_opt
     612      222472 :       INTEGER, DIMENSION(SIZE(contract_1))           :: contract_1_mod
     613      222472 :       INTEGER, DIMENSION(SIZE(notcontract_1))        :: notcontract_1_mod
     614      222472 :       INTEGER, DIMENSION(SIZE(contract_2))           :: contract_2_mod
     615      222472 :       INTEGER, DIMENSION(SIZE(notcontract_2))        :: notcontract_2_mod
     616      222472 :       INTEGER, DIMENSION(SIZE(map_1))                :: map_1_mod
     617      222472 :       INTEGER, DIMENSION(SIZE(map_2))                :: map_2_mod
     618             :       LOGICAL                                        :: trans_1, trans_2, trans_3
     619             :       LOGICAL                                        :: new_1, new_2, new_3, move_data_1, move_data_2
     620             :       INTEGER                                        :: ndims1, ndims2, ndims3
     621             :       INTEGER                                        :: occ_1, occ_2
     622      111236 :       INTEGER, DIMENSION(:), ALLOCATABLE             :: dims1, dims2, dims3
     623             : 
     624             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_contract'
     625      111236 :       CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE    :: indchar1, indchar2, indchar3, indchar1_mod, &
     626      111236 :                                                         indchar2_mod, indchar3_mod
     627             :       CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
     628             :                                                ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
     629      222472 :       INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
     630      222472 :       INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
     631             :       LOGICAL                                        :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
     632             :                                                         pgrid_changed_any, do_change_pgrid(2)
     633     1446068 :       TYPE(dbt_tas_split_info)                     :: split_opt, split, split_opt_avg
     634             :       INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
     635             :       REAL(dp) :: pdim_ratio, pdim_ratio_opt
     636             : 
     637      111236 :       NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
     638      111236 :                tensor_small)
     639             : 
     640      111236 :       CALL timeset(routineN, handle)
     641             : 
     642      111236 :       CPASSERT(tensor_1%valid)
     643      111236 :       CPASSERT(tensor_2%valid)
     644      111236 :       CPASSERT(tensor_3%valid)
     645             : 
     646      111236 :       assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
     647      111236 :       CPASSERT(assert_stmt)
     648             : 
     649      111236 :       assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
     650      111236 :       CPASSERT(assert_stmt)
     651             : 
     652      111236 :       assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
     653      111236 :       CPASSERT(assert_stmt)
     654             : 
     655      111236 :       assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
     656      111236 :       CPASSERT(assert_stmt)
     657             : 
     658      111236 :       assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
     659      111236 :       CPASSERT(assert_stmt)
     660             : 
     661      111236 :       assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
     662      111236 :       CPASSERT(assert_stmt)
     663             : 
     664      111236 :       unit_nr_prv = prep_output_unit(unit_nr)
     665             : 
     666      111236 :       IF (PRESENT(flop)) flop = 0
     667      111236 :       IF (PRESENT(nblks_local)) nblks_local = 0
     668             : 
     669      111236 :       IF (PRESENT(move_data)) THEN
     670       30817 :          move_data_1 = move_data
     671       30817 :          move_data_2 = move_data
     672             :       ELSE
     673       80419 :          move_data_1 = .FALSE.
     674       80419 :          move_data_2 = .FALSE.
     675             :       END IF
     676             : 
     677      111236 :       nodata_3 = .TRUE.
     678      111236 :       IF (PRESENT(retain_sparsity)) THEN
     679        4762 :          IF (retain_sparsity) nodata_3 = .FALSE.
     680             :       END IF
     681             : 
     682             :       CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
     683             :                                      contract_1, notcontract_1, &
     684             :                                      contract_2, notcontract_2, &
     685             :                                      bounds_t1, bounds_t2, &
     686             :                                      bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
     687      111236 :                                      do_crop_1=do_crop_1, do_crop_2=do_crop_2)
     688             : 
     689      111236 :       IF (do_crop_1) THEN
     690      476588 :          ALLOCATE (tensor_crop_1)
     691       68084 :          CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
     692       68084 :          move_data_1 = .TRUE.
     693             :       ELSE
     694             :          tensor_crop_1 => tensor_1
     695             :       END IF
     696             : 
     697      111236 :       IF (do_crop_2) THEN
     698      461986 :          ALLOCATE (tensor_crop_2)
     699       65998 :          CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
     700       65998 :          move_data_2 = .TRUE.
     701             :       ELSE
     702             :          tensor_crop_2 => tensor_2
     703             :       END IF
     704             : 
     705             :       ! shortcut for empty tensors
     706             :       ! this is needed to avoid unnecessary work in case user contracts different portions of a
     707             :       ! tensor consecutively to save memory
     708             :       ASSOCIATE (mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
     709      111236 :          occ_1 = dbt_get_num_blocks(tensor_crop_1)
     710      111236 :          CALL mp_comm%max(occ_1)
     711      111236 :          occ_2 = dbt_get_num_blocks(tensor_crop_2)
     712      111236 :          CALL mp_comm%max(occ_2)
     713             :       END ASSOCIATE
     714             : 
     715      111236 :       IF (occ_1 == 0 .OR. occ_2 == 0) THEN
     716        7447 :          CALL dbt_scale(tensor_3, beta)
     717        7447 :          IF (do_crop_1) THEN
     718        2778 :             CALL dbt_destroy(tensor_crop_1)
     719        2778 :             DEALLOCATE (tensor_crop_1)
     720             :          END IF
     721        7447 :          IF (do_crop_2) THEN
     722        2850 :             CALL dbt_destroy(tensor_crop_2)
     723        2850 :             DEALLOCATE (tensor_crop_2)
     724             :          END IF
     725             : 
     726        7447 :          CALL timestop(handle)
     727        7447 :          RETURN
     728             :       END IF
     729             : 
     730      103789 :       IF (unit_nr_prv /= 0) THEN
     731       46146 :          IF (unit_nr_prv > 0) THEN
     732          10 :             WRITE (unit_nr_prv, '(A)') repeat("-", 80)
     733          10 :             WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
     734          20 :                TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name)
     735          10 :             WRITE (unit_nr_prv, '(A)') repeat("-", 80)
     736             :          END IF
     737       46146 :          CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
     738       46146 :          CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
     739       46146 :          CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
     740       46146 :          CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
     741             :       END IF
     742             : 
     743             :       ! align tensor index with data, tensor data is not modified
     744      103789 :       ndims1 = ndims_tensor(tensor_crop_1)
     745      103789 :       ndims2 = ndims_tensor(tensor_crop_2)
     746      103789 :       ndims3 = ndims_tensor(tensor_3)
     747      415156 :       ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
     748      415156 :       ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
     749      415156 :       ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
     750             : 
     751             :       ! labeling tensor index with letters
     752             : 
     753      884995 :       indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
     754      250775 :       indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
     755      244164 :       indchar2(contract_2) = indchar1(contract_1)
     756      223816 :       indchar3(map_1) = indchar1(notcontract_1)
     757      250775 :       indchar3(map_2) = indchar2(notcontract_2)
     758             : 
     759      103789 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
     760             :                                                              tensor_crop_2, indchar2, &
     761       46146 :                                                              tensor_3, indchar3, unit_nr_prv)
     762      103789 :       IF (unit_nr_prv > 0) THEN
     763          10 :          WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
     764             :       END IF
     765             : 
     766             :       CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
     767      103789 :                         tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
     768             : 
     769             :       CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
     770      103789 :                         tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
     771             : 
     772             :       CALL align_tensor(tensor_3, map_1, map_2, &
     773      103789 :                         tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
     774             : 
     775      103789 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
     776             :                                                              tensor_algn_2, indchar2_mod, &
     777       46146 :                                                              tensor_algn_3, indchar3_mod, unit_nr_prv)
     778             : 
     779      311367 :       ALLOCATE (dims1(ndims1))
     780      311367 :       ALLOCATE (dims2(ndims2))
     781      311367 :       ALLOCATE (dims3(ndims3))
     782             : 
     783             :       ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
     784             :       ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
     785      103789 :       CALL blk_dims_tensor(tensor_crop_1, dims1)
     786      103789 :       CALL blk_dims_tensor(tensor_crop_2, dims2)
     787      103789 :       CALL blk_dims_tensor(tensor_3, dims3)
     788             : 
     789             :       max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), &
     790             :                            PRODUCT(INT(dims1(contract_1), int_8)), &
     791      822544 :                            PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1)
     792     1229932 :       max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1)
     793       24126 :       SELECT CASE (max_mm_dim)
     794             :       CASE (1)
     795       24126 :          IF (unit_nr_prv > 0) THEN
     796           3 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
     797           3 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     798             :          END IF
     799       24126 :          CALL index_linked_sort(contract_1_mod, contract_2_mod)
     800       24126 :          CALL index_linked_sort(map_2_mod, notcontract_2_mod)
     801       23698 :          SELECT CASE (max_tensor)
     802             :          CASE (1)
     803       23698 :             CALL index_linked_sort(notcontract_1_mod, map_1_mod)
     804             :          CASE (3)
     805         428 :             CALL index_linked_sort(map_1_mod, notcontract_1_mod)
     806             :          CASE DEFAULT
     807       24126 :             CPABORT("should not happen")
     808             :          END SELECT
     809             : 
     810             :          CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
     811             :                                     contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
     812             :                                     trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
     813       24126 :                                     move_data_1=move_data_1, unit_nr=unit_nr_prv)
     814             : 
     815             :          CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
     816       24126 :                                new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
     817             : 
     818       23698 :          SELECT CASE (ref_tensor)
     819             :          CASE (1)
     820       23698 :             tensor_large => tensor_contr_1
     821             :          CASE (2)
     822       24126 :             tensor_large => tensor_contr_3
     823             :          END SELECT
     824       24126 :          tensor_small => tensor_contr_2
     825             : 
     826             :       CASE (2)
     827       36988 :          IF (unit_nr_prv > 0) THEN
     828           5 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
     829           5 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     830             :          END IF
     831             : 
     832       36988 :          CALL index_linked_sort(notcontract_1_mod, map_1_mod)
     833       36988 :          CALL index_linked_sort(notcontract_2_mod, map_2_mod)
     834       36238 :          SELECT CASE (max_tensor)
     835             :          CASE (1)
     836       36238 :             CALL index_linked_sort(contract_1_mod, contract_2_mod)
     837             :          CASE (2)
     838         750 :             CALL index_linked_sort(contract_2_mod, contract_1_mod)
     839             :          CASE DEFAULT
     840       36988 :             CPABORT("should not happen")
     841             :          END SELECT
     842             : 
     843             :          CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
     844             :                                     notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
     845             :                                     trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
     846       36988 :                                     move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
     847       36988 :          trans_1 = .NOT. trans_1
     848             : 
     849             :          CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
     850       36988 :                                new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
     851             : 
     852       36238 :          SELECT CASE (ref_tensor)
     853             :          CASE (1)
     854       36238 :             tensor_large => tensor_contr_1
     855             :          CASE (2)
     856       36988 :             tensor_large => tensor_contr_2
     857             :          END SELECT
     858       36988 :          tensor_small => tensor_contr_3
     859             : 
     860             :       CASE (3)
     861       42675 :          IF (unit_nr_prv > 0) THEN
     862           2 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
     863           2 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     864             :          END IF
     865       42675 :          CALL index_linked_sort(map_1_mod, notcontract_1_mod)
     866       42675 :          CALL index_linked_sort(contract_2_mod, contract_1_mod)
     867       42293 :          SELECT CASE (max_tensor)
     868             :          CASE (2)
     869       42293 :             CALL index_linked_sort(notcontract_2_mod, map_2_mod)
     870             :          CASE (3)
     871         382 :             CALL index_linked_sort(map_2_mod, notcontract_2_mod)
     872             :          CASE DEFAULT
     873       42675 :             CPABORT("should not happen")
     874             :          END SELECT
     875             : 
     876             :          CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
     877             :                                     contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
     878             :                                     trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
     879       42675 :                                     move_data_1=move_data_2, unit_nr=unit_nr_prv)
     880             : 
     881       42675 :          trans_2 = .NOT. trans_2
     882       42675 :          trans_3 = .NOT. trans_3
     883             : 
     884             :          CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
     885       42675 :                                trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
     886             : 
     887       42293 :          SELECT CASE (ref_tensor)
     888             :          CASE (1)
     889       42293 :             tensor_large => tensor_contr_2
     890             :          CASE (2)
     891       42675 :             tensor_large => tensor_contr_3
     892             :          END SELECT
     893      146464 :          tensor_small => tensor_contr_1
     894             : 
     895             :       END SELECT
     896             : 
     897      103789 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
     898             :                                                              tensor_contr_2, indchar2_mod, &
     899       46146 :                                                              tensor_contr_3, indchar3_mod, unit_nr_prv)
     900      103789 :       IF (unit_nr_prv /= 0) THEN
     901       46146 :          IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
     902       46146 :          IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
     903       46146 :          IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
     904       46146 :          IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
     905             :       END IF
     906             : 
     907             :       CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
     908             :                             tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
     909             :                             beta, &
     910             :                             tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
     911             :                             unit_nr=unit_nr_prv, log_verbose=log_verbose, &
     912             :                             split_opt=split_opt, &
     913      103789 :                             move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
     914             : 
     915      103789 :       IF (PRESENT(pgrid_opt_1)) THEN
     916           0 :          IF (.NOT. new_1) THEN
     917           0 :             ALLOCATE (pgrid_opt_1)
     918           0 :             pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
     919             :          END IF
     920             :       END IF
     921             : 
     922      103789 :       IF (PRESENT(pgrid_opt_2)) THEN
     923           0 :          IF (.NOT. new_2) THEN
     924           0 :             ALLOCATE (pgrid_opt_2)
     925           0 :             pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
     926             :          END IF
     927             :       END IF
     928             : 
     929      103789 :       IF (PRESENT(pgrid_opt_3)) THEN
     930           0 :          IF (.NOT. new_3) THEN
     931           0 :             ALLOCATE (pgrid_opt_3)
     932           0 :             pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
     933             :          END IF
     934             :       END IF
     935             : 
     936      103789 :       do_batched = tensor_small%matrix_rep%do_batched > 0
     937             : 
     938      103789 :       tensors_remapped = .FALSE.
     939      103789 :       IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE.
     940             : 
     941      103789 :       IF (tensors_remapped .AND. do_batched) THEN
     942             :          CALL cp_warn(__LOCATION__, &
     943           0 :                       "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
     944             :       END IF
     945             : 
     946             :       ! optimize process grid during batched contraction
     947      103789 :       do_change_pgrid(:) = .FALSE.
     948      103789 :       IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
     949      559419 :          ASSOCIATE (storage => tensor_small%contraction_storage)
     950           0 :             CPASSERT(storage%static)
     951       79917 :             split = dbt_tas_info(tensor_large%matrix_rep)
     952             :             do_change_pgrid(:) = &
     953       79917 :                update_contraction_storage(storage, split_opt, split)
     954             : 
     955      798206 :             IF (ANY(do_change_pgrid)) THEN
     956         964 :                mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
     957             :                CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
     958         964 :                                          NINT(storage%nsplit_avg), own_comm=.TRUE.)
     959        2892 :                pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
     960             :             END IF
     961             : 
     962             :          END ASSOCIATE
     963             : 
     964       79917 :          IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
     965             :             ! check if new grid has better subgrid, if not there is no need to change process grid
     966        2892 :             pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
     967        2892 :             pdims_sub = split%mp_comm_group%num_pe_cart
     968             : 
     969        4820 :             pdim_ratio = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
     970        4820 :             pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, dp))/MINVAL(pdims_sub_opt)
     971         964 :             IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
     972           0 :                do_change_pgrid(1) = .FALSE.
     973           0 :                CALL dbt_tas_release_info(split_opt_avg)
     974             :             END IF
     975             :          END IF
     976             :       END IF
     977             : 
     978      103789 :       IF (unit_nr_prv /= 0) THEN
     979       46146 :          do_write_3 = .TRUE.
     980       46146 :          IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
     981       20424 :             IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE.
     982             :          END IF
     983             :          IF (do_write_3) THEN
     984       25760 :             CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
     985       25760 :             CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
     986             :          END IF
     987             :       END IF
     988             : 
     989      103789 :       IF (new_3) THEN
     990             :          ! need redistribute if we created new tensor for tensor 3
     991         188 :          CALL dbt_scale(tensor_algn_3, beta)
     992         188 :          CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.)
     993         188 :          IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
     994             :          ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
     995             :          ! pointer to data of tensor_3
     996             :       END IF
     997             : 
     998             :       ! transfer contraction storage
     999      103789 :       CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
    1000      103789 :       CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
    1001      103789 :       CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
    1002             : 
    1003      103789 :       IF (unit_nr_prv /= 0) THEN
    1004       46146 :          IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
    1005       46146 :          IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
    1006             :       END IF
    1007             : 
    1008      103789 :       CALL dbt_destroy(tensor_algn_1)
    1009      103789 :       CALL dbt_destroy(tensor_algn_2)
    1010      103789 :       CALL dbt_destroy(tensor_algn_3)
    1011             : 
    1012      103789 :       IF (do_crop_1) THEN
    1013       65306 :          CALL dbt_destroy(tensor_crop_1)
    1014       65306 :          DEALLOCATE (tensor_crop_1)
    1015             :       END IF
    1016             : 
    1017      103789 :       IF (do_crop_2) THEN
    1018       63148 :          CALL dbt_destroy(tensor_crop_2)
    1019       63148 :          DEALLOCATE (tensor_crop_2)
    1020             :       END IF
    1021             : 
    1022      103789 :       IF (new_1) THEN
    1023         202 :          CALL dbt_destroy(tensor_contr_1)
    1024         202 :          DEALLOCATE (tensor_contr_1)
    1025             :       END IF
    1026      103789 :       IF (new_2) THEN
    1027          80 :          CALL dbt_destroy(tensor_contr_2)
    1028          80 :          DEALLOCATE (tensor_contr_2)
    1029             :       END IF
    1030      103789 :       IF (new_3) THEN
    1031         188 :          CALL dbt_destroy(tensor_contr_3)
    1032         188 :          DEALLOCATE (tensor_contr_3)
    1033             :       END IF
    1034             : 
    1035      103789 :       IF (PRESENT(move_data)) THEN
    1036       29876 :          IF (move_data) THEN
    1037       26052 :             CALL dbt_clear(tensor_1)
    1038       26052 :             CALL dbt_clear(tensor_2)
    1039             :          END IF
    1040             :       END IF
    1041             : 
    1042      103789 :       IF (unit_nr_prv > 0) THEN
    1043          10 :          WRITE (unit_nr_prv, '(A)') repeat("-", 80)
    1044          10 :          WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
    1045          10 :          WRITE (unit_nr_prv, '(A)') repeat("-", 80)
    1046             :       END IF
    1047             : 
    1048      309439 :       IF (ANY(do_change_pgrid)) THEN
    1049         964 :          pgrid_changed_any = .FALSE.
    1050         264 :          SELECT CASE (max_mm_dim)
    1051             :          CASE (1)
    1052         264 :             IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
    1053             :                CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1054             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1055             :                                         pgrid_changed=pgrid_changed, &
    1056           0 :                                         unit_nr=unit_nr_prv)
    1057           0 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1058             :                CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1059             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1060             :                                         pgrid_changed=pgrid_changed, &
    1061           0 :                                         unit_nr=unit_nr_prv)
    1062           0 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1063             :             END IF
    1064           0 :             IF (pgrid_changed_any) THEN
    1065           0 :                IF (tensor_2%matrix_rep%do_batched == 3) THEN
    1066             :                   ! set flag that process grid has been optimized to make sure that no grid optimizations are done
    1067             :                   ! in TAS multiply algorithm
    1068           0 :                   CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
    1069             :                END IF
    1070             :             END IF
    1071             :          CASE (2)
    1072         172 :             IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
    1073             :                CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1074             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1075             :                                         pgrid_changed=pgrid_changed, &
    1076         172 :                                         unit_nr=unit_nr_prv)
    1077         172 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1078             :                CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1079             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1080             :                                         pgrid_changed=pgrid_changed, &
    1081         172 :                                         unit_nr=unit_nr_prv)
    1082         172 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1083             :             END IF
    1084           8 :             IF (pgrid_changed_any) THEN
    1085         172 :                IF (tensor_3%matrix_rep%do_batched == 3) THEN
    1086         160 :                   CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
    1087             :                END IF
    1088             :             END IF
    1089             :          CASE (3)
    1090         528 :             IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
    1091             :                CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1092             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1093             :                                         pgrid_changed=pgrid_changed, &
    1094         214 :                                         unit_nr=unit_nr_prv)
    1095         214 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1096             :                CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1097             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1098             :                                         pgrid_changed=pgrid_changed, &
    1099         214 :                                         unit_nr=unit_nr_prv)
    1100         214 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1101             :             END IF
    1102         964 :             IF (pgrid_changed_any) THEN
    1103         214 :                IF (tensor_1%matrix_rep%do_batched == 3) THEN
    1104         214 :                   CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
    1105             :                END IF
    1106             :             END IF
    1107             :          END SELECT
    1108         964 :          CALL dbt_tas_release_info(split_opt_avg)
    1109             :       END IF
    1110             : 
    1111      103789 :       IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
    1112             :          ! freeze TAS process grids if tensor grids were optimized
    1113       79917 :          CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.)
    1114       79917 :          CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.)
    1115       79917 :          CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.)
    1116             :       END IF
    1117             : 
    1118      103789 :       CALL dbt_tas_release_info(split_opt)
    1119             : 
    1120      103789 :       CALL timestop(handle)
    1121             : 
    1122      222472 :    END SUBROUTINE
    1123             : 
    1124             : ! **************************************************************************************************
    1125             : !> \brief align tensor index with data
    1126             : !> \author Patrick Seewald
    1127             : ! **************************************************************************************************
    1128     2490936 :    SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
    1129      311367 :                            tensor_out, contract_out, notcontract_out, indp_in, indp_out)
    1130             :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in
    1131             :       INTEGER, DIMENSION(:), INTENT(IN)            :: contract_in, notcontract_in
    1132             :       TYPE(dbt_type), INTENT(OUT)              :: tensor_out
    1133             :       INTEGER, DIMENSION(SIZE(contract_in)), &
    1134             :          INTENT(OUT)                               :: contract_out
    1135             :       INTEGER, DIMENSION(SIZE(notcontract_in)), &
    1136             :          INTENT(OUT)                               :: notcontract_out
    1137             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
    1138             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
    1139      311367 :       INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
    1140             : 
    1141      311367 :       CALL dbt_align_index(tensor_in, tensor_out, order=align)
    1142      712144 :       contract_out = align(contract_in)
    1143      725366 :       notcontract_out = align(notcontract_in)
    1144     1126143 :       indp_out(align) = indp_in
    1145             : 
    1146      311367 :    END SUBROUTINE
    1147             : 
    1148             : ! **************************************************************************************************
    1149             : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
    1150             : !>        matrix multiplication. This routine reshapes the two largest of the three tensors.
    1151             : !>        Redistribution is avoided if tensors already in a consistent layout.
    1152             : !> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
    1153             : !> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
    1154             : !>                    1:1 correspondence with ind1_linked
    1155             : !> \param trans1 transpose flag of matrix rep. tensor 1
    1156             : !> \param trans2 transpose flag of matrix rep. tensor 2
    1157             : !> \param new1 whether a new tensor 1 was created
    1158             : !> \param new2 whether a new tensor 2 was created
    1159             : !> \param nodata1 don't copy data of tensor 1
    1160             : !> \param nodata2 don't copy data of tensor 2
    1161             : !> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
    1162             : !> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
    1163             : !> \param optimize_dist experimental: optimize distribution
    1164             : !> \param unit_nr output unit
    1165             : !> \author Patrick Seewald
    1166             : ! **************************************************************************************************
    1167      103789 :    SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
    1168      103789 :                                     ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
    1169             :                                     nodata1, nodata2, move_data_1, &
    1170             :                                     move_data_2, optimize_dist, unit_nr)
    1171             :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor1
    1172             :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor2
    1173             :       TYPE(dbt_type), POINTER, INTENT(OUT)    :: tensor1_out, tensor2_out
    1174             :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_free, ind2_free
    1175             :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_linked, ind2_linked
    1176             :       LOGICAL, INTENT(OUT)                        :: trans1, trans2
    1177             :       LOGICAL, INTENT(OUT)                        :: new1, new2
    1178             :       INTEGER, INTENT(OUT) :: ref_tensor
    1179             :       LOGICAL, INTENT(IN), OPTIONAL               :: nodata1, nodata2
    1180             :       LOGICAL, INTENT(INOUT), OPTIONAL            :: move_data_1, move_data_2
    1181             :       LOGICAL, INTENT(IN), OPTIONAL               :: optimize_dist
    1182             :       INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
    1183             :       INTEGER                                     :: compat1, compat1_old, compat2, compat2_old, &
    1184             :                                                      unit_nr_prv
    1185      103789 :       TYPE(mp_cart_type)                          :: comm_2d
    1186      103789 :       TYPE(array_list)                            :: dist_list
    1187      103789 :       INTEGER, DIMENSION(:), ALLOCATABLE          :: mp_dims
    1188      726523 :       TYPE(dbt_distribution_type)             :: dist_in
    1189             :       INTEGER(KIND=int_8)                         :: nblkrows, nblkcols
    1190             :       LOGICAL                                     :: optimize_dist_prv
    1191      207578 :       INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
    1192      207578 :       INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
    1193             : 
    1194      103789 :       NULLIFY (tensor1_out, tensor2_out)
    1195             : 
    1196      103789 :       unit_nr_prv = prep_output_unit(unit_nr)
    1197             : 
    1198      103789 :       CALL blk_dims_tensor(tensor1, dims1)
    1199      103789 :       CALL blk_dims_tensor(tensor2, dims2)
    1200             : 
    1201      709535 :       IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN
    1202      102229 :          ref_tensor = 1
    1203             :       ELSE
    1204        1560 :          ref_tensor = 2
    1205             :       END IF
    1206             : 
    1207      103789 :       IF (PRESENT(optimize_dist)) THEN
    1208         182 :          optimize_dist_prv = optimize_dist
    1209             :       ELSE
    1210             :          optimize_dist_prv = .FALSE.
    1211             :       END IF
    1212             : 
    1213      103789 :       compat1 = compat_map(tensor1%nd_index, ind1_linked)
    1214      103789 :       compat2 = compat_map(tensor2%nd_index, ind2_linked)
    1215      103789 :       compat1_old = compat1
    1216      103789 :       compat2_old = compat2
    1217             : 
    1218      103789 :       IF (unit_nr_prv > 0) THEN
    1219          10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":"
    1220           6 :          SELECT CASE (compat1)
    1221             :          CASE (0)
    1222           6 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1223             :          CASE (1)
    1224           3 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1225             :          CASE (2)
    1226          10 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1227             :          END SELECT
    1228          10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":"
    1229           5 :          SELECT CASE (compat2)
    1230             :          CASE (0)
    1231           5 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1232             :          CASE (1)
    1233           4 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1234             :          CASE (2)
    1235          10 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1236             :          END SELECT
    1237             :       END IF
    1238             : 
    1239      103789 :       new1 = .FALSE.
    1240      103789 :       new2 = .FALSE.
    1241             : 
    1242      103789 :       IF (compat1 == 0 .OR. optimize_dist_prv) THEN
    1243         194 :          new1 = .TRUE.
    1244             :       END IF
    1245             : 
    1246      103789 :       IF (compat2 == 0 .OR. optimize_dist_prv) THEN
    1247         254 :          new2 = .TRUE.
    1248             :       END IF
    1249             : 
    1250      103789 :       IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
    1251      102229 :          IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
    1252         114 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name)
    1253         342 :             nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8))
    1254         230 :             nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8))
    1255         114 :             comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
    1256         798 :             ALLOCATE (tensor1_out)
    1257             :             CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
    1258         114 :                            nodata=nodata1, move_data=move_data_1)
    1259         114 :             CALL comm_2d%free()
    1260         114 :             compat1 = 1
    1261             :          ELSE
    1262      102115 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
    1263      102115 :             tensor1_out => tensor1
    1264             :          END IF
    1265      102229 :          IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
    1266         174 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
    1267           8 :                TRIM(tensor2%name), "compatible with", TRIM(tensor1%name)
    1268         170 :             dist_in = dbt_distribution(tensor1_out)
    1269         170 :             dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
    1270         170 :             IF (compat1 == 1) THEN ! linked index is first 2d dimension
    1271             :                ! get distribution of linked index, tensor 2 must adopt this distribution
    1272             :                ! get grid dimensions of linked index
    1273         480 :                ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
    1274         160 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
    1275        1120 :                ALLOCATE (tensor2_out)
    1276             :                CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1277         160 :                               dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
    1278          10 :             ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
    1279             :                ! get distribution of linked index, tensor 2 must adopt this distribution
    1280             :                ! get grid dimensions of linked index
    1281          30 :                ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
    1282          10 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
    1283          70 :                ALLOCATE (tensor2_out)
    1284             :                CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1285          10 :                               dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
    1286             :             ELSE
    1287           0 :                CPABORT("should not happen")
    1288             :             END IF
    1289             :             compat2 = compat1
    1290             :          ELSE
    1291      102059 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
    1292      102059 :             tensor2_out => tensor2
    1293             :          END IF
    1294             :       ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
    1295        1560 :          IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
    1296          84 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name)
    1297         178 :             nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8))
    1298         170 :             nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8))
    1299          84 :             comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
    1300         588 :             ALLOCATE (tensor2_out)
    1301          84 :             CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
    1302          84 :             CALL comm_2d%free()
    1303          84 :             compat2 = 1
    1304             :          ELSE
    1305        1476 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
    1306        1476 :             tensor2_out => tensor2
    1307             :          END IF
    1308        1560 :          IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
    1309          83 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), &
    1310           6 :                "compatible with", TRIM(tensor2%name)
    1311          80 :             dist_in = dbt_distribution(tensor2_out)
    1312          80 :             dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
    1313          80 :             IF (compat2 == 1) THEN
    1314         234 :                ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
    1315          78 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
    1316         546 :                ALLOCATE (tensor1_out)
    1317             :                CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1318          78 :                               dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
    1319           2 :             ELSEIF (compat2 == 2) THEN
    1320           6 :                ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
    1321           2 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
    1322          14 :                ALLOCATE (tensor1_out)
    1323             :                CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1324           2 :                               dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
    1325             :             ELSE
    1326           0 :                CPABORT("should not happen")
    1327             :             END IF
    1328             :             compat1 = compat2
    1329             :          ELSE
    1330        1480 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
    1331        1480 :             tensor1_out => tensor1
    1332             :          END IF
    1333             :       END IF
    1334             : 
    1335       66965 :       SELECT CASE (compat1)
    1336             :       CASE (1)
    1337       66965 :          trans1 = .FALSE.
    1338             :       CASE (2)
    1339       36824 :          trans1 = .TRUE.
    1340             :       CASE DEFAULT
    1341      103789 :          CPABORT("should not happen")
    1342             :       END SELECT
    1343             : 
    1344       69776 :       SELECT CASE (compat2)
    1345             :       CASE (1)
    1346       69776 :          trans2 = .FALSE.
    1347             :       CASE (2)
    1348       34013 :          trans2 = .TRUE.
    1349             :       CASE DEFAULT
    1350      103789 :          CPABORT("should not happen")
    1351             :       END SELECT
    1352             : 
    1353      103789 :       IF (unit_nr_prv > 0) THEN
    1354          10 :          IF (compat1 .NE. compat1_old) THEN
    1355           6 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":"
    1356           0 :             SELECT CASE (compat1)
    1357             :             CASE (0)
    1358           0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1359             :             CASE (1)
    1360           5 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1361             :             CASE (2)
    1362           6 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1363             :             END SELECT
    1364             :          END IF
    1365          10 :          IF (compat2 .NE. compat2_old) THEN
    1366           5 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":"
    1367           0 :             SELECT CASE (compat2)
    1368             :             CASE (0)
    1369           0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1370             :             CASE (1)
    1371           4 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1372             :             CASE (2)
    1373           5 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1374             :             END SELECT
    1375             :          END IF
    1376             :       END IF
    1377             : 
    1378      103789 :       IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE.
    1379      103789 :       IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE.
    1380             : 
    1381      103789 :    END SUBROUTINE
    1382             : 
    1383             : ! **************************************************************************************************
    1384             : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
    1385             : !>        matrix multiplication. This routine reshapes the smallest of the three tensors.
    1386             : !> \param ind1 index that should be mapped to first matrix dimension
    1387             : !> \param ind2 index that should be mapped to second matrix dimension
    1388             : !> \param trans transpose flag of matrix rep.
    1389             : !> \param new whether a new tensor was created for tensor_out
    1390             : !> \param nodata don't copy tensor data
    1391             : !> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
    1392             : !> \param unit_nr output unit
    1393             : !> \author Patrick Seewald
    1394             : ! **************************************************************************************************
    1395      103789 :    SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
    1396             :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor_in
    1397             :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1, ind2
    1398             :       TYPE(dbt_type), POINTER, INTENT(OUT)    :: tensor_out
    1399             :       LOGICAL, INTENT(OUT)                        :: trans
    1400             :       LOGICAL, INTENT(OUT)                        :: new
    1401             :       LOGICAL, INTENT(IN), OPTIONAL               :: nodata, move_data
    1402             :       INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
    1403             :       INTEGER                                     :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
    1404             :       LOGICAL                                     :: nodata_prv
    1405             : 
    1406      103789 :       NULLIFY (tensor_out)
    1407             :       IF (PRESENT(nodata)) THEN
    1408      103789 :          nodata_prv = nodata
    1409             :       ELSE
    1410             :          nodata_prv = .FALSE.
    1411             :       END IF
    1412             : 
    1413      103789 :       unit_nr_prv = prep_output_unit(unit_nr)
    1414             : 
    1415      103789 :       new = .FALSE.
    1416      103789 :       compat1 = compat_map(tensor_in%nd_index, ind1)
    1417      103789 :       compat2 = compat_map(tensor_in%nd_index, ind2)
    1418      103789 :       compat1_old = compat1; compat2_old = compat2
    1419      103789 :       IF (unit_nr_prv > 0) THEN
    1420          10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":"
    1421          10 :          IF (compat1 == 1 .AND. compat2 == 2) THEN
    1422           4 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1423           6 :          ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1424           2 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1425             :          ELSE
    1426           4 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1427             :          END IF
    1428             :       END IF
    1429      103789 :       IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
    1430             : 
    1431          22 :          IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name)
    1432             : 
    1433         154 :          ALLOCATE (tensor_out)
    1434          22 :          CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
    1435          22 :          CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
    1436          22 :          compat1 = 1
    1437          22 :          compat2 = 2
    1438          22 :          new = .TRUE.
    1439             :       ELSE
    1440      103767 :          IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name)
    1441      103767 :          tensor_out => tensor_in
    1442             :       END IF
    1443             : 
    1444      103789 :       IF (compat1 == 1 .AND. compat2 == 2) THEN
    1445       82117 :          trans = .FALSE.
    1446       21672 :       ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1447       21672 :          trans = .TRUE.
    1448             :       ELSE
    1449           0 :          CPABORT("this should not happen")
    1450             :       END IF
    1451             : 
    1452      103789 :       IF (unit_nr_prv > 0) THEN
    1453          10 :          IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
    1454           4 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":"
    1455           4 :             IF (compat1 == 1 .AND. compat2 == 2) THEN
    1456           4 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1457           0 :             ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1458           0 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1459             :             ELSE
    1460           0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1461             :             END IF
    1462             :          END IF
    1463             :       END IF
    1464             : 
    1465      103789 :    END SUBROUTINE
    1466             : 
    1467             : ! **************************************************************************************************
    1468             : !> \brief update contraction storage that keeps track of process grids during a batched contraction
    1469             : !>        and decide if tensor process grid needs to be optimized
    1470             : !> \param split_opt optimized TAS process grid
    1471             : !> \param split current TAS process grid
    1472             : !> \author Patrick Seewald
    1473             : ! **************************************************************************************************
    1474       79917 :    FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
    1475             :       TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
    1476             :       TYPE(dbt_tas_split_info), INTENT(IN)           :: split_opt
    1477             :       TYPE(dbt_tas_split_info), INTENT(IN)           :: split
    1478             :       INTEGER, DIMENSION(2) :: pdims, pdims_sub
    1479             :       LOGICAL, DIMENSION(2) :: do_change_pgrid
    1480             :       REAL(kind=dp) :: change_criterion, pdims_ratio
    1481             :       INTEGER :: nsplit_opt, nsplit
    1482             : 
    1483       79917 :       CPASSERT(ALLOCATED(split_opt%ngroup_opt))
    1484       79917 :       nsplit_opt = split_opt%ngroup_opt
    1485       79917 :       nsplit = split%ngroup
    1486             : 
    1487      239751 :       pdims = split%mp_comm%num_pe_cart
    1488             : 
    1489       79917 :       storage%ibatch = storage%ibatch + 1
    1490             : 
    1491             :       storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, dp) + REAL(nsplit_opt, dp)) &
    1492       79917 :                            /REAL(storage%ibatch, dp)
    1493             : 
    1494       79917 :       SELECT CASE (split_opt%split_rowcol)
    1495             :       CASE (rowsplit)
    1496       79917 :          pdims_ratio = REAL(pdims(1), dp)/pdims(2)
    1497             :       CASE (colsplit)
    1498       79917 :          pdims_ratio = REAL(pdims(2), dp)/pdims(1)
    1499             :       END SELECT
    1500             : 
    1501      239751 :       do_change_pgrid(:) = .FALSE.
    1502             : 
    1503             :       ! check for process grid dimensions
    1504      239751 :       pdims_sub = split%mp_comm_group%num_pe_cart
    1505      479502 :       change_criterion = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
    1506       79917 :       IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.
    1507             : 
    1508             :       ! check for split factor
    1509       79917 :       change_criterion = MAX(REAL(nsplit, dp)/storage%nsplit_avg, REAL(storage%nsplit_avg, dp)/nsplit)
    1510       79917 :       IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE.
    1511             : 
    1512       79917 :    END FUNCTION
    1513             : 
    1514             : ! **************************************************************************************************
    1515             : !> \brief Check if 2d index is compatible with tensor index
    1516             : !> \author Patrick Seewald
    1517             : ! **************************************************************************************************
    1518      415156 :    FUNCTION compat_map(nd_index, compat_ind)
    1519             :       TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
    1520             :       INTEGER, DIMENSION(:), INTENT(IN)  :: compat_ind
    1521      830312 :       INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
    1522      830312 :       INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
    1523             :       INTEGER                            :: compat_map
    1524             : 
    1525      415156 :       CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
    1526             : 
    1527      415156 :       compat_map = 0
    1528      415156 :       IF (array_eq_i(map1, compat_ind)) THEN
    1529             :          compat_map = 1
    1530      175008 :       ELSEIF (array_eq_i(map2, compat_ind)) THEN
    1531      174890 :          compat_map = 2
    1532             :       END IF
    1533             : 
    1534      415156 :    END FUNCTION
    1535             : 
    1536             : ! **************************************************************************************************
    1537             : !> \brief
    1538             : !> \author Patrick Seewald
    1539             : ! **************************************************************************************************
    1540      311367 :    SUBROUTINE index_linked_sort(ind_ref, ind_dep)
    1541             :       INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
    1542      622734 :       INTEGER, DIMENSION(SIZE(ind_ref))    :: sort_indices
    1543             : 
    1544      311367 :       CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
    1545     1437510 :       ind_dep(:) = ind_dep(sort_indices)
    1546             : 
    1547      311367 :    END SUBROUTINE
    1548             : 
    1549             : ! **************************************************************************************************
    1550             : !> \brief
    1551             : !> \author Patrick Seewald
    1552             : ! **************************************************************************************************
    1553           0 :    FUNCTION opt_pgrid(tensor, tas_split_info)
    1554             :       TYPE(dbt_type), INTENT(IN) :: tensor
    1555             :       TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
    1556           0 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    1557           0 :       INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
    1558             :       TYPE(dbt_pgrid_type) :: opt_pgrid
    1559           0 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
    1560             : 
    1561           0 :       CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
    1562           0 :       CALL blk_dims_tensor(tensor, dims)
    1563           0 :       opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
    1564             : 
    1565           0 :       ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info)
    1566           0 :       CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
    1567           0 :    END FUNCTION
    1568             : 
    1569             : ! **************************************************************************************************
    1570             : !> \brief Copy tensor to tensor with modified index mapping
    1571             : !> \param map1_2d new index mapping
    1572             : !> \param map2_2d new index mapping
    1573             : !> \author Patrick Seewald
    1574             : ! **************************************************************************************************
    1575        3760 :    SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
    1576         470 :                         mp_dims_1, mp_dims_2, name, nodata, move_data)
    1577             :       TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
    1578             :       INTEGER, DIMENSION(:), INTENT(IN)      :: map1_2d, map2_2d
    1579             :       TYPE(dbt_type), INTENT(OUT)        :: tensor_out
    1580             :       CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
    1581             :       LOGICAL, INTENT(IN), OPTIONAL          :: nodata, move_data
    1582             :       CLASS(mp_comm_type), INTENT(IN), OPTIONAL          :: comm_2d
    1583             :       TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
    1584             :       INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
    1585             :       INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
    1586             :       CHARACTER(len=default_string_length)   :: name_tmp
    1587         470 :       INTEGER, DIMENSION(:), ALLOCATABLE     :: ${varlist("blk_sizes")}$, &
    1588         470 :                                                 ${varlist("nd_dist")}$
    1589        3290 :       TYPE(dbt_distribution_type)        :: dist
    1590         470 :       TYPE(mp_cart_type) :: comm_2d_prv
    1591             :       INTEGER                                :: handle, i
    1592         470 :       INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
    1593             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_remap'
    1594             :       LOGICAL                               :: nodata_prv
    1595        1410 :       TYPE(dbt_pgrid_type)              :: comm_nd
    1596             : 
    1597         470 :       CALL timeset(routineN, handle)
    1598             : 
    1599         470 :       IF (PRESENT(name)) THEN
    1600           0 :          name_tmp = name
    1601             :       ELSE
    1602         470 :          name_tmp = tensor_in%name
    1603             :       END IF
    1604         470 :       IF (PRESENT(dist1)) THEN
    1605         238 :          CPASSERT(PRESENT(mp_dims_1))
    1606             :       END IF
    1607             : 
    1608         470 :       IF (PRESENT(dist2)) THEN
    1609          12 :          CPASSERT(PRESENT(mp_dims_2))
    1610             :       END IF
    1611             : 
    1612         470 :       IF (PRESENT(comm_2d)) THEN
    1613         364 :          comm_2d_prv = comm_2d
    1614             :       ELSE
    1615         106 :          comm_2d_prv = tensor_in%pgrid%mp_comm_2d
    1616             :       END IF
    1617             : 
    1618         470 :       comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
    1619         470 :       CALL mp_environ_pgrid(comm_nd, pdims, myploc)
    1620             : 
    1621             :       #:for ndim in ndims
    1622         866 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1623         396 :             CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$)
    1624             :          END IF
    1625             :       #:endfor
    1626             : 
    1627             :       #:for ndim in ndims
    1628         936 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1629             :             #:for idim in range(1, ndim+1)
    1630        1340 :                IF (PRESENT(dist1)) THEN
    1631        1594 :                   IF (ANY(map1_2d == ${idim}$)) THEN
    1632        1132 :                      i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d
    1633         240 :                      CALL get_ith_array(dist1, i, nd_dist_${idim}$)
    1634             :                   END IF
    1635             :                END IF
    1636             : 
    1637        1340 :                IF (PRESENT(dist2)) THEN
    1638          80 :                   IF (ANY(map2_2d == ${idim}$)) THEN
    1639          40 :                      i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d
    1640          16 :                      CALL get_ith_array(dist2, i, nd_dist_${idim}$)
    1641             :                   END IF
    1642             :                END IF
    1643             : 
    1644        1340 :                IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN
    1645        2766 :                   ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$)))
    1646         922 :                   CALL dbt_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$)
    1647             :                END IF
    1648             :             #:endfor
    1649             :             CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
    1650         470 :                                              ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.)
    1651             :          END IF
    1652             :       #:endfor
    1653             : 
    1654             :       #:for ndim in ndims
    1655         936 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1656             :             CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
    1657         470 :                             ${varlist("blk_sizes", nmax=ndim)}$)
    1658             :          END IF
    1659             :       #:endfor
    1660             : 
    1661         470 :       IF (PRESENT(nodata)) THEN
    1662         188 :          nodata_prv = nodata
    1663             :       ELSE
    1664             :          nodata_prv = .FALSE.
    1665             :       END IF
    1666             : 
    1667         470 :       IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
    1668         470 :       CALL dbt_distribution_destroy(dist)
    1669             : 
    1670         470 :       CALL timestop(handle)
    1671         940 :    END SUBROUTINE
    1672             : 
    1673             : ! **************************************************************************************************
    1674             : !> \brief Align index with data
    1675             : !> \param order permutation resulting from alignment
    1676             : !> \author Patrick Seewald
    1677             : ! **************************************************************************************************
    1678     2179569 :    SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
    1679             :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in
    1680             :       TYPE(dbt_type), INTENT(OUT)                 :: tensor_out
    1681      622734 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
    1682      622734 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
    1683             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
    1684             :          INTENT(OUT), OPTIONAL                        :: order
    1685      311367 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))     :: order_prv
    1686             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_align_index'
    1687             :       INTEGER                                         :: handle
    1688             : 
    1689      311367 :       CALL timeset(routineN, handle)
    1690             : 
    1691      311367 :       CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
    1692     1940919 :       order_prv = dbt_inverse_order([map1_2d, map2_2d])
    1693      311367 :       CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
    1694             : 
    1695     1126143 :       IF (PRESENT(order)) order = order_prv
    1696             : 
    1697      311367 :       CALL timestop(handle)
    1698      311367 :    END SUBROUTINE
    1699             : 
    1700             : ! **************************************************************************************************
    1701             : !> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
    1702             : !> \author Patrick Seewald
    1703             : ! **************************************************************************************************
    1704     3216008 :    SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
    1705             :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor_in
    1706             :       TYPE(dbt_type), INTENT(OUT)                 :: tensor_out
    1707             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
    1708             :          INTENT(IN)                                   :: order
    1709             : 
    1710     2010005 :       TYPE(nd_to_2d_mapping)                          :: nd_index_blk_rs, nd_index_rs
    1711             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_permute_index'
    1712             :       INTEGER                                         :: handle
    1713             :       INTEGER                                         :: ndims
    1714             : 
    1715      402001 :       CALL timeset(routineN, handle)
    1716             : 
    1717      402001 :       ndims = ndims_tensor(tensor_in)
    1718             : 
    1719      402001 :       CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
    1720      402001 :       CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
    1721      402001 :       CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
    1722             : 
    1723      402001 :       tensor_out%matrix_rep => tensor_in%matrix_rep
    1724      402001 :       tensor_out%owns_matrix = .FALSE.
    1725             : 
    1726      402001 :       tensor_out%nd_index = nd_index_rs
    1727      402001 :       tensor_out%nd_index_blk = nd_index_blk_rs
    1728      402001 :       tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
    1729      402001 :       IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
    1730      402001 :          ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info)
    1731             :       END IF
    1732      402001 :       tensor_out%refcount => tensor_in%refcount
    1733      402001 :       CALL dbt_hold(tensor_out)
    1734             : 
    1735      402001 :       CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
    1736      402001 :       CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
    1737      402001 :       CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
    1738      402001 :       CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
    1739     1206003 :       ALLOCATE (tensor_out%nblks_local(ndims))
    1740     1206003 :       ALLOCATE (tensor_out%nfull_local(ndims))
    1741     1469099 :       tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
    1742     1469099 :       tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
    1743      402001 :       tensor_out%name = tensor_in%name
    1744      402001 :       tensor_out%valid = .TRUE.
    1745             : 
    1746      402001 :       IF (ALLOCATED(tensor_in%contraction_storage)) THEN
    1747      282191 :          ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage)
    1748      282191 :          CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
    1749      282191 :          CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
    1750             :       END IF
    1751             : 
    1752      402001 :       CALL timestop(handle)
    1753      804002 :    END SUBROUTINE
    1754             : 
    1755             : ! **************************************************************************************************
    1756             : !> \brief Map contraction bounds to bounds referring to tensor indices
    1757             : !>        see dbt_contract for docu of dummy arguments
    1758             : !> \param bounds_t1 bounds mapped to tensor_1
    1759             : !> \param bounds_t2 bounds mapped to tensor_2
    1760             : !> \param do_crop_1 whether tensor 1 should be cropped
    1761             : !> \param do_crop_2 whether tensor 2 should be cropped
    1762             : !> \author Patrick Seewald
    1763             : ! **************************************************************************************************
    1764      111236 :    SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
    1765      111236 :                                         contract_1, notcontract_1, &
    1766      222472 :                                         contract_2, notcontract_2, &
    1767      111236 :                                         bounds_t1, bounds_t2, &
    1768      114954 :                                         bounds_1, bounds_2, bounds_3, &
    1769             :                                         do_crop_1, do_crop_2)
    1770             : 
    1771             :       TYPE(dbt_type), INTENT(IN)      :: tensor_1, tensor_2
    1772             :       INTEGER, DIMENSION(:), INTENT(IN)   :: contract_1, contract_2, &
    1773             :                                              notcontract_1, notcontract_2
    1774             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
    1775             :          INTENT(OUT)                                 :: bounds_t1
    1776             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
    1777             :          INTENT(OUT)                                 :: bounds_t2
    1778             :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
    1779             :          INTENT(IN), OPTIONAL                        :: bounds_1
    1780             :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
    1781             :          INTENT(IN), OPTIONAL                        :: bounds_2
    1782             :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
    1783             :          INTENT(IN), OPTIONAL                        :: bounds_3
    1784             :       LOGICAL, INTENT(OUT), OPTIONAL                 :: do_crop_1, do_crop_2
    1785             :       LOGICAL, DIMENSION(2)                          :: do_crop
    1786             : 
    1787      111236 :       do_crop = .FALSE.
    1788             : 
    1789      390168 :       bounds_t1(1, :) = 1
    1790      390168 :       CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
    1791             : 
    1792      415102 :       bounds_t2(1, :) = 1
    1793      415102 :       CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
    1794             : 
    1795      111236 :       IF (PRESENT(bounds_1)) THEN
    1796      168742 :          bounds_t1(:, contract_1) = bounds_1
    1797       24958 :          do_crop(1) = .TRUE.
    1798      168742 :          bounds_t2(:, contract_2) = bounds_1
    1799      111236 :          do_crop(2) = .TRUE.
    1800             :       END IF
    1801             : 
    1802      111236 :       IF (PRESENT(bounds_2)) THEN
    1803      227752 :          bounds_t1(:, notcontract_1) = bounds_2
    1804      111236 :          do_crop(1) = .TRUE.
    1805             :       END IF
    1806             : 
    1807      111236 :       IF (PRESENT(bounds_3)) THEN
    1808      252958 :          bounds_t2(:, notcontract_2) = bounds_3
    1809      111236 :          do_crop(2) = .TRUE.
    1810             :       END IF
    1811             : 
    1812      111236 :       IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
    1813      111236 :       IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
    1814             : 
    1815      267674 :    END SUBROUTINE
    1816             : 
    1817             : ! **************************************************************************************************
    1818             : !> \brief print tensor contraction indices in a human readable way
    1819             : !> \param indchar1 characters printed for index of tensor 1
    1820             : !> \param indchar2 characters printed for index of tensor 2
    1821             : !> \param indchar3 characters printed for index of tensor 3
    1822             : !> \param unit_nr output unit
    1823             : !> \author Patrick Seewald
    1824             : ! **************************************************************************************************
    1825      138438 :    SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
    1826             :       TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
    1827             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
    1828             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
    1829             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
    1830             :       INTEGER, INTENT(IN) :: unit_nr
    1831      276876 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
    1832      276876 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
    1833      276876 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
    1834      276876 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
    1835      276876 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
    1836      276876 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
    1837             :       INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
    1838             : 
    1839      138438 :       unit_nr_prv = prep_output_unit(unit_nr)
    1840             : 
    1841      138438 :       IF (unit_nr_prv /= 0) THEN
    1842      138438 :          CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
    1843      138438 :          CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
    1844      138438 :          CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
    1845             :       END IF
    1846             : 
    1847      138438 :       IF (unit_nr_prv > 0) THEN
    1848          30 :          WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
    1849          30 :          WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
    1850         123 :          DO ichar1 = 1, SIZE(indchar1)
    1851         123 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
    1852             :          END DO
    1853          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
    1854         120 :          DO ichar2 = 1, SIZE(indchar2)
    1855         120 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
    1856             :          END DO
    1857          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
    1858         123 :          DO ichar3 = 1, SIZE(indchar3)
    1859         123 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
    1860             :          END DO
    1861          30 :          WRITE (unit_nr_prv, '(A)') ")"
    1862             : 
    1863          30 :          WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
    1864          82 :          DO ichar1 = 1, SIZE(map11)
    1865          82 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
    1866             :          END DO
    1867          30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1868          71 :          DO ichar1 = 1, SIZE(map12)
    1869          71 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
    1870             :          END DO
    1871          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
    1872          76 :          DO ichar2 = 1, SIZE(map21)
    1873          76 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
    1874             :          END DO
    1875          30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1876          74 :          DO ichar2 = 1, SIZE(map22)
    1877          74 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
    1878             :          END DO
    1879          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
    1880          79 :          DO ichar3 = 1, SIZE(map31)
    1881          79 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
    1882             :          END DO
    1883          30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1884          74 :          DO ichar3 = 1, SIZE(map32)
    1885          74 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
    1886             :          END DO
    1887          30 :          WRITE (unit_nr_prv, '(A)') ")"
    1888             :       END IF
    1889             : 
    1890      138438 :    END SUBROUTINE
    1891             : 
    1892             : ! **************************************************************************************************
    1893             : !> \brief Initialize batched contraction for this tensor.
    1894             : !>
    1895             : !>        Explanation: A batched contraction is a contraction performed in several consecutive steps
    1896             : !>        by specification of bounds in dbt_contract. This can be used to reduce memory by
    1897             : !>        a large factor. The routines dbt_batched_contract_init and
    1898             : !>        dbt_batched_contract_finalize should be called to define the scope of a batched
    1899             : !>        contraction as this enables important optimizations (adapting communication scheme to
    1900             : !>        batches and adapting process grid to multiplication algorithm). The routines
    1901             : !>        dbt_batched_contract_init and dbt_batched_contract_finalize must be
    1902             : !>        called before the first and after the last contraction step on all 3 tensors.
    1903             : !>
    1904             : !>        Requirements:
    1905             : !>        - the tensors are in a compatible matrix layout (see documentation of
    1906             : !>          `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
    1907             : !>          disabled and a warning is issued.
    1908             : !>        - within the scope of a batched contraction, it is not allowed to access or change tensor
    1909             : !>          data except by calling the routines dbt_contract & dbt_copy.
    1910             : !>        - the bounds affecting indices of the smallest tensor must not change in the course of a
    1911             : !>          batched contraction (todo: get rid of this requirement).
    1912             : !>
    1913             : !>        Side effects:
    1914             : !>        - the parallel layout (process grid and distribution) of all tensors may change. In order
    1915             : !>          to disable the process grid optimization including this side effect, call this routine
    1916             : !>          only on the smallest of the 3 tensors.
    1917             : !>
    1918             : !> \note
    1919             : !>        Note 1: for an example of batched contraction see `examples/dbt_example.F`.
    1920             : !>        (todo: the example is outdated and should be updated).
    1921             : !>
    1922             : !>        Note 2: it is meaningful to use this feature if the contraction consists of one batch only
    1923             : !>        but if multiple contractions involving the same 3 tensors are performed
    1924             : !>        (batched_contract_init and batched_contract_finalize must then be called before/after each
    1925             : !>        contraction call). The process grid is then optimized after the first contraction
    1926             : !>        and future contraction may profit from this optimization.
    1927             : !>
    1928             : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
    1929             : !>                      a new range. The size should be the number of ranges plus one, the last
    1930             : !>                      element being the block index plus one of the last block in the last range.
    1931             : !>                      For internal load balancing optimizations, optionally specify the index
    1932             : !>                      ranges of batched contraction.
    1933             : !> \author Patrick Seewald
    1934             : ! **************************************************************************************************
    1935       99749 :    SUBROUTINE dbt_batched_contract_init(tensor, ${varlist("batch_range")}$)
    1936             :       TYPE(dbt_type), INTENT(INOUT) :: tensor
    1937             :       INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
    1938      199498 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
    1939       99749 :       INTEGER, DIMENSION(:), ALLOCATABLE                 :: ${varlist("batch_range_prv")}$
    1940             :       LOGICAL :: static_range
    1941             : 
    1942       99749 :       CALL dbt_get_info(tensor, nblks_total=tdims)
    1943             : 
    1944       99749 :       static_range = .TRUE.
    1945             :       #:for idim in range(1, maxdim+1)
    1946       99749 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    1947      233720 :             IF (PRESENT(batch_range_${idim}$)) THEN
    1948      370776 :                ALLOCATE (batch_range_prv_${idim}$, source=batch_range_${idim}$)
    1949      233720 :                static_range = .FALSE.
    1950             :             ELSE
    1951      176926 :                ALLOCATE (batch_range_prv_${idim}$ (2))
    1952      176926 :                batch_range_prv_${idim}$ (1) = 1
    1953      176926 :                batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1
    1954             :             END IF
    1955             :          END IF
    1956             :       #:endfor
    1957             : 
    1958       99749 :       ALLOCATE (tensor%contraction_storage)
    1959       99749 :       tensor%contraction_storage%static = static_range
    1960       99749 :       IF (static_range) THEN
    1961       68117 :          CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
    1962             :       END IF
    1963       99749 :       tensor%contraction_storage%nsplit_avg = 0.0_dp
    1964       99749 :       tensor%contraction_storage%ibatch = 0
    1965             : 
    1966             :       #:for ndim in range(1, maxdim+1)
    1967      199498 :          IF (ndims_tensor(tensor) == ${ndim}$) THEN
    1968             :             CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, &
    1969       99749 :                                    ${varlist("batch_range_prv", nmax=ndim)}$)
    1970             :          END IF
    1971             :       #:endfor
    1972             : 
    1973       99749 :    END SUBROUTINE
    1974             : 
    1975             : ! **************************************************************************************************
    1976             : !> \brief finalize batched contraction. This performs all communication that has been postponed in
    1977             : !>         the contraction calls.
    1978             : !> \author Patrick Seewald
    1979             : ! **************************************************************************************************
    1980      199498 :    SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
    1981             :       TYPE(dbt_type), INTENT(INOUT) :: tensor
    1982             :       INTEGER, INTENT(IN), OPTIONAL :: unit_nr
    1983             :       LOGICAL :: do_write
    1984             :       INTEGER :: unit_nr_prv, handle
    1985             : 
    1986       99749 :       CALL tensor%pgrid%mp_comm_2d%sync()
    1987       99749 :       CALL timeset("dbt_total", handle)
    1988       99749 :       unit_nr_prv = prep_output_unit(unit_nr)
    1989             : 
    1990       99749 :       do_write = .FALSE.
    1991             : 
    1992       99749 :       IF (tensor%contraction_storage%static) THEN
    1993       68117 :          IF (tensor%matrix_rep%do_batched > 0) THEN
    1994       68117 :             IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE.
    1995             :          END IF
    1996       68117 :          CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
    1997             :       END IF
    1998             : 
    1999       99749 :       IF (do_write .AND. unit_nr_prv /= 0) THEN
    2000       15666 :          IF (unit_nr_prv > 0) THEN
    2001             :             WRITE (unit_nr_prv, "(T2,A)") &
    2002           0 :                "FINALIZING BATCHED PROCESSING OF MATMUL"
    2003             :          END IF
    2004       15666 :          CALL dbt_write_tensor_info(tensor, unit_nr_prv)
    2005       15666 :          CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
    2006             :       END IF
    2007             : 
    2008       99749 :       CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
    2009       99749 :       DEALLOCATE (tensor%contraction_storage)
    2010       99749 :       CALL tensor%pgrid%mp_comm_2d%sync()
    2011       99749 :       CALL timestop(handle)
    2012             : 
    2013       99749 :    END SUBROUTINE
    2014             : 
    2015             : ! **************************************************************************************************
    2016             : !> \brief change the process grid of a tensor
    2017             : !> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
    2018             : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
    2019             : !>                      a new range. The size should be the number of ranges plus one, the last
    2020             : !>                      element being the block index plus one of the last block in the last range.
    2021             : !>                      For internal load balancing optimizations, optionally specify the index
    2022             : !>                      ranges of batched contraction.
    2023             : !> \author Patrick Seewald
    2024             : ! **************************************************************************************************
    2025         772 :    SUBROUTINE dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, &
    2026             :                                nodata, pgrid_changed, unit_nr)
    2027             :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor
    2028             :       TYPE(dbt_pgrid_type), INTENT(IN)               :: pgrid
    2029             :       INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
    2030             :       !!
    2031             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    2032             :       LOGICAL, INTENT(OUT), OPTIONAL                     :: pgrid_changed
    2033             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    2034             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_change_pgrid'
    2035             :       CHARACTER(default_string_length)                   :: name
    2036             :       INTEGER                                            :: handle
    2037         772 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ${varlist("bs")}$, &
    2038         772 :                                                             ${varlist("dist")}$
    2039        1544 :       INTEGER, DIMENSION(ndims_tensor(tensor))           :: pcoord, pcoord_ref, pdims, pdims_ref, &
    2040        1544 :                                                             tdims
    2041        5404 :       TYPE(dbt_type)                                 :: t_tmp
    2042        5404 :       TYPE(dbt_distribution_type)                    :: dist
    2043        1544 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    2044             :       INTEGER, &
    2045        1544 :          DIMENSION(ndims_matrix_column(tensor))    :: map2
    2046        1544 :       LOGICAL, DIMENSION(ndims_tensor(tensor))             :: mem_aware
    2047         772 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
    2048             :       INTEGER :: ind1, ind2, batch_size, ibatch
    2049             : 
    2050         772 :       IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE.
    2051         772 :       CALL mp_environ_pgrid(pgrid, pdims, pcoord)
    2052         772 :       CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
    2053             : 
    2054         796 :       IF (ALL(pdims == pdims_ref)) THEN
    2055           8 :          IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
    2056           8 :             IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
    2057             :                RETURN
    2058             :             END IF
    2059             :          END IF
    2060             :       END IF
    2061             : 
    2062         764 :       CALL timeset(routineN, handle)
    2063             : 
    2064             :       #:for idim in range(1, maxdim+1)
    2065        3056 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    2066        2292 :             mem_aware(${idim}$) = PRESENT(batch_range_${idim}$)
    2067        2292 :             IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1
    2068             :          END IF
    2069             :       #:endfor
    2070             : 
    2071         764 :       CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
    2072             : 
    2073             :       #:for idim in range(1, maxdim+1)
    2074        3056 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    2075        6876 :             ALLOCATE (bs_${idim}$ (dbt_nblks_total(tensor, ${idim}$)))
    2076        2292 :             CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$)
    2077        6876 :             ALLOCATE (dist_${idim}$ (tdims(${idim}$)))
    2078       16764 :             dist_${idim}$ = 0
    2079        2292 :             IF (mem_aware(${idim}$)) THEN
    2080        6276 :                DO ibatch = 1, nbatch(${idim}$)
    2081        3984 :                   ind1 = batch_range_${idim}$ (ibatch)
    2082        3984 :                   ind2 = batch_range_${idim}$ (ibatch + 1) - 1
    2083        3984 :                   batch_size = ind2 - ind1 + 1
    2084             :                   CALL dbt_default_distvec(batch_size, pdims(${idim}$), &
    2085        6276 :                                            bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2))
    2086             :                END DO
    2087             :             ELSE
    2088           0 :                CALL dbt_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$)
    2089             :             END IF
    2090             :          END IF
    2091             :       #:endfor
    2092             : 
    2093         764 :       CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
    2094             :       #:for ndim in ndims
    2095        1528 :          IF (ndims_tensor(tensor) == ${ndim}$) THEN
    2096         764 :             CALL dbt_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$)
    2097         764 :             CALL dbt_create(t_tmp, name, dist, map1, map2, ${varlist("bs", nmax=ndim)}$)
    2098             :          END IF
    2099             :       #:endfor
    2100         764 :       CALL dbt_distribution_destroy(dist)
    2101             : 
    2102         764 :       IF (PRESENT(nodata)) THEN
    2103           0 :          IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
    2104             :       ELSE
    2105         764 :          CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
    2106             :       END IF
    2107             : 
    2108         764 :       CALL dbt_copy_contraction_storage(tensor, t_tmp)
    2109             : 
    2110         764 :       CALL dbt_destroy(tensor)
    2111         764 :       tensor = t_tmp
    2112             : 
    2113         764 :       IF (PRESENT(unit_nr)) THEN
    2114         764 :          IF (unit_nr > 0) THEN
    2115           0 :             WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name)
    2116           0 :             WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
    2117           0 :             CALL dbt_write_split_info(pgrid, unit_nr)
    2118             :          END IF
    2119             :       END IF
    2120             : 
    2121         764 :       IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE.
    2122             : 
    2123         764 :       CALL timestop(handle)
    2124         772 :    END SUBROUTINE
    2125             : 
    2126             : ! **************************************************************************************************
    2127             : !> \brief map tensor to a new 2d process grid for the matrix representation.
    2128             : !> \author Patrick Seewald
    2129             : ! **************************************************************************************************
    2130         772 :    SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
    2131             :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor
    2132             :       TYPE(mp_cart_type), INTENT(IN)               :: mp_comm
    2133             :       INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
    2134             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    2135             :       INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
    2136             :       LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
    2137             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    2138        1544 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    2139        1544 :       INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
    2140        1544 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
    2141        2316 :       TYPE(dbt_pgrid_type) :: pgrid
    2142         772 :       INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$
    2143         772 :       INTEGER, DIMENSION(:), ALLOCATABLE :: array
    2144             :       INTEGER :: idim
    2145             : 
    2146         772 :       CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
    2147         772 :       CALL blk_dims_tensor(tensor, dims)
    2148             : 
    2149         772 :       IF (ALLOCATED(tensor%contraction_storage)) THEN
    2150             :          ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges)
    2151        3088 :             nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
    2152             :             ! for good load balancing the process grid dimensions should be chosen adapted to the
    2153             :             ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
    2154             :             ! the number of batches (number of index ranges).
    2155        3860 :             DO idim = 1, ndims_tensor(tensor)
    2156        2316 :                CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
    2157        2316 :                dims(idim) = array(nbatches(idim) + 1) - array(1)
    2158        2316 :                DEALLOCATE (array)
    2159        2316 :                dims(idim) = dims(idim)/nbatches(idim)
    2160        5404 :                IF (dims(idim) <= 0) dims(idim) = 1
    2161             :             END DO
    2162             :          END ASSOCIATE
    2163             :       END IF
    2164             : 
    2165         772 :       pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
    2166         772 :       IF (ALLOCATED(tensor%contraction_storage)) THEN
    2167             :          #:for ndim in range(1, maxdim+1)
    2168        1544 :             IF (ndims_tensor(tensor) == ${ndim}$) THEN
    2169         772 :                CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$)
    2170             :                CALL dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, &
    2171         772 :                                      nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
    2172             :             END IF
    2173             :          #:endfor
    2174             :       ELSE
    2175           0 :          CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
    2176             :       END IF
    2177         772 :       CALL dbt_pgrid_destroy(pgrid)
    2178             : 
    2179         772 :    END SUBROUTINE
    2180             : 
    2181      184948 : END MODULE

Generated by: LCOV version 1.15