LCOV - code coverage report
Current view: top level - src/dbt - dbt_types.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:1f285aa) Lines: 525 583 90.1 %
Date: 2024-04-23 06:49:27 Functions: 49 63 77.8 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief DBT tensor framework for block-sparse tensor contraction: Types and create/destroy routines.
      10             : !> \author Patrick Seewald
      11             : ! **************************************************************************************************
      12             : MODULE dbt_types
      13             :    #:include "dbt_macros.fypp"
      14             :    #:set maxdim = maxrank
      15             :    #:set ndims = range(2,maxdim+1)
      16             : 
      17             :    USE dbcsr_api, ONLY: dbcsr_type, dbcsr_get_info, dbcsr_distribution_type, dbcsr_distribution_get
      18             :    USE dbt_array_list_methods, ONLY: &
      19             :       array_list, array_offsets, create_array_list, destroy_array_list, get_array_elements, &
      20             :       sizes_of_arrays, sum_of_arrays, array_sublist, get_arrays, get_ith_array, array_eq_i
      21             :    USE dbm_api, ONLY: &
      22             :       dbm_distribution_obj, dbm_type
      23             :    USE kinds, ONLY: dp, dp, default_string_length
      24             :    USE dbt_tas_base, ONLY: &
      25             :       dbt_tas_create, dbt_tas_distribution_new, &
      26             :       dbt_tas_distribution_destroy, dbt_tas_finalize, dbt_tas_get_info, &
      27             :       dbt_tas_destroy, dbt_tas_get_stored_coordinates, dbt_tas_filter, &
      28             :       dbt_tas_get_num_blocks, dbt_tas_get_num_blocks_total, dbt_tas_get_nze, &
      29             :       dbt_tas_get_nze_total, dbt_tas_clear
      30             :    USE dbt_tas_types, ONLY: &
      31             :       dbt_tas_type, dbt_tas_distribution_type, dbt_tas_split_info, dbt_tas_mm_storage
      32             :    USE dbt_tas_mm, ONLY: dbt_tas_set_batched_state
      33             :    USE dbt_index, ONLY: &
      34             :       get_2d_indices_tensor, get_nd_indices_pgrid, create_nd_to_2d_mapping, destroy_nd_to_2d_mapping, &
      35             :       dbt_get_mapping_info, nd_to_2d_mapping, split_tensor_index, combine_tensor_index, combine_pgrid_index, &
      36             :       split_pgrid_index, ndims_mapping, ndims_mapping_row, ndims_mapping_column
      37             :    USE dbt_tas_split, ONLY: &
      38             :       dbt_tas_create_split_rows_or_cols, dbt_tas_release_info, dbt_tas_info_hold, &
      39             :       dbt_tas_create_split, dbt_tas_get_split_info, dbt_tas_set_strict_split
      40             :    USE kinds, ONLY: default_string_length, int_8, dp
      41             :    USE message_passing, ONLY: &
      42             :       mp_cart_type, mp_dims_create, mp_comm_type
      43             :    USE dbt_tas_global, ONLY: dbt_tas_distribution, dbt_tas_rowcol_data, dbt_tas_default_distvec
      44             :    USE dbt_allocate_wrap, ONLY: allocate_any
      45             :    USE dbm_api, ONLY: dbm_scale
      46             :    USE util, ONLY: sort
      47             : #include "../base/base_uses.f90"
      48             : 
      49             :    IMPLICIT NONE
      50             :    PRIVATE
      51             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_types'
      52             : 
      53             :    PUBLIC  :: &
      54             :       blk_dims_tensor, &
      55             :       dbt_blk_offsets, &
      56             :       dbt_blk_sizes, &
      57             :       dbt_clear, &
      58             :       dbt_create, &
      59             :       dbt_destroy, &
      60             :       dbt_distribution, &
      61             :       dbt_distribution_destroy, &
      62             :       dbt_distribution_new, &
      63             :       dbt_distribution_new_expert, &
      64             :       dbt_distribution_type, &
      65             :       dbt_filter, &
      66             :       dbt_finalize, &
      67             :       dbt_get_info, &
      68             :       dbt_get_num_blocks, &
      69             :       dbt_get_num_blocks_total, &
      70             :       dbt_get_nze, &
      71             :       dbt_get_nze_total, &
      72             :       dbt_get_stored_coordinates, &
      73             :       dbt_hold, &
      74             :       dbt_mp_dims_create, &
      75             :       dbt_nd_mp_comm, &
      76             :       dbt_nd_mp_free, &
      77             :       dbt_pgrid_change_dims, &
      78             :       dbt_pgrid_create, &
      79             :       dbt_pgrid_create_expert, &
      80             :       dbt_pgrid_destroy, &
      81             :       dbt_pgrid_type, &
      82             :       dbt_pgrid_set_strict_split, &
      83             :       dbt_scale, &
      84             :       dbt_type, &
      85             :       dims_tensor, &
      86             :       mp_environ_pgrid, &
      87             :       ndims_tensor, &
      88             :       ndims_matrix_row, &
      89             :       ndims_matrix_column, &
      90             :       dbt_nblks_local, &
      91             :       dbt_nblks_total, &
      92             :       dbt_blk_size, &
      93             :       dbt_max_nblks_local, &
      94             :       dbt_default_distvec, &
      95             :       dbt_contraction_storage, &
      96             :       dbt_copy_contraction_storage
      97             : 
      98             :    TYPE dbt_pgrid_type
      99             :       TYPE(nd_to_2d_mapping)                  :: nd_index_grid
     100             :       TYPE(mp_cart_type)                      :: mp_comm_2d
     101             :       TYPE(dbt_tas_split_info), ALLOCATABLE   :: tas_split_info
     102             :       INTEGER                                 :: nproc = -1
     103             :    END TYPE
     104             : 
     105             :    TYPE dbt_contraction_storage
     106             :       REAL(dp)         :: nsplit_avg = 0.0_dp
     107             :       INTEGER          :: ibatch = -1
     108             :       TYPE(array_list) :: batch_ranges
     109             :       LOGICAL          :: static = .FALSE.
     110             :    END TYPE
     111             : 
     112             :    TYPE dbt_type
     113             :       TYPE(dbt_tas_type), POINTER                :: matrix_rep => NULL()
     114             :       TYPE(nd_to_2d_mapping)                     :: nd_index_blk
     115             :       TYPE(nd_to_2d_mapping)                     :: nd_index
     116             :       TYPE(array_list)                           :: blk_sizes
     117             :       TYPE(array_list)                           :: blk_offsets
     118             :       TYPE(array_list)                           :: nd_dist
     119             :       TYPE(dbt_pgrid_type)                       :: pgrid
     120             :       TYPE(array_list)                           :: blks_local
     121             :       INTEGER, DIMENSION(:), ALLOCATABLE         :: nblks_local
     122             :       INTEGER, DIMENSION(:), ALLOCATABLE         :: nfull_local
     123             :       LOGICAL                                    :: valid = .FALSE.
     124             :       LOGICAL                                    :: owns_matrix = .FALSE.
     125             :       CHARACTER(LEN=default_string_length)       :: name = ""
     126             :       ! lightweight reference counting for communicators:
     127             :       INTEGER, POINTER                           :: refcount => NULL()
     128             :       TYPE(dbt_contraction_storage), ALLOCATABLE :: contraction_storage
     129             :    END TYPE dbt_type
     130             : 
     131             :    TYPE dbt_distribution_type
     132             :       TYPE(dbt_tas_distribution_type) :: dist
     133             :       TYPE(dbt_pgrid_type)            :: pgrid
     134             :       TYPE(array_list)                :: nd_dist
     135             :       ! lightweight reference counting for communicators:
     136             :       INTEGER, POINTER                :: refcount => NULL()
     137             :    END TYPE
     138             : 
     139             : ! **************************************************************************************************
     140             : !> \brief tas matrix distribution function object for one matrix index
     141             : !> \var dims tensor     dimensions only for this matrix dimension
     142             : !> \var dims_grid       grid dimensions only for this matrix dimension
     143             : !> \var nd_dist         dist only for tensor dimensions belonging to this matrix dimension
     144             : !> \var tas_dist_t map  matrix index to process grid
     145             : !> \var tas_rowcols_t   map process grid to matrix index
     146             : ! **************************************************************************************************
     147             :    TYPE, EXTENDS(dbt_tas_distribution) :: dbt_tas_dist_t
     148             :       INTEGER, DIMENSION(:), ALLOCATABLE :: dims
     149             :       INTEGER, DIMENSION(:), ALLOCATABLE :: dims_grid
     150             :       TYPE(array_list)                   :: nd_dist
     151             :    CONTAINS
     152             :       PROCEDURE                          :: dist => tas_dist_t
     153             :       PROCEDURE                          :: rowcols => tas_rowcols_t
     154             :    END TYPE
     155             : 
     156             : ! **************************************************************************************************
     157             : !> \brief  block size object for one matrix index
     158             : !> \var dims tensor dimensions only for this matrix dimension
     159             : !> \var blk_size block size only for this matrix dimension
     160             : ! **************************************************************************************************
     161             :    TYPE, EXTENDS(dbt_tas_rowcol_data) :: dbt_tas_blk_size_t
     162             :       INTEGER, DIMENSION(:), ALLOCATABLE :: dims
     163             :       TYPE(array_list)                   :: blk_size
     164             :    CONTAINS
     165             :       PROCEDURE                          :: data => tas_blk_size_t
     166             :    END TYPE
     167             : 
     168             :    INTERFACE dbt_create
     169             :       MODULE PROCEDURE dbt_create_new
     170             :       MODULE PROCEDURE dbt_create_template
     171             :       MODULE PROCEDURE dbt_create_matrix
     172             :    END INTERFACE
     173             : 
     174             :    INTERFACE dbt_tas_dist_t
     175             :       MODULE PROCEDURE new_dbt_tas_dist_t
     176             :    END INTERFACE
     177             : 
     178             :    INTERFACE dbt_tas_blk_size_t
     179             :       MODULE PROCEDURE new_dbt_tas_blk_size_t
     180             :    END INTERFACE
     181             : 
     182             : CONTAINS
     183             : 
     184             : ! **************************************************************************************************
     185             : !> \brief Create distribution object for one matrix dimension
     186             : !> \param nd_dist arrays for distribution vectors along all dimensions
     187             : !> \param map_blks tensor to matrix mapping object for blocks
     188             : !> \param map_grid tensor to matrix mapping object for process grid
     189             : !> \param which_dim for which dimension (1 or 2) distribution should be created
     190             : !> \return distribution object
     191             : !> \author Patrick Seewald
     192             : ! **************************************************************************************************
     193      742224 :    FUNCTION new_dbt_tas_dist_t(nd_dist, map_blks, map_grid, which_dim)
     194             :       TYPE(array_list), INTENT(IN)       :: nd_dist
     195             :       TYPE(nd_to_2d_mapping), INTENT(IN) :: map_blks, map_grid
     196             :       INTEGER, INTENT(IN)                :: which_dim
     197             : 
     198             :       TYPE(dbt_tas_dist_t)               :: new_dbt_tas_dist_t
     199             :       INTEGER, DIMENSION(2)              :: grid_dims
     200             :       INTEGER(KIND=int_8), DIMENSION(2)  :: matrix_dims
     201      742224 :       INTEGER, DIMENSION(:), ALLOCATABLE :: index_map
     202             : 
     203      742224 :       IF (which_dim == 1) THEN
     204     1113336 :          ALLOCATE (new_dbt_tas_dist_t%dims(ndims_mapping_row(map_blks)))
     205     1113336 :          ALLOCATE (index_map(ndims_mapping_row(map_blks)))
     206             :          CALL dbt_get_mapping_info(map_blks, &
     207             :                                    dims_2d_i8=matrix_dims, &
     208             :                                    map1_2d=index_map, &
     209      371112 :                                    dims1_2d=new_dbt_tas_dist_t%dims)
     210     1113336 :          ALLOCATE (new_dbt_tas_dist_t%dims_grid(ndims_mapping_row(map_grid)))
     211             :          CALL dbt_get_mapping_info(map_grid, &
     212             :                                    dims_2d=grid_dims, &
     213      371112 :                                    dims1_2d=new_dbt_tas_dist_t%dims_grid)
     214      371112 :       ELSEIF (which_dim == 2) THEN
     215     1113336 :          ALLOCATE (new_dbt_tas_dist_t%dims(ndims_mapping_column(map_blks)))
     216     1113336 :          ALLOCATE (index_map(ndims_mapping_column(map_blks)))
     217             :          CALL dbt_get_mapping_info(map_blks, &
     218             :                                    dims_2d_i8=matrix_dims, &
     219             :                                    map2_2d=index_map, &
     220      371112 :                                    dims2_2d=new_dbt_tas_dist_t%dims)
     221     1113336 :          ALLOCATE (new_dbt_tas_dist_t%dims_grid(ndims_mapping_column(map_grid)))
     222             :          CALL dbt_get_mapping_info(map_grid, &
     223             :                                    dims_2d=grid_dims, &
     224      371112 :                                    dims2_2d=new_dbt_tas_dist_t%dims_grid)
     225             :       ELSE
     226           0 :          CPABORT("Unknown value for which_dim")
     227             :       END IF
     228             : 
     229      742224 :       new_dbt_tas_dist_t%nd_dist = array_sublist(nd_dist, index_map)
     230      742224 :       new_dbt_tas_dist_t%nprowcol = grid_dims(which_dim)
     231      742224 :       new_dbt_tas_dist_t%nmrowcol = matrix_dims(which_dim)
     232     1484448 :    END FUNCTION
     233             : 
     234             : ! **************************************************************************************************
     235             : !> \author Patrick Seewald
     236             : ! **************************************************************************************************
     237    28271256 :    FUNCTION tas_dist_t(t, rowcol)
     238             :       CLASS(dbt_tas_dist_t), INTENT(IN) :: t
     239             :       INTEGER(KIND=int_8), INTENT(IN) :: rowcol
     240             :       INTEGER, DIMENSION(${maxrank}$) :: ind_blk
     241             :       INTEGER, DIMENSION(${maxrank}$) :: dist_blk
     242             :       INTEGER :: tas_dist_t
     243             : 
     244    28271256 :       ind_blk(:SIZE(t%dims)) = split_tensor_index(rowcol, t%dims)
     245    28271256 :       dist_blk(:SIZE(t%dims)) = get_array_elements(t%nd_dist, ind_blk(:SIZE(t%dims)))
     246    28271256 :       tas_dist_t = combine_pgrid_index(dist_blk(:SIZE(t%dims)), t%dims_grid)
     247    28271256 :    END FUNCTION
     248             : 
     249             : ! **************************************************************************************************
     250             : !> \author Patrick Seewald
     251             : ! **************************************************************************************************
     252      766484 :    FUNCTION tas_rowcols_t(t, dist)
     253             :       CLASS(dbt_tas_dist_t), INTENT(IN) :: t
     254             :       INTEGER, INTENT(IN) :: dist
     255             :       INTEGER(KIND=int_8), DIMENSION(:), ALLOCATABLE :: tas_rowcols_t
     256             :       INTEGER, DIMENSION(${maxrank}$) :: dist_blk
     257      766484 :       INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("dist")}$, ${varlist("blks")}$, blks_tmp, nd_ind
     258             :       INTEGER :: ${varlist("i")}$, i, iblk, iblk_count, nblks
     259             :       INTEGER(KIND=int_8) :: nrowcols
     260      766484 :       TYPE(array_list) :: blks
     261             : 
     262      766484 :       dist_blk(:SIZE(t%dims)) = split_pgrid_index(dist, t%dims_grid)
     263             : 
     264             :       #:for ndim in range(1, maxdim+1)
     265      996928 :          IF (SIZE(t%dims) == ${ndim}$) THEN
     266      230444 :             CALL get_arrays(t%nd_dist, ${varlist("dist", nmax=ndim)}$)
     267             :          END IF
     268             :       #:endfor
     269             : 
     270             :       #:for idim in range(1, maxdim+1)
     271     1763608 :          IF (SIZE(t%dims) .GE. ${idim}$) THEN
     272      997124 :             nblks = SIZE(dist_${idim}$)
     273     2991372 :             ALLOCATE (blks_tmp(nblks))
     274    12828060 :             iblk_count = 0
     275    12828060 :             DO iblk = 1, nblks
     276    12828060 :                IF (dist_${idim}$ (iblk) == dist_blk(${idim}$)) THEN
     277    10507490 :                   iblk_count = iblk_count + 1
     278    10507490 :                   blks_tmp(iblk_count) = iblk
     279             :                END IF
     280             :             END DO
     281     2987206 :             ALLOCATE (blks_${idim}$ (iblk_count))
     282    11504614 :             blks_${idim}$ (:) = blks_tmp(:iblk_count)
     283      997124 :             DEALLOCATE (blks_tmp)
     284             :          END IF
     285             :       #:endfor
     286             : 
     287             :       #:for ndim in range(1, maxdim+1)
     288     1532968 :          IF (SIZE(t%dims) == ${ndim}$) THEN
     289      766484 :             CALL create_array_list(blks, ${ndim}$, ${varlist("blks", nmax=ndim)}$)
     290             :          END IF
     291             :       #:endfor
     292             : 
     293     1763608 :       nrowcols = PRODUCT(INT(sizes_of_arrays(blks), int_8))
     294     2295286 :       ALLOCATE (tas_rowcols_t(nrowcols))
     295             : 
     296             :       #:for ndim in range(1, maxdim+1)
     297     1532968 :          IF (SIZE(t%dims) == ${ndim}$) THEN
     298      766484 :             ALLOCATE (nd_ind(${ndim}$))
     299      766484 :             i = 0
     300             :             #:for idim in range(1,ndim+1)
     301    25505524 :                DO i_${idim}$ = 1, SIZE(blks_${idim}$)
     302             :                   #:endfor
     303    21554912 :                   i = i + 1
     304             : 
     305    58661050 :                   nd_ind(:) = get_array_elements(blks, [${varlist("i", nmax=ndim)}$])
     306    25273884 :                   tas_rowcols_t(i) = combine_tensor_index(nd_ind, t%dims)
     307             :                   #:for idim in range(1,ndim+1)
     308             :                      END DO
     309             :                   #:endfor
     310             :                END IF
     311             :             #:endfor
     312             : 
     313             :          END FUNCTION
     314             : 
     315             : ! **************************************************************************************************
     316             : !> \brief Create block size object for one matrix dimension
     317             : !> \param blk_size arrays for block sizes along all dimensions
     318             : !> \param map_blks tensor to matrix mapping object for blocks
     319             : !> \param which_dim for which dimension (1 or 2) distribution should be created
     320             : !> \return block size object
     321             : !> \author Patrick Seewald
     322             : ! **************************************************************************************************
     323      371512 :          FUNCTION new_dbt_tas_blk_size_t(blk_size, map_blks, which_dim)
     324             :             TYPE(array_list), INTENT(IN)                   :: blk_size
     325             :             TYPE(nd_to_2d_mapping), INTENT(IN)             :: map_blks
     326             :             INTEGER, INTENT(IN) :: which_dim
     327             :             INTEGER(KIND=int_8), DIMENSION(2) :: matrix_dims
     328      371512 :             INTEGER, DIMENSION(:), ALLOCATABLE :: index_map
     329             :             TYPE(dbt_tas_blk_size_t) :: new_dbt_tas_blk_size_t
     330             : 
     331      371512 :             IF (which_dim == 1) THEN
     332      557268 :                ALLOCATE (index_map(ndims_mapping_row(map_blks)))
     333      557268 :                ALLOCATE (new_dbt_tas_blk_size_t%dims(ndims_mapping_row(map_blks)))
     334             :                CALL dbt_get_mapping_info(map_blks, &
     335             :                                          dims_2d_i8=matrix_dims, &
     336             :                                          map1_2d=index_map, &
     337      185756 :                                          dims1_2d=new_dbt_tas_blk_size_t%dims)
     338      185756 :             ELSEIF (which_dim == 2) THEN
     339      557268 :                ALLOCATE (index_map(ndims_mapping_column(map_blks)))
     340      557268 :                ALLOCATE (new_dbt_tas_blk_size_t%dims(ndims_mapping_column(map_blks)))
     341             :                CALL dbt_get_mapping_info(map_blks, &
     342             :                                          dims_2d_i8=matrix_dims, &
     343             :                                          map2_2d=index_map, &
     344      185756 :                                          dims2_2d=new_dbt_tas_blk_size_t%dims)
     345             :             ELSE
     346           0 :                CPABORT("Unknown value for which_dim")
     347             :             END IF
     348             : 
     349      371512 :             new_dbt_tas_blk_size_t%blk_size = array_sublist(blk_size, index_map)
     350      371512 :             new_dbt_tas_blk_size_t%nmrowcol = matrix_dims(which_dim)
     351             : 
     352             :             new_dbt_tas_blk_size_t%nfullrowcol = PRODUCT(INT(sum_of_arrays(new_dbt_tas_blk_size_t%blk_size), &
     353      828846 :                                                              KIND=int_8))
     354      743024 :          END FUNCTION
     355             : 
     356             : ! **************************************************************************************************
     357             : !> \author Patrick Seewald
     358             : ! **************************************************************************************************
     359    13257186 :          FUNCTION tas_blk_size_t(t, rowcol)
     360             :             CLASS(dbt_tas_blk_size_t), INTENT(IN) :: t
     361             :             INTEGER(KIND=int_8), INTENT(IN) :: rowcol
     362             :             INTEGER :: tas_blk_size_t
     363    26514372 :             INTEGER, DIMENSION(SIZE(t%dims)) :: ind_blk
     364    26514372 :             INTEGER, DIMENSION(SIZE(t%dims)) :: blk_size
     365             : 
     366    13257186 :             ind_blk(:) = split_tensor_index(rowcol, t%dims)
     367    13257186 :             blk_size(:) = get_array_elements(t%blk_size, ind_blk)
     368    31015025 :             tas_blk_size_t = PRODUCT(blk_size)
     369             : 
     370    13257186 :          END FUNCTION
     371             : 
     372             : ! **************************************************************************************************
     373             : !> \brief load balancing criterion whether to accept process grid dimension based on total number of
     374             : !>        cores and tensor dimension
     375             : !> \param pdims_avail available process grid dimensions (total number of cores)
     376             : !> \param pdim process grid dimension to test
     377             : !> \param tdim tensor dimension corresponding to pdim
     378             : !> \param lb_ratio load imbalance acceptance factor
     379             : !> \author Patrick Seewald
     380             : ! **************************************************************************************************
     381       13838 :          PURE FUNCTION accept_pdims_loadbalancing(pdims_avail, pdim, tdim, lb_ratio)
     382             :             INTEGER, INTENT(IN) :: pdims_avail
     383             :             INTEGER, INTENT(IN) :: pdim
     384             :             INTEGER, INTENT(IN) :: tdim
     385             :             REAL(dp), INTENT(IN) :: lb_ratio
     386             :             LOGICAL :: accept_pdims_loadbalancing
     387             : 
     388       13838 :             accept_pdims_loadbalancing = .FALSE.
     389       13838 :             IF (MOD(pdims_avail, pdim) == 0) THEN
     390       11298 :                IF (REAL(tdim, dp)*lb_ratio < REAL(pdim, dp)) THEN
     391        8902 :                   IF (MOD(tdim, pdim) == 0) accept_pdims_loadbalancing = .TRUE.
     392             :                ELSE
     393             :                   accept_pdims_loadbalancing = .TRUE.
     394             :                END IF
     395             :             END IF
     396             : 
     397       13838 :          END FUNCTION
     398             : 
     399             : ! **************************************************************************************************
     400             : !> \brief Create process grid dimensions corresponding to one dimension of the matrix representation
     401             : !>        of a tensor, imposing that no process grid dimension is greater than the corresponding
     402             : !>        tensor dimension.
     403             : !> \param nodes Total number of nodes available for this matrix dimension
     404             : !> \param dims process grid dimension corresponding to tensor_dims
     405             : !> \param tensor_dims tensor dimensions
     406             : !> \param lb_ratio load imbalance acceptance factor
     407             : !> \author Patrick Seewald
     408             : ! **************************************************************************************************
     409        3532 :          RECURSIVE SUBROUTINE dbt_mp_dims_create(nodes, dims, tensor_dims, lb_ratio)
     410             :             INTEGER, INTENT(IN) :: nodes
     411             :             INTEGER, DIMENSION(:), INTENT(INOUT) :: dims
     412             :             INTEGER, DIMENSION(:), INTENT(IN) :: tensor_dims
     413             :             REAL(dp), INTENT(IN), OPTIONAL :: lb_ratio
     414             : 
     415        3532 :             INTEGER, DIMENSION(:), ALLOCATABLE :: tensor_dims_sorted, sort_indices, dims_store
     416        3532 :             REAL(dp), DIMENSION(:), ALLOCATABLE :: sort_key
     417             :             INTEGER :: pdims_rem, idim, pdim
     418             :             REAL(dp) :: lb_ratio_prv
     419             : 
     420        3532 :             IF (PRESENT(lb_ratio)) THEN
     421         550 :                lb_ratio_prv = lb_ratio
     422             :             ELSE
     423        2982 :                lb_ratio_prv = 0.1_dp
     424             :             END IF
     425             : 
     426       18488 :             ALLOCATE (dims_store, source=dims)
     427             : 
     428             :             ! get default process grid dimensions
     429        3532 :             IF (any(dims == 0)) THEN
     430        3532 :                CALL mp_dims_create(nodes, dims)
     431             :             END IF
     432             : 
     433             :             ! sort dimensions such that problematic grid dimensions (those who should be corrected) come first
     434       10596 :             ALLOCATE (sort_key(SIZE(tensor_dims)))
     435       11424 :             sort_key(:) = REAL(tensor_dims, dp)/dims
     436             : 
     437       18488 :             ALLOCATE (tensor_dims_sorted, source=tensor_dims)
     438       10596 :             ALLOCATE (sort_indices(SIZE(sort_key)))
     439        3532 :             CALL sort(sort_key, SIZE(sort_key), sort_indices)
     440       19316 :             tensor_dims_sorted(:) = tensor_dims_sorted(sort_indices)
     441       19316 :             dims(:) = dims(sort_indices)
     442             : 
     443             :             ! remaining number of nodes
     444        3532 :             pdims_rem = nodes
     445             : 
     446       10530 :             DO idim = 1, SIZE(tensor_dims_sorted)
     447       10530 :                IF (.NOT. accept_pdims_loadbalancing(pdims_rem, dims(idim), tensor_dims_sorted(idim), lb_ratio_prv)) THEN
     448        2238 :                   pdim = tensor_dims_sorted(idim)
     449        5946 :                   DO WHILE (.NOT. accept_pdims_loadbalancing(pdims_rem, pdim, tensor_dims_sorted(idim), lb_ratio_prv))
     450        3708 :                      pdim = pdim - 1
     451             :                   END DO
     452        2238 :                   dims(idim) = pdim
     453        2238 :                   pdims_rem = pdims_rem/dims(idim)
     454             : 
     455        2238 :                   IF (idim .NE. SIZE(tensor_dims_sorted)) THEN
     456        3238 :                      dims(idim + 1:) = 0
     457        1344 :                      CALL mp_dims_create(pdims_rem, dims(idim + 1:))
     458         894 :                   ELSEIF (lb_ratio_prv < 0.5_dp) THEN
     459             :                      ! resort to a less strict load imbalance factor
     460        1910 :                      dims(:) = dims_store
     461         550 :                      CALL dbt_mp_dims_create(nodes, dims, tensor_dims, 0.5_dp)
     462         550 :                      RETURN
     463             :                   ELSE
     464             :                      ! resort to default process grid dimensions
     465        1222 :                      dims(:) = dims_store
     466         344 :                      CALL mp_dims_create(nodes, dims)
     467         344 :                      RETURN
     468             :                   END IF
     469             : 
     470             :                ELSE
     471        5654 :                   pdims_rem = pdims_rem/dims(idim)
     472             :                END IF
     473             :             END DO
     474             : 
     475       13946 :             dims(sort_indices) = dims
     476             : 
     477        3532 :          END SUBROUTINE
     478             : 
     479             : ! **************************************************************************************************
     480             : !> \brief Create an n-dimensional process grid.
     481             : !>        We can not use a n-dimensional MPI cartesian grid for tensors since the mapping between
     482             : !>        n-dim. and 2-dim. index allows for an arbitrary reordering of tensor index. Therefore we
     483             : !>        can not use n-dim. MPI Cartesian grid because it may not be consistent with the respective
     484             : !>        2d grid. The 2d Cartesian MPI grid is the reference grid (since tensor data is stored as
     485             : !>        DBM matrix) and this routine creates an object that is a n-dim. interface to this grid.
     486             : !>        map1_2d and map2_2d don't need to be specified (correctly), grid may be redefined in
     487             : !>        dbt_distribution_new. Note that pgrid is equivalent to a MPI cartesian grid only
     488             : !>        if map1_2d and map2_2d don't reorder indices (which is the case if
     489             : !>        [map1_2d, map2_2d] == [1, 2, ..., ndims]). Otherwise the mapping of grid coordinates to
     490             : !>        processes depends on the ordering of the indices and is not equivalent to a MPI cartesian
     491             : !>        grid.
     492             : !> \param mp_comm simple MPI Communicator
     493             : !> \param dims grid dimensions - if entries are 0, dimensions are chosen automatically.
     494             : !> \param pgrid n-dimensional grid object
     495             : !> \param map1_2d which nd-indices map to first matrix index and in which order
     496             : !> \param map2_2d which nd-indices map to first matrix index and in which order
     497             : !> \param tensor_dims tensor block dimensions. If present, process grid dimensions are created such
     498             : !>                    that good load balancing is ensured even if some of the tensor dimensions are
     499             : !>                    small (i.e. on the same order or smaller than nproc**(1/ndim) where ndim is
     500             : !>                    the tensor rank)
     501             : !> \param nsplit impose a constant split factor
     502             : !> \param dimsplit which matrix dimension to split
     503             : !> \author Patrick Seewald
     504             : ! **************************************************************************************************
     505     1506472 :          SUBROUTINE dbt_pgrid_create_expert(mp_comm, dims, pgrid, map1_2d, map2_2d, tensor_dims, nsplit, dimsplit)
     506             :             CLASS(mp_comm_type), INTENT(IN) :: mp_comm
     507             :             INTEGER, DIMENSION(:), INTENT(INOUT) :: dims
     508             :             TYPE(dbt_pgrid_type), INTENT(OUT) :: pgrid
     509             :             INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
     510             :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL :: tensor_dims
     511             :             INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
     512             :             INTEGER, DIMENSION(2) :: pdims_2d
     513             :             INTEGER :: nproc, ndims, handle
     514     2259708 :             TYPE(dbt_tas_split_info) :: info
     515             : 
     516             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_pgrid_create_expert'
     517             : 
     518      376618 :             CALL timeset(routineN, handle)
     519             : 
     520      376618 :             ndims = SIZE(dims)
     521             : 
     522      376618 :             nproc = mp_comm%num_pe
     523     1298356 :             IF (ANY(dims == 0)) THEN
     524        2220 :                IF (.NOT. PRESENT(tensor_dims)) THEN
     525         904 :                   CALL mp_dims_create(nproc, dims)
     526             :                ELSE
     527        1316 :                   CALL dbt_mp_dims_create(nproc, dims, tensor_dims)
     528             :                END IF
     529             :             END IF
     530      376618 :             CALL create_nd_to_2d_mapping(pgrid%nd_index_grid, dims, map1_2d, map2_2d, base=0, col_major=.FALSE.)
     531      376618 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, dims_2d=pdims_2d)
     532      376618 :             CALL pgrid%mp_comm_2d%create(mp_comm, 2, pdims_2d)
     533             : 
     534      376618 :             IF (PRESENT(nsplit)) THEN
     535         772 :                CPASSERT(PRESENT(dimsplit))
     536         772 :                CALL dbt_tas_create_split(info, pgrid%mp_comm_2d, dimsplit, nsplit, opt_nsplit=.FALSE.)
     537         772 :                ALLOCATE (pgrid%tas_split_info, SOURCE=info)
     538             :             END IF
     539             : 
     540             :             ! store number of MPI ranks because we need it for PURE function dbt_max_nblks_local
     541      376618 :             pgrid%nproc = nproc
     542             : 
     543      376618 :             CALL timestop(handle)
     544      376618 :          END SUBROUTINE
     545             : 
     546             : ! **************************************************************************************************
     547             : !> \brief Create a default nd process topology that is consistent with a given 2d topology.
     548             : !>        Purpose: a nd tensor defined on the returned process grid can be represented as a DBM
     549             : !>        matrix with the given 2d topology.
     550             : !>        This is needed to enable contraction of 2 tensors (must have the same 2d process grid).
     551             : !> \param comm_2d communicator with 2-dimensional topology
     552             : !> \param map1_2d which nd-indices map to first matrix index and in which order
     553             : !> \param map2_2d which nd-indices map to second matrix index and in which order
     554             : !> \param dims_nd nd dimensions
     555             : !> \param pdims_2d if comm_2d does not have a cartesian topology associated, can input dimensions
     556             : !>                 with pdims_2d
     557             : !> \param tdims tensor block dimensions. If present, process grid dimensions are created such that
     558             : !>              good load balancing is ensured even if some of the tensor dimensions are small
     559             : !>              (i.e. on the same order or smaller than nproc**(1/ndim) where ndim is the tensor rank)
     560             : !> \return with nd cartesian grid
     561             : !> \author Patrick Seewald
     562             : ! **************************************************************************************************
     563       25066 :          FUNCTION dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, &
     564             :                                  nsplit, dimsplit)
     565             :             CLASS(mp_comm_type), INTENT(IN)                               :: comm_2d
     566             :             INTEGER, DIMENSION(:), INTENT(IN)                 :: map1_2d, map2_2d
     567             :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)), &
     568             :                INTENT(IN), OPTIONAL                           :: dims_nd
     569             :             INTEGER, DIMENSION(SIZE(map1_2d)), INTENT(IN), OPTIONAL :: dims1_nd
     570             :             INTEGER, DIMENSION(SIZE(map2_2d)), INTENT(IN), OPTIONAL :: dims2_nd
     571             :             INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL           :: pdims_2d
     572             :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)), &
     573             :                INTENT(IN), OPTIONAL                           :: tdims
     574             :             INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
     575             :             INTEGER                                           :: ndim1, ndim2
     576             :             INTEGER, DIMENSION(2)                             :: dims_2d
     577             : 
     578       48088 :             INTEGER, DIMENSION(SIZE(map1_2d)) :: dims1_nd_prv
     579       48088 :             INTEGER, DIMENSION(SIZE(map2_2d)) :: dims2_nd_prv
     580       48088 :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)) :: dims_nd_prv
     581             :             INTEGER                                           :: handle
     582             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_nd_mp_comm'
     583             :             TYPE(dbt_pgrid_type)                          :: dbt_nd_mp_comm
     584             : 
     585       24044 :             CALL timeset(routineN, handle)
     586             : 
     587       24044 :             ndim1 = SIZE(map1_2d); ndim2 = SIZE(map2_2d)
     588             : 
     589       24044 :             IF (PRESENT(pdims_2d)) THEN
     590       23574 :                dims_2d(:) = pdims_2d
     591             :             ELSE
     592             : ! This branch allows us to call this routine with a plain mp_comm_type without actually requiring an mp_cart_type
     593             : ! In a few cases in CP2K, this prevents erroneous calls to mpi_cart_get with a non-cartesian communicator
     594             :                SELECT TYPE (comm_2d)
     595             :                CLASS IS (mp_cart_type)
     596        1410 :                   dims_2d = comm_2d%num_pe_cart
     597             :                CLASS DEFAULT
     598             :                   CALL cp_abort(__LOCATION__, "If the argument pdims_2d is not given, the "// &
     599           0 :                                 "communicator comm_2d must be of class mp_cart_type.")
     600             :                END SELECT
     601             :             END IF
     602             : 
     603       24044 :             IF (.NOT. PRESENT(dims_nd)) THEN
     604       97348 :                dims1_nd_prv = 0; dims2_nd_prv = 0
     605       24044 :                IF (PRESENT(dims1_nd)) THEN
     606         640 :                   dims1_nd_prv(:) = dims1_nd
     607             :                ELSE
     608             : 
     609       23806 :                   IF (PRESENT(tdims)) THEN
     610        1544 :                      CALL dbt_mp_dims_create(dims_2d(1), dims1_nd_prv, tdims(map1_2d))
     611             :                   ELSE
     612       23034 :                      CALL mp_dims_create(dims_2d(1), dims1_nd_prv)
     613             :                   END IF
     614             :                END IF
     615             : 
     616       24044 :                IF (PRESENT(dims2_nd)) THEN
     617          28 :                   dims2_nd_prv(:) = dims2_nd
     618             :                ELSE
     619       24032 :                   IF (PRESENT(tdims)) THEN
     620        2316 :                      CALL dbt_mp_dims_create(dims_2d(2), dims2_nd_prv, tdims(map2_2d))
     621             :                   ELSE
     622       23260 :                      CALL mp_dims_create(dims_2d(2), dims2_nd_prv)
     623             :                   END IF
     624             :                END IF
     625       48404 :                dims_nd_prv(map1_2d) = dims1_nd_prv
     626       48944 :                dims_nd_prv(map2_2d) = dims2_nd_prv
     627             :             ELSE
     628           0 :                CPASSERT(PRODUCT(dims_nd(map1_2d)) == dims_2d(1))
     629           0 :                CPASSERT(PRODUCT(dims_nd(map2_2d)) == dims_2d(2))
     630           0 :                dims_nd_prv = dims_nd
     631             :             END IF
     632             : 
     633             :             CALL dbt_pgrid_create_expert(comm_2d, dims_nd_prv, dbt_nd_mp_comm, &
     634       47316 :                                          tensor_dims=tdims, map1_2d=map1_2d, map2_2d=map2_2d, nsplit=nsplit, dimsplit=dimsplit)
     635             : 
     636       24044 :             CALL timestop(handle)
     637             : 
     638       97186 :          END FUNCTION
     639             : 
     640             : ! **************************************************************************************************
     641             : !> \brief Release the MPI communicator.
     642             : !> \author Patrick Seewald
     643             : ! **************************************************************************************************
     644           0 :          SUBROUTINE dbt_nd_mp_free(mp_comm)
     645             :             TYPE(mp_comm_type), INTENT(INOUT)                               :: mp_comm
     646             : 
     647           0 :             CALL mp_comm%free()
     648           0 :          END SUBROUTINE dbt_nd_mp_free
     649             : 
     650             : ! **************************************************************************************************
     651             : !> \brief remap a process grid (needed when mapping between tensor and matrix index is changed)
     652             : !> \param map1_2d new mapping
     653             : !> \param map2_2d new mapping
     654             : !> \author Patrick Seewald
     655             : ! **************************************************************************************************
     656     1396808 :          SUBROUTINE dbt_pgrid_remap(pgrid_in, map1_2d, map2_2d, pgrid_out)
     657             :             TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid_in
     658             :             INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
     659             :             TYPE(dbt_pgrid_type), INTENT(OUT) :: pgrid_out
     660      349202 :             INTEGER, DIMENSION(:), ALLOCATABLE :: dims
     661      698404 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid_in%nd_index_grid)) :: map1_2d_old
     662      349202 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid_in%nd_index_grid)) :: map2_2d_old
     663             : 
     664     1047606 :             ALLOCATE (dims(SIZE(map1_2d) + SIZE(map2_2d)))
     665      349202 :             CALL dbt_get_mapping_info(pgrid_in%nd_index_grid, dims_nd=dims, map1_2d=map1_2d_old, map2_2d=map2_2d_old)
     666      349202 :             CALL dbt_pgrid_create_expert(pgrid_in%mp_comm_2d, dims, pgrid_out, map1_2d=map1_2d, map2_2d=map2_2d)
     667      349202 :             IF (array_eq_i(map1_2d_old, map1_2d) .AND. array_eq_i(map2_2d_old, map2_2d)) THEN
     668      346166 :                IF (ALLOCATED(pgrid_in%tas_split_info)) THEN
     669      331452 :                   ALLOCATE (pgrid_out%tas_split_info, SOURCE=pgrid_in%tas_split_info)
     670      331452 :                   CALL dbt_tas_info_hold(pgrid_out%tas_split_info)
     671             :                END IF
     672             :             END IF
     673      349202 :          END SUBROUTINE
     674             : 
     675             : ! **************************************************************************************************
     676             : !> \brief as mp_environ but for special pgrid type
     677             : !> \author Patrick Seewald
     678             : ! **************************************************************************************************
     679      705818 :          SUBROUTINE mp_environ_pgrid(pgrid, dims, task_coor)
     680             :             TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
     681             :             INTEGER, DIMENSION(ndims_mapping(pgrid%nd_index_grid)), INTENT(OUT) :: dims
     682             :             INTEGER, DIMENSION(ndims_mapping(pgrid%nd_index_grid)), INTENT(OUT) :: task_coor
     683             :             INTEGER, DIMENSION(2)                                          :: task_coor_2d
     684             : 
     685     2117454 :             task_coor_2d = pgrid%mp_comm_2d%mepos_cart
     686      705818 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, dims_nd=dims)
     687      705818 :             task_coor = get_nd_indices_pgrid(pgrid%nd_index_grid, task_coor_2d)
     688      705818 :          END SUBROUTINE
     689             : 
     690             : ! **************************************************************************************************
     691             : !> \brief Create a tensor distribution.
     692             : !> \param pgrid process grid
     693             : !> \param map1_2d which nd-indices map to first matrix index and in which order
     694             : !> \param map2_2d which nd-indices map to second matrix index and in which order
     695             : !> \param own_comm whether distribution should own communicator
     696             : !> \author Patrick Seewald
     697             : ! **************************************************************************************************
     698     2968896 :          SUBROUTINE dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, ${varlist("nd_dist")}$, own_comm)
     699             :             TYPE(dbt_distribution_type), INTENT(OUT)    :: dist
     700             :             TYPE(dbt_pgrid_type), INTENT(IN)            :: pgrid
     701             :             INTEGER, DIMENSION(:), INTENT(IN)               :: map1_2d
     702             :             INTEGER, DIMENSION(:), INTENT(IN)               :: map2_2d
     703             :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL     :: ${varlist("nd_dist")}$
     704             :             LOGICAL, INTENT(IN), OPTIONAL                   :: own_comm
     705             :             INTEGER                                         :: ndims
     706      371112 :             TYPE(mp_cart_type)                              :: comm_2d
     707             :             INTEGER, DIMENSION(2)                           :: pdims_2d_check, &
     708             :                                                                pdims_2d
     709     2226672 :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)) :: dims, nblks_nd, task_coor
     710      371112 :             TYPE(array_list)                                :: nd_dist
     711     1855560 :             TYPE(nd_to_2d_mapping)                          :: map_blks, map_grid
     712             :             INTEGER                                         :: handle
     713      742224 :             TYPE(dbt_tas_dist_t)                          :: row_dist_obj, col_dist_obj
     714     1113336 :             TYPE(dbt_pgrid_type)                        :: pgrid_prv
     715             :             LOGICAL                                         :: need_pgrid_remap
     716      742224 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid%nd_index_grid)) :: map1_2d_check
     717      371112 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid%nd_index_grid)) :: map2_2d_check
     718             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_distribution_new_expert'
     719             : 
     720      371112 :             CALL timeset(routineN, handle)
     721      371112 :             ndims = SIZE(map1_2d) + SIZE(map2_2d)
     722      371112 :             CPASSERT(ndims .GE. 2 .AND. ndims .LE. ${maxdim}$)
     723             : 
     724      942092 :             CALL create_array_list(nd_dist, ndims, ${varlist("nd_dist")}$)
     725             : 
     726     1284580 :             nblks_nd(:) = sizes_of_arrays(nd_dist)
     727             : 
     728      371112 :             need_pgrid_remap = .TRUE.
     729      371112 :             IF (PRESENT(own_comm)) THEN
     730       21910 :                CALL dbt_get_mapping_info(pgrid%nd_index_grid, map1_2d=map1_2d_check, map2_2d=map2_2d_check)
     731       21910 :                IF (own_comm) THEN
     732       21910 :                   IF (.NOT. array_eq_i(map1_2d_check, map1_2d) .OR. .NOT. array_eq_i(map2_2d_check, map2_2d)) THEN
     733           0 :                      CPABORT("map1_2d / map2_2d are not consistent with pgrid")
     734             :                   END IF
     735       21910 :                   pgrid_prv = pgrid
     736             :                   need_pgrid_remap = .FALSE.
     737             :                END IF
     738             :             END IF
     739             : 
     740      349202 :             IF (need_pgrid_remap) CALL dbt_pgrid_remap(pgrid, map1_2d, map2_2d, pgrid_prv)
     741             : 
     742             :             ! check that 2d process topology is consistent with nd topology.
     743      371112 :             CALL mp_environ_pgrid(pgrid_prv, dims, task_coor)
     744             : 
     745             :             ! process grid index mapping
     746      371112 :             CALL create_nd_to_2d_mapping(map_grid, dims, map1_2d, map2_2d, base=0, col_major=.FALSE.)
     747             : 
     748             :             ! blk index mapping
     749      371112 :             CALL create_nd_to_2d_mapping(map_blks, nblks_nd, map1_2d, map2_2d)
     750             : 
     751      371112 :             row_dist_obj = dbt_tas_dist_t(nd_dist, map_blks, map_grid, 1)
     752      371112 :             col_dist_obj = dbt_tas_dist_t(nd_dist, map_blks, map_grid, 2)
     753             : 
     754      371112 :             CALL dbt_get_mapping_info(map_grid, dims_2d=pdims_2d)
     755             : 
     756      371112 :             comm_2d = pgrid_prv%mp_comm_2d
     757             : 
     758     1113336 :             pdims_2d_check = comm_2d%num_pe_cart
     759     1113336 :             IF (ANY(pdims_2d_check .NE. pdims_2d)) THEN
     760           0 :                CPABORT("inconsistent process grid dimensions")
     761             :             END IF
     762             : 
     763      371112 :             IF (ALLOCATED(pgrid_prv%tas_split_info)) THEN
     764      331452 :                CALL dbt_tas_distribution_new(dist%dist, comm_2d, row_dist_obj, col_dist_obj, split_info=pgrid_prv%tas_split_info)
     765             :             ELSE
     766       39660 :                CALL dbt_tas_distribution_new(dist%dist, comm_2d, row_dist_obj, col_dist_obj)
     767       39660 :                ALLOCATE (pgrid_prv%tas_split_info, SOURCE=dist%dist%info)
     768       39660 :                CALL dbt_tas_info_hold(pgrid_prv%tas_split_info)
     769             :             END IF
     770             : 
     771      371112 :             dist%nd_dist = nd_dist
     772      371112 :             dist%pgrid = pgrid_prv
     773             : 
     774      371112 :             ALLOCATE (dist%refcount)
     775      371112 :             dist%refcount = 1
     776      371112 :             CALL timestop(handle)
     777             : 
     778             :          CONTAINS
     779       43820 :             PURE FUNCTION array_eq_i(arr1, arr2)
     780             :                INTEGER, INTENT(IN), DIMENSION(:) :: arr1
     781             :                INTEGER, INTENT(IN), DIMENSION(:) :: arr2
     782             :                LOGICAL                           :: array_eq_i
     783             : 
     784       43820 :                array_eq_i = .FALSE.
     785      131860 :                IF (SIZE(arr1) .EQ. SIZE(arr2)) array_eq_i = ALL(arr1 == arr2)
     786             : 
     787       43820 :             END FUNCTION
     788             : 
     789             :          END SUBROUTINE
     790             : 
     791             : ! **************************************************************************************************
     792             : !> \brief Create a tensor distribution.
     793             : !> \param pgrid process grid
     794             : !> \param nd_dist_i distribution vectors for all tensor dimensions
     795             : !> \author Patrick Seewald
     796             : ! **************************************************************************************************
     797      123824 :          SUBROUTINE dbt_distribution_new(dist, pgrid, ${varlist("nd_dist")}$)
     798             :             TYPE(dbt_distribution_type), INTENT(OUT)    :: dist
     799             :             TYPE(dbt_pgrid_type), INTENT(IN)            :: pgrid
     800             :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL     :: ${varlist("nd_dist")}$
     801       30956 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid%nd_index_grid)) :: map1_2d
     802       15478 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid%nd_index_grid)) :: map2_2d
     803             :             INTEGER :: ndims
     804             : 
     805       15478 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, map1_2d=map1_2d, map2_2d=map2_2d, ndim_nd=ndims)
     806             : 
     807       38788 :             CALL dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, ${varlist("nd_dist")}$)
     808             : 
     809       15478 :          END SUBROUTINE
     810             : 
     811             : ! **************************************************************************************************
     812             : !> \brief destroy process grid
     813             : !> \param keep_comm  if .TRUE. communicator is not freed
     814             : !> \author Patrick Seewald
     815             : ! **************************************************************************************************
     816     1214721 :          SUBROUTINE dbt_pgrid_destroy(pgrid, keep_comm)
     817             :             TYPE(dbt_pgrid_type), INTENT(INOUT) :: pgrid
     818             :             LOGICAL, INTENT(IN), OPTIONAL           :: keep_comm
     819             :             LOGICAL :: keep_comm_prv
     820     1214721 :             IF (PRESENT(keep_comm)) THEN
     821      838103 :                keep_comm_prv = keep_comm
     822             :             ELSE
     823             :                keep_comm_prv = .FALSE.
     824             :             END IF
     825     1214721 :             IF (.NOT. keep_comm_prv) CALL pgrid%mp_comm_2d%free()
     826     1214721 :             CALL destroy_nd_to_2d_mapping(pgrid%nd_index_grid)
     827     1214721 :             IF (ALLOCATED(pgrid%tas_split_info) .AND. .NOT. keep_comm_prv) THEN
     828      371884 :                CALL dbt_tas_release_info(pgrid%tas_split_info)
     829      371884 :                DEALLOCATE (pgrid%tas_split_info)
     830             :             END IF
     831     1214721 :          END SUBROUTINE
     832             : 
     833             : ! **************************************************************************************************
     834             : !> \brief Destroy tensor distribution
     835             : !> \author Patrick Seewald
     836             : ! **************************************************************************************************
     837      371112 :          SUBROUTINE dbt_distribution_destroy(dist)
     838             :             TYPE(dbt_distribution_type), INTENT(INOUT) :: dist
     839             :             INTEGER                                   :: handle
     840             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_distribution_destroy'
     841             :             LOGICAL :: abort
     842             : 
     843      371112 :             CALL timeset(routineN, handle)
     844      371112 :             CALL dbt_tas_distribution_destroy(dist%dist)
     845      371112 :             CALL destroy_array_list(dist%nd_dist)
     846             : 
     847      371112 :             abort = .FALSE.
     848      371112 :             IF (.NOT. ASSOCIATED(dist%refcount)) THEN
     849             :                abort = .TRUE.
     850      371112 :             ELSEIF (dist%refcount < 1) THEN
     851             :                abort = .TRUE.
     852             :             END IF
     853             : 
     854             :             IF (abort) THEN
     855           0 :                CPABORT("can not destroy non-existing tensor distribution")
     856             :             END IF
     857             : 
     858      371112 :             dist%refcount = dist%refcount - 1
     859             : 
     860      371112 :             IF (dist%refcount == 0) THEN
     861      185356 :                CALL dbt_pgrid_destroy(dist%pgrid)
     862      185356 :                DEALLOCATE (dist%refcount)
     863             :             ELSE
     864      185756 :                CALL dbt_pgrid_destroy(dist%pgrid, keep_comm=.TRUE.)
     865             :             END IF
     866             : 
     867      371112 :             CALL timestop(handle)
     868      371112 :          END SUBROUTINE
     869             : 
     870             : ! **************************************************************************************************
     871             : !> \brief reference counting for distribution
     872             : !>        (only needed for communicator handle that must be freed when no longer needed)
     873             : !> \author Patrick Seewald
     874             : ! **************************************************************************************************
     875      185756 :          SUBROUTINE dbt_distribution_hold(dist)
     876             :             TYPE(dbt_distribution_type), INTENT(IN) :: dist
     877             :             INTEGER, POINTER                            :: ref
     878             : 
     879      185756 :             IF (dist%refcount < 1) THEN
     880           0 :                CPABORT("can not hold non-existing tensor distribution")
     881             :             END IF
     882      185756 :             ref => dist%refcount
     883      185756 :             ref = ref + 1
     884      185756 :          END SUBROUTINE
     885             : 
     886             : ! **************************************************************************************************
     887             : !> \brief get distribution from tensor
     888             : !> \return distribution
     889             : !> \author Patrick Seewald
     890             : ! **************************************************************************************************
     891      148558 :          FUNCTION dbt_distribution(tensor)
     892             :             TYPE(dbt_type), INTENT(IN)  :: tensor
     893             :             TYPE(dbt_distribution_type) :: dbt_distribution
     894             : 
     895      148558 :             CALL dbt_tas_get_info(tensor%matrix_rep, distribution=dbt_distribution%dist)
     896      148558 :             dbt_distribution%pgrid = tensor%pgrid
     897      148558 :             dbt_distribution%nd_dist = tensor%nd_dist
     898             :             dbt_distribution%refcount => dbt_distribution%refcount
     899     1039906 :          END FUNCTION
     900             : 
     901             : ! **************************************************************************************************
     902             : !> \author Patrick Seewald
     903             : ! **************************************************************************************************
     904     1857560 :          SUBROUTINE dbt_distribution_remap(dist_in, map1_2d, map2_2d, dist_out)
     905             :             TYPE(dbt_distribution_type), INTENT(IN)    :: dist_in
     906             :             INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
     907             :             TYPE(dbt_distribution_type), INTENT(OUT)    :: dist_out
     908      185756 :             INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("dist")}$
     909             :             INTEGER :: ndims
     910      185756 :             ndims = SIZE(map1_2d) + SIZE(map2_2d)
     911             :             #:for ndim in range(1, maxdim+1)
     912      556964 :                IF (ndims == ${ndim}$) THEN
     913      185756 :                   CALL get_arrays(dist_in%nd_dist, ${varlist("dist", nmax=ndim)}$)
     914      185756 :                   CALL dbt_distribution_new_expert(dist_out, dist_in%pgrid, map1_2d, map2_2d, ${varlist("dist", nmax=ndim)}$)
     915             :                END IF
     916             :             #:endfor
     917      185756 :          END SUBROUTINE
     918             : 
     919             : ! **************************************************************************************************
     920             : !> \brief create a tensor.
     921             : !>        For performance, the arguments map1_2d and map2_2d (controlling matrix representation of
     922             : !>        tensor) should be consistent with the the contraction to be performed (see documentation
     923             : !>        of dbt_contract).
     924             : !> \param map1_2d which nd-indices to map to first 2d index and in which order
     925             : !> \param map2_2d which nd-indices to map to first 2d index and in which order
     926             : !> \param blk_size_i blk sizes in each dimension
     927             : !> \author Patrick Seewald
     928             : ! **************************************************************************************************
     929     1857560 :          SUBROUTINE dbt_create_new(tensor, name, dist, map1_2d, map2_2d, &
     930      185756 :                                    ${varlist("blk_size")}$)
     931             :             TYPE(dbt_type), INTENT(OUT)                   :: tensor
     932             :             CHARACTER(len=*), INTENT(IN)                      :: name
     933             :             TYPE(dbt_distribution_type), INTENT(INOUT)    :: dist
     934             :             INTEGER, DIMENSION(:), INTENT(IN)                 :: map1_2d
     935             :             INTEGER, DIMENSION(:), INTENT(IN)                 :: map2_2d
     936             :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL       :: ${varlist("blk_size")}$
     937             :             INTEGER                                           :: ndims
     938             :             INTEGER(KIND=int_8), DIMENSION(2)                             :: dims_2d
     939      928780 :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)) :: dims, pdims, task_coor
     940      371512 :             TYPE(dbt_tas_blk_size_t)                        :: col_blk_size_obj, row_blk_size_obj
     941     1671804 :             TYPE(dbt_distribution_type)                   :: dist_new
     942      185756 :             TYPE(array_list)                                  :: blk_size, blks_local
     943      557268 :             TYPE(nd_to_2d_mapping)                            :: map
     944             :             INTEGER                                   :: handle
     945             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_create_new'
     946      185756 :             INTEGER, DIMENSION(:), ALLOCATABLE              :: ${varlist("blks_local")}$
     947      185756 :             INTEGER, DIMENSION(:), ALLOCATABLE              :: ${varlist("dist")}$
     948             :             INTEGER                                         :: iblk_count, iblk
     949      185756 :             INTEGER, DIMENSION(:), ALLOCATABLE              :: nblks_local, nfull_local
     950             : 
     951      185756 :             CALL timeset(routineN, handle)
     952      185756 :             ndims = SIZE(map1_2d) + SIZE(map2_2d)
     953      471446 :             CALL create_array_list(blk_size, ndims, ${varlist("blk_size")}$)
     954      643090 :             dims = sizes_of_arrays(blk_size)
     955             : 
     956      185756 :             CALL create_nd_to_2d_mapping(map, dims, map1_2d, map2_2d)
     957      185756 :             CALL dbt_get_mapping_info(map, dims_2d_i8=dims_2d)
     958             : 
     959      185756 :             row_blk_size_obj = dbt_tas_blk_size_t(blk_size, map, 1)
     960      185756 :             col_blk_size_obj = dbt_tas_blk_size_t(blk_size, map, 2)
     961             : 
     962      185756 :             CALL dbt_distribution_remap(dist, map1_2d, map2_2d, dist_new)
     963             : 
     964     1300292 :             ALLOCATE (tensor%matrix_rep)
     965             :             CALL dbt_tas_create(matrix=tensor%matrix_rep, &
     966             :                                 name=TRIM(name)//" matrix", &
     967             :                                 dist=dist_new%dist, &
     968             :                                 row_blk_size=row_blk_size_obj, &
     969      185756 :                                 col_blk_size=col_blk_size_obj)
     970             : 
     971      185756 :             tensor%owns_matrix = .TRUE.
     972             : 
     973      185756 :             tensor%nd_index_blk = map
     974      185756 :             tensor%name = name
     975             : 
     976      185756 :             CALL dbt_tas_finalize(tensor%matrix_rep)
     977      185756 :             CALL destroy_nd_to_2d_mapping(map)
     978             : 
     979             :             ! map element-wise tensor index
     980      185756 :             CALL create_nd_to_2d_mapping(map, sum_of_arrays(blk_size), map1_2d, map2_2d)
     981      185756 :             tensor%nd_index = map
     982      185756 :             tensor%blk_sizes = blk_size
     983             : 
     984      185756 :             CALL mp_environ_pgrid(dist_new%pgrid, pdims, task_coor)
     985             : 
     986             :             #:for ndim in range(1, maxdim+1)
     987      371512 :                IF (ndims == ${ndim}$) THEN
     988      185756 :                   CALL get_arrays(dist_new%nd_dist, ${varlist("dist", nmax=ndim)}$)
     989             :                END IF
     990             :             #:endfor
     991             : 
     992      557268 :             ALLOCATE (nblks_local(ndims))
     993      557268 :             ALLOCATE (nfull_local(ndims))
     994      643090 :             nfull_local(:) = 0
     995             :             #:for idim in range(1, maxdim+1)
     996      642786 :                IF (ndims .GE. ${idim}$) THEN
     997     4542846 :                   nblks_local(${idim}$) = COUNT(dist_${idim}$ == task_coor(${idim}$))
     998     1370838 :                   ALLOCATE (blks_local_${idim}$ (nblks_local(${idim}$)))
     999      457334 :                   iblk_count = 0
    1000     4542846 :                   DO iblk = 1, SIZE(dist_${idim}$)
    1001     4542846 :                      IF (dist_${idim}$ (iblk) == task_coor(${idim}$)) THEN
    1002     3685111 :                         iblk_count = iblk_count + 1
    1003     3685111 :                         blks_local_${idim}$ (iblk_count) = iblk
    1004     3685111 :                         nfull_local(${idim}$) = nfull_local(${idim}$) + blk_size_${idim}$ (iblk)
    1005             :                      END IF
    1006             :                   END DO
    1007             :                END IF
    1008             :             #:endfor
    1009             : 
    1010             :             #:for ndim in range(1, maxdim+1)
    1011      371208 :                IF (ndims == ${ndim}$) THEN
    1012      185756 :                   CALL create_array_list(blks_local, ${ndim}$, ${varlist("blks_local", nmax=ndim)}$)
    1013             :                END IF
    1014             :             #:endfor
    1015             : 
    1016      557268 :             ALLOCATE (tensor%nblks_local(ndims))
    1017      557268 :             ALLOCATE (tensor%nfull_local(ndims))
    1018      643090 :             tensor%nblks_local(:) = nblks_local
    1019      643090 :             tensor%nfull_local(:) = nfull_local
    1020             : 
    1021      185756 :             tensor%blks_local = blks_local
    1022             : 
    1023      185756 :             tensor%nd_dist = dist_new%nd_dist
    1024      185756 :             tensor%pgrid = dist_new%pgrid
    1025             : 
    1026      185756 :             CALL dbt_distribution_hold(dist_new)
    1027      185756 :             tensor%refcount => dist_new%refcount
    1028      185756 :             CALL dbt_distribution_destroy(dist_new)
    1029             : 
    1030      185756 :             CALL array_offsets(tensor%blk_sizes, tensor%blk_offsets)
    1031             : 
    1032      185756 :             tensor%valid = .TRUE.
    1033      185756 :             CALL timestop(handle)
    1034      557268 :          END SUBROUTINE
    1035             : 
    1036             : ! **************************************************************************************************
    1037             : !> \brief reference counting for tensors
    1038             : !>        (only needed for communicator handle that must be freed when no longer needed)
    1039             : !> \author Patrick Seewald
    1040             : ! **************************************************************************************************
    1041      652347 :          SUBROUTINE dbt_hold(tensor)
    1042             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1043             :             INTEGER, POINTER :: ref
    1044             : 
    1045      652347 :             IF (tensor%refcount < 1) THEN
    1046           0 :                CPABORT("can not hold non-existing tensor")
    1047             :             END IF
    1048      652347 :             ref => tensor%refcount
    1049      652347 :             ref = ref + 1
    1050             : 
    1051      652347 :          END SUBROUTINE
    1052             : 
    1053             : ! **************************************************************************************************
    1054             : !> \brief how many tensor dimensions are mapped to matrix row
    1055             : !> \author Patrick Seewald
    1056             : ! **************************************************************************************************
    1057     1588895 :          PURE FUNCTION ndims_matrix_row(tensor)
    1058             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1059             :             INTEGER(int_8) :: ndims_matrix_row
    1060             : 
    1061     1588895 :             ndims_matrix_row = ndims_mapping_row(tensor%nd_index_blk)
    1062             : 
    1063     1588895 :          END FUNCTION
    1064             : 
    1065             : ! **************************************************************************************************
    1066             : !> \brief how many tensor dimensions are mapped to matrix column
    1067             : !> \author Patrick Seewald
    1068             : ! **************************************************************************************************
    1069     1588895 :          PURE FUNCTION ndims_matrix_column(tensor)
    1070             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1071             :             INTEGER(int_8) :: ndims_matrix_column
    1072             : 
    1073     1588895 :             ndims_matrix_column = ndims_mapping_column(tensor%nd_index_blk)
    1074     1588895 :          END FUNCTION
    1075             : 
    1076             : ! **************************************************************************************************
    1077             : !> \brief tensor rank
    1078             : !> \author Patrick Seewald
    1079             : ! **************************************************************************************************
    1080    83623359 :          PURE FUNCTION ndims_tensor(tensor)
    1081             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1082             :             INTEGER                        :: ndims_tensor
    1083             : 
    1084    83623359 :             ndims_tensor = tensor%nd_index%ndim_nd
    1085    83623359 :          END FUNCTION
    1086             : 
    1087             : ! **************************************************************************************************
    1088             : !> \brief tensor dimensions
    1089             : !> \author Patrick Seewald
    1090             : ! **************************************************************************************************
    1091        3580 :          SUBROUTINE dims_tensor(tensor, dims)
    1092             :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1093             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1094             :                INTENT(OUT)                              :: dims
    1095             : 
    1096        3580 :             CPASSERT(tensor%valid)
    1097       17176 :             dims = tensor%nd_index%dims_nd
    1098        3580 :          END SUBROUTINE
    1099             : 
    1100             : ! **************************************************************************************************
    1101             : !> \brief create a tensor from template
    1102             : !> \author Patrick Seewald
    1103             : ! **************************************************************************************************
    1104     2005488 :          SUBROUTINE dbt_create_template(tensor_in, tensor, name, dist, map1_2d, map2_2d)
    1105             :             TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
    1106             :             TYPE(dbt_type), INTENT(OUT)        :: tensor
    1107             :             CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
    1108             :             TYPE(dbt_distribution_type), &
    1109             :                INTENT(INOUT), OPTIONAL             :: dist
    1110             :             INTEGER, DIMENSION(:), INTENT(IN), &
    1111             :                OPTIONAL                            :: map1_2d, map2_2d
    1112             :             INTEGER                                :: handle
    1113             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_create_template'
    1114      250686 :             INTEGER, DIMENSION(:), ALLOCATABLE     :: ${varlist("bsize")}$
    1115      250686 :             INTEGER, DIMENSION(:), ALLOCATABLE     :: map1_2d_prv, map2_2d_prv
    1116             :             CHARACTER(len=default_string_length)   :: name_prv
    1117     1754802 :             TYPE(dbt_distribution_type)        :: dist_prv
    1118             : 
    1119      250686 :             CALL timeset(routineN, handle)
    1120             : 
    1121      250686 :             IF (PRESENT(dist) .OR. PRESENT(map1_2d) .OR. PRESENT(map2_2d)) THEN
    1122             :                ! need to create matrix representation from scratch
    1123         340 :                IF (PRESENT(dist)) THEN
    1124           0 :                   dist_prv = dist
    1125             :                ELSE
    1126         340 :                   dist_prv = dbt_distribution(tensor_in)
    1127             :                END IF
    1128         340 :                IF (PRESENT(map1_2d) .AND. PRESENT(map2_2d)) THEN
    1129        1360 :                   ALLOCATE (map1_2d_prv, source=map1_2d)
    1130        1700 :                   ALLOCATE (map2_2d_prv, source=map2_2d)
    1131             :                ELSE
    1132           0 :                   ALLOCATE (map1_2d_prv(ndims_matrix_row(tensor_in)))
    1133           0 :                   ALLOCATE (map2_2d_prv(ndims_matrix_column(tensor_in)))
    1134           0 :                   CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d_prv, map2_2d=map2_2d_prv)
    1135             :                END IF
    1136         340 :                IF (PRESENT(name)) THEN
    1137           0 :                   name_prv = name
    1138             :                ELSE
    1139         340 :                   name_prv = tensor_in%name
    1140             :                END IF
    1141             : 
    1142             :                #:for ndim in range(1, maxdim+1)
    1143        1020 :                   IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1144         340 :                      CALL get_arrays(tensor_in%blk_sizes, ${varlist("bsize", nmax=ndim)}$)
    1145             :                      CALL dbt_create(tensor, name_prv, dist_prv, map1_2d_prv, map2_2d_prv, &
    1146         340 :                                      ${varlist("bsize", nmax=ndim)}$)
    1147             :                   END IF
    1148             :                #:endfor
    1149             :             ELSE
    1150             :                ! create matrix representation from template
    1151     1251730 :                ALLOCATE (tensor%matrix_rep)
    1152      250346 :                IF (.NOT. PRESENT(name)) THEN
    1153             :                   CALL dbt_tas_create(tensor_in%matrix_rep, tensor%matrix_rep, &
    1154      234250 :                                       name=TRIM(tensor_in%name)//" matrix")
    1155             :                ELSE
    1156       16096 :                   CALL dbt_tas_create(tensor_in%matrix_rep, tensor%matrix_rep, name=TRIM(name)//" matrix")
    1157             :                END IF
    1158      250346 :                tensor%owns_matrix = .TRUE.
    1159      250346 :                CALL dbt_tas_finalize(tensor%matrix_rep)
    1160             : 
    1161      250346 :                tensor%nd_index_blk = tensor_in%nd_index_blk
    1162      250346 :                tensor%nd_index = tensor_in%nd_index
    1163      250346 :                tensor%blk_sizes = tensor_in%blk_sizes
    1164      250346 :                tensor%blk_offsets = tensor_in%blk_offsets
    1165      250346 :                tensor%nd_dist = tensor_in%nd_dist
    1166      250346 :                tensor%blks_local = tensor_in%blks_local
    1167      751038 :                ALLOCATE (tensor%nblks_local(ndims_tensor(tensor_in)))
    1168      924532 :                tensor%nblks_local(:) = tensor_in%nblks_local
    1169      751038 :                ALLOCATE (tensor%nfull_local(ndims_tensor(tensor_in)))
    1170      924532 :                tensor%nfull_local(:) = tensor_in%nfull_local
    1171      250346 :                tensor%pgrid = tensor_in%pgrid
    1172             : 
    1173      250346 :                tensor%refcount => tensor_in%refcount
    1174      250346 :                CALL dbt_hold(tensor)
    1175             : 
    1176      250346 :                tensor%valid = .TRUE.
    1177      250346 :                IF (PRESENT(name)) THEN
    1178       16096 :                   tensor%name = name
    1179             :                ELSE
    1180      234250 :                   tensor%name = tensor_in%name
    1181             :                END IF
    1182             :             END IF
    1183      250686 :             CALL timestop(handle)
    1184      501372 :          END SUBROUTINE
    1185             : 
    1186             : ! **************************************************************************************************
    1187             : !> \brief Create 2-rank tensor from matrix.
    1188             : !> \author Patrick Seewald
    1189             : ! **************************************************************************************************
    1190      192960 :          SUBROUTINE dbt_create_matrix(matrix_in, tensor, order, name)
    1191             :             TYPE(dbcsr_type), INTENT(IN)                :: matrix_in
    1192             :             TYPE(dbt_type), INTENT(OUT)        :: tensor
    1193             :             INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: order
    1194             :             CHARACTER(len=*), INTENT(IN), OPTIONAL      :: name
    1195             : 
    1196             :             CHARACTER(len=default_string_length)        :: name_in
    1197             :             INTEGER, DIMENSION(2)                       :: order_in
    1198             :             TYPE(mp_comm_type)                          :: comm_2d
    1199             :             TYPE(dbcsr_distribution_type)               :: matrix_dist
    1200      192960 :             TYPE(dbt_distribution_type)                 :: dist
    1201       42880 :             INTEGER, DIMENSION(:), POINTER              :: row_blk_size, col_blk_size
    1202       42880 :             INTEGER, DIMENSION(:), POINTER              :: col_dist, row_dist
    1203             :             INTEGER                                   :: handle, comm_2d_handle
    1204             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_create_matrix'
    1205       64320 :             TYPE(dbt_pgrid_type)                  :: comm_nd
    1206             :             INTEGER, DIMENSION(2)                     :: pdims_2d
    1207             : 
    1208       21440 :             CALL timeset(routineN, handle)
    1209             : 
    1210       21440 :             NULLIFY (row_blk_size, col_blk_size, col_dist, row_dist)
    1211       21440 :             IF (PRESENT(name)) THEN
    1212         652 :                name_in = name
    1213             :             ELSE
    1214       20788 :                CALL dbcsr_get_info(matrix_in, name=name_in)
    1215             :             END IF
    1216             : 
    1217       21440 :             IF (PRESENT(order)) THEN
    1218           0 :                order_in = order
    1219             :             ELSE
    1220       21440 :                order_in = [1, 2]
    1221             :             END IF
    1222             : 
    1223       21440 :             CALL dbcsr_get_info(matrix_in, distribution=matrix_dist)
    1224             :             CALL dbcsr_distribution_get(matrix_dist, group=comm_2d_handle, row_dist=row_dist, col_dist=col_dist, &
    1225       21440 :                                         nprows=pdims_2d(1), npcols=pdims_2d(2))
    1226       21440 :             CALL comm_2d%set_handle(comm_2d_handle)
    1227       64320 :             comm_nd = dbt_nd_mp_comm(comm_2d, [order_in(1)], [order_in(2)], pdims_2d=pdims_2d)
    1228             : 
    1229             :             CALL dbt_distribution_new_expert( &
    1230             :                dist, &
    1231             :                comm_nd, &
    1232             :                [order_in(1)], [order_in(2)], &
    1233       64320 :                row_dist, col_dist, own_comm=.TRUE.)
    1234             : 
    1235       21440 :             CALL dbcsr_get_info(matrix_in, row_blk_size=row_blk_size, col_blk_size=col_blk_size)
    1236             : 
    1237             :             CALL dbt_create_new(tensor, name_in, dist, &
    1238             :                                 [order_in(1)], [order_in(2)], &
    1239             :                                 row_blk_size, &
    1240       64320 :                                 col_blk_size)
    1241             : 
    1242       21440 :             CALL dbt_distribution_destroy(dist)
    1243       21440 :             CALL timestop(handle)
    1244       42880 :          END SUBROUTINE
    1245             : 
    1246             : ! **************************************************************************************************
    1247             : !> \brief Destroy a tensor
    1248             : !> \author Patrick Seewald
    1249             : ! **************************************************************************************************
    1250      838103 :          SUBROUTINE dbt_destroy(tensor)
    1251             :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1252             :             INTEGER                                   :: handle
    1253             :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_destroy'
    1254             :             LOGICAL :: abort
    1255             : 
    1256      838103 :             CALL timeset(routineN, handle)
    1257      838103 :             IF (tensor%owns_matrix) THEN
    1258      436102 :                CALL dbt_tas_destroy(tensor%matrix_rep)
    1259      436102 :                DEALLOCATE (tensor%matrix_rep)
    1260             :             ELSE
    1261      402001 :                NULLIFY (tensor%matrix_rep)
    1262             :             END IF
    1263      838103 :             tensor%owns_matrix = .FALSE.
    1264             : 
    1265      838103 :             CALL destroy_nd_to_2d_mapping(tensor%nd_index_blk)
    1266      838103 :             CALL destroy_nd_to_2d_mapping(tensor%nd_index)
    1267             :             !CALL destroy_nd_to_2d_mapping(tensor%nd_index_grid)
    1268      838103 :             CALL destroy_array_list(tensor%blk_sizes)
    1269      838103 :             CALL destroy_array_list(tensor%blk_offsets)
    1270      838103 :             CALL destroy_array_list(tensor%nd_dist)
    1271      838103 :             CALL destroy_array_list(tensor%blks_local)
    1272             : 
    1273      838103 :             DEALLOCATE (tensor%nblks_local, tensor%nfull_local)
    1274             : 
    1275      838103 :             abort = .FALSE.
    1276      838103 :             IF (.NOT. ASSOCIATED(tensor%refcount)) THEN
    1277             :                abort = .TRUE.
    1278      838103 :             ELSEIF (tensor%refcount < 1) THEN
    1279             :                abort = .TRUE.
    1280             :             END IF
    1281             : 
    1282             :             IF (abort) THEN
    1283           0 :                CPABORT("can not destroy non-existing tensor")
    1284             :             END IF
    1285             : 
    1286      838103 :             tensor%refcount = tensor%refcount - 1
    1287             : 
    1288      838103 :             IF (tensor%refcount == 0) THEN
    1289      185756 :                CALL dbt_pgrid_destroy(tensor%pgrid)
    1290             :                !CALL tensor%comm_2d%free()
    1291             :                !CALL tensor%comm_nd%free()
    1292      185756 :                DEALLOCATE (tensor%refcount)
    1293             :             ELSE
    1294      652347 :                CALL dbt_pgrid_destroy(tensor%pgrid, keep_comm=.TRUE.)
    1295             :             END IF
    1296             : 
    1297      838103 :             tensor%valid = .FALSE.
    1298      838103 :             tensor%name = ""
    1299      838103 :             CALL timestop(handle)
    1300      838103 :          END SUBROUTINE
    1301             : 
    1302             : ! **************************************************************************************************
    1303             : !> \brief tensor block dimensions
    1304             : !> \author Patrick Seewald
    1305             : ! **************************************************************************************************
    1306      653465 :          SUBROUTINE blk_dims_tensor(tensor, dims)
    1307             :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1308             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1309             :                INTENT(OUT)                              :: dims
    1310             : 
    1311      653465 :             CPASSERT(tensor%valid)
    1312     2435413 :             dims = tensor%nd_index_blk%dims_nd
    1313      653465 :          END SUBROUTINE
    1314             : 
    1315             : ! **************************************************************************************************
    1316             : !> \brief Size of tensor block
    1317             : !> \author Patrick Seewald
    1318             : ! **************************************************************************************************
    1319    22270410 :          SUBROUTINE dbt_blk_sizes(tensor, ind, blk_size)
    1320             :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1321             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1322             :                INTENT(IN)                               :: ind
    1323             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1324             :                INTENT(OUT)                              :: blk_size
    1325             : 
    1326    22270410 :             blk_size(:) = get_array_elements(tensor%blk_sizes, ind)
    1327    22270410 :          END SUBROUTINE
    1328             : 
    1329             : ! **************************************************************************************************
    1330             : !> \brief offset of tensor block
    1331             : !> \param ind block index
    1332             : !> \param blk_offset block offset
    1333             : !> \author Patrick Seewald
    1334             : ! **************************************************************************************************
    1335           0 :          SUBROUTINE dbt_blk_offsets(tensor, ind, blk_offset)
    1336             :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1337             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1338             :                INTENT(IN)                               :: ind
    1339             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1340             :                INTENT(OUT)                              :: blk_offset
    1341             : 
    1342           0 :             CPASSERT(tensor%valid)
    1343           0 :             blk_offset(:) = get_array_elements(tensor%blk_offsets, ind)
    1344           0 :          END SUBROUTINE
    1345             : 
    1346             : ! **************************************************************************************************
    1347             : !> \brief Generalization of block_get_stored_coordinates for tensors.
    1348             : !> \author Patrick Seewald
    1349             : ! **************************************************************************************************
    1350     6039506 :          SUBROUTINE dbt_get_stored_coordinates(tensor, ind_nd, processor)
    1351             :             TYPE(dbt_type), INTENT(IN)               :: tensor
    1352             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1353             :                INTENT(IN)                                :: ind_nd
    1354             :             INTEGER, INTENT(OUT)                         :: processor
    1355             : 
    1356             :             INTEGER(KIND=int_8), DIMENSION(2)                        :: ind_2d
    1357             : 
    1358     6039506 :             ind_2d(:) = get_2d_indices_tensor(tensor%nd_index_blk, ind_nd)
    1359     6039506 :             CALL dbt_tas_get_stored_coordinates(tensor%matrix_rep, ind_2d(1), ind_2d(2), processor)
    1360     6039506 :          END SUBROUTINE
    1361             : 
    1362             : ! **************************************************************************************************
    1363             : !> \author Patrick Seewald
    1364             : ! **************************************************************************************************
    1365       13488 :          SUBROUTINE dbt_pgrid_create(mp_comm, dims, pgrid, tensor_dims)
    1366             :             CLASS(mp_comm_type), INTENT(IN) :: mp_comm
    1367             :             INTEGER, DIMENSION(:), INTENT(INOUT) :: dims
    1368             :             TYPE(dbt_pgrid_type), INTENT(OUT) :: pgrid
    1369             :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL :: tensor_dims
    1370        3372 :             INTEGER, DIMENSION(:), ALLOCATABLE :: map1_2d, map2_2d
    1371             :             INTEGER :: i, ndims
    1372             : 
    1373        3372 :             ndims = SIZE(dims)
    1374             : 
    1375       10116 :             ALLOCATE (map1_2d(ndims/2))
    1376       10116 :             ALLOCATE (map2_2d(ndims - ndims/2))
    1377       13496 :             map1_2d(:) = (/(i, i=1, SIZE(map1_2d))/)
    1378       18536 :             map2_2d(:) = (/(i, i=SIZE(map1_2d) + 1, SIZE(map1_2d) + SIZE(map2_2d))/)
    1379             : 
    1380        5372 :             CALL dbt_pgrid_create_expert(mp_comm, dims, pgrid, map1_2d, map2_2d, tensor_dims)
    1381             : 
    1382        3372 :          END SUBROUTINE
    1383             : 
    1384             : ! **************************************************************************************************
    1385             : !> \brief freeze current split factor such that it is never changed during contraction
    1386             : !> \author Patrick Seewald
    1387             : ! **************************************************************************************************
    1388           0 :          SUBROUTINE dbt_pgrid_set_strict_split(pgrid)
    1389             :             TYPE(dbt_pgrid_type), INTENT(INOUT) :: pgrid
    1390           0 :             IF (ALLOCATED(pgrid%tas_split_info)) CALL dbt_tas_set_strict_split(pgrid%tas_split_info)
    1391           0 :          END SUBROUTINE
    1392             : 
    1393             : ! **************************************************************************************************
    1394             : !> \brief change dimensions of an existing process grid.
    1395             : !> \param pgrid process grid to be changed
    1396             : !> \param pdims new process grid dimensions, should all be set > 0
    1397             : !> \author Patrick Seewald
    1398             : ! **************************************************************************************************
    1399           0 :          SUBROUTINE dbt_pgrid_change_dims(pgrid, pdims)
    1400             :             TYPE(dbt_pgrid_type), INTENT(INOUT) :: pgrid
    1401             :             INTEGER, DIMENSION(:), INTENT(INOUT)    :: pdims
    1402           0 :             TYPE(dbt_pgrid_type)                :: pgrid_tmp
    1403             :             INTEGER                                 :: nsplit, dimsplit
    1404           0 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid%nd_index_grid)) :: map1_2d
    1405           0 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid%nd_index_grid)) :: map2_2d
    1406           0 :             TYPe(nd_to_2d_mapping)                  :: nd_index_grid
    1407             :             INTEGER, DIMENSION(2)                   :: pdims_2d
    1408             : 
    1409           0 :             CPASSERT(ALL(pdims > 0))
    1410           0 :             CALL dbt_tas_get_split_info(pgrid%tas_split_info, nsplit=nsplit, split_rowcol=dimsplit)
    1411           0 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, map1_2d=map1_2d, map2_2d=map2_2d)
    1412           0 :             CALL create_nd_to_2d_mapping(nd_index_grid, pdims, map1_2d, map2_2d, base=0, col_major=.FALSE.)
    1413           0 :             CALL dbt_get_mapping_info(nd_index_grid, dims_2d=pdims_2d)
    1414           0 :             IF (MOD(pdims_2d(dimsplit), nsplit) == 0) THEN
    1415             :                CALL dbt_pgrid_create_expert(pgrid%mp_comm_2d, pdims, pgrid_tmp, map1_2d=map1_2d, map2_2d=map2_2d, &
    1416           0 :                                             nsplit=nsplit, dimsplit=dimsplit)
    1417             :             ELSE
    1418           0 :                CALL dbt_pgrid_create_expert(pgrid%mp_comm_2d, pdims, pgrid_tmp, map1_2d=map1_2d, map2_2d=map2_2d)
    1419             :             END IF
    1420           0 :             CALL dbt_pgrid_destroy(pgrid)
    1421           0 :             pgrid = pgrid_tmp
    1422           0 :          END SUBROUTINE
    1423             : 
    1424             : ! **************************************************************************************************
    1425             : !> \brief As block_filter
    1426             : !> \author Patrick Seewald
    1427             : ! **************************************************************************************************
    1428      169109 :          SUBROUTINE dbt_filter(tensor, eps)
    1429             :             TYPE(dbt_type), INTENT(INOUT)    :: tensor
    1430             :             REAL(dp), INTENT(IN)                :: eps
    1431             : 
    1432      169109 :             CALL dbt_tas_filter(tensor%matrix_rep, eps)
    1433             : 
    1434      169109 :          END SUBROUTINE
    1435             : 
    1436             : ! **************************************************************************************************
    1437             : !> \brief local number of blocks along dimension idim
    1438             : !> \author Patrick Seewald
    1439             : ! **************************************************************************************************
    1440      359110 :          PURE FUNCTION dbt_nblks_local(tensor, idim)
    1441             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1442             :             INTEGER, INTENT(IN) :: idim
    1443             :             INTEGER :: dbt_nblks_local
    1444             : 
    1445      359110 :             IF (idim > ndims_tensor(tensor)) THEN
    1446             :                dbt_nblks_local = 0
    1447             :             ELSE
    1448      359110 :                dbt_nblks_local = tensor%nblks_local(idim)
    1449             :             END IF
    1450             : 
    1451      359110 :          END FUNCTION
    1452             : 
    1453             : ! **************************************************************************************************
    1454             : !> \brief total numbers of blocks along dimension idim
    1455             : !> \author Patrick Seewald
    1456             : ! **************************************************************************************************
    1457     2386074 :          PURE FUNCTION dbt_nblks_total(tensor, idim)
    1458             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1459             :             INTEGER, INTENT(IN) :: idim
    1460             :             INTEGER :: dbt_nblks_total
    1461             : 
    1462     2386074 :             IF (idim > ndims_tensor(tensor)) THEN
    1463             :                dbt_nblks_total = 0
    1464             :             ELSE
    1465     1858428 :                dbt_nblks_total = tensor%nd_index_blk%dims_nd(idim)
    1466             :             END IF
    1467     2386074 :          END FUNCTION
    1468             : 
    1469             : ! **************************************************************************************************
    1470             : !> \brief As block_get_info but for tensors
    1471             : !> \param nblks_total number of blocks along each dimension
    1472             : !> \param nfull_total number of elements along each dimension
    1473             : !> \param nblks_local local number of blocks along each dimension
    1474             : !> \param nfull_local local number of elements along each dimension
    1475             : !> \param my_ploc process coordinates in process grid
    1476             : !> \param pdims process grid dimensions
    1477             : !> \param blks_local_${idim}$ local blocks along dimension ${idim}$
    1478             : !> \param proc_dist_${idim}$ distribution along dimension ${idim}$
    1479             : !> \param blk_size_${idim}$ block sizes along dimension ${idim}$
    1480             : !> \param blk_offset_${idim}$ block offsets along dimension ${idim}$
    1481             : !> \param distribution distribution object
    1482             : !> \param name name of tensor
    1483             : !> \author Patrick Seewald
    1484             : ! **************************************************************************************************
    1485           0 :          SUBROUTINE dbt_get_info(tensor, nblks_total, &
    1486             :                                  nfull_total, &
    1487      133748 :                                  nblks_local, &
    1488      133748 :                                  nfull_local, &
    1489             :                                  pdims, &
    1490             :                                  my_ploc, &
    1491             :                                  ${varlist("blks_local")}$, &
    1492             :                                  ${varlist("proc_dist")}$, &
    1493             :                                  ${varlist("blk_size")}$, &
    1494             :                                  ${varlist("blk_offset")}$, &
    1495             :                                  distribution, &
    1496             :                                  name)
    1497             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1498             :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nblks_total
    1499             :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nfull_total
    1500             :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nblks_local
    1501             :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nfull_local
    1502             :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: my_ploc
    1503             :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: pdims
    1504             :             #:for idim in range(1, maxdim+1)
    1505             :                INTEGER, DIMENSION(dbt_nblks_local(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: blks_local_${idim}$
    1506             :                INTEGER, DIMENSION(dbt_nblks_total(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: proc_dist_${idim}$
    1507             :                INTEGER, DIMENSION(dbt_nblks_total(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: blk_size_${idim}$
    1508             :                INTEGER, DIMENSION(dbt_nblks_total(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: blk_offset_${idim}$
    1509             :             #:endfor
    1510             :             TYPE(dbt_distribution_type), INTENT(OUT), OPTIONAL    :: distribution
    1511             :             CHARACTER(len=*), INTENT(OUT), OPTIONAL                   :: name
    1512     1312800 :             INTEGER, DIMENSION(ndims_tensor(tensor))                  :: pdims_tmp, my_ploc_tmp
    1513             : 
    1514      656400 :             IF (PRESENT(nblks_total)) CALL dbt_get_mapping_info(tensor%nd_index_blk, dims_nd=nblks_total)
    1515      656400 :             IF (PRESENT(nfull_total)) CALL dbt_get_mapping_info(tensor%nd_index, dims_nd=nfull_total)
    1516     1015510 :             IF (PRESENT(nblks_local)) nblks_local(:) = tensor%nblks_local
    1517     1015510 :             IF (PRESENT(nfull_local)) nfull_local(:) = tensor%nfull_local
    1518             : 
    1519      656400 :             IF (PRESENT(my_ploc) .OR. PRESENT(pdims)) CALL mp_environ_pgrid(tensor%pgrid, pdims_tmp, my_ploc_tmp)
    1520     1015510 :             IF (PRESENT(my_ploc)) my_ploc = my_ploc_tmp
    1521     1016080 :             IF (PRESENT(pdims)) pdims = pdims_tmp
    1522             : 
    1523             :             #:for idim in range(1, maxdim+1)
    1524     2395153 :                IF (${idim}$ <= ndims_tensor(tensor)) THEN
    1525     1738805 :                   IF (PRESENT(blks_local_${idim}$)) CALL get_ith_array(tensor%blks_local, ${idim}$, &
    1526             :                                                                        dbt_nblks_local(tensor, ${idim}$), &
    1527      359110 :                                                                        blks_local_${idim}$)
    1528     1738805 :                   IF (PRESENT(proc_dist_${idim}$)) CALL get_ith_array(tensor%nd_dist, ${idim}$, &
    1529             :                                                                       dbt_nblks_total(tensor, ${idim}$), &
    1530      361004 :                                                                       proc_dist_${idim}$)
    1531     1738805 :                   IF (PRESENT(blk_size_${idim}$)) CALL get_ith_array(tensor%blk_sizes, ${idim}$, &
    1532             :                                                                      dbt_nblks_total(tensor, ${idim}$), &
    1533      415272 :                                                                      blk_size_${idim}$)
    1534     1738805 :                   IF (PRESENT(blk_offset_${idim}$)) CALL get_ith_array(tensor%blk_offsets, ${idim}$, &
    1535             :                                                                        dbt_nblks_total(tensor, ${idim}$), &
    1536          56 :                                                                        blk_offset_${idim}$)
    1537             :                END IF
    1538             :             #:endfor
    1539             : 
    1540      656400 :             IF (PRESENT(distribution)) distribution = dbt_distribution(tensor)
    1541      656400 :             IF (PRESENT(name)) name = tensor%name
    1542             : 
    1543     1446548 :          END SUBROUTINE
    1544             : 
    1545             : ! **************************************************************************************************
    1546             : !> \brief As block_get_num_blocks: get number of local blocks
    1547             : !> \author Patrick Seewald
    1548             : ! **************************************************************************************************
    1549      412048 :          PURE FUNCTION dbt_get_num_blocks(tensor) RESULT(num_blocks)
    1550             :             TYPE(dbt_type), INTENT(IN)    :: tensor
    1551             :             INTEGER                           :: num_blocks
    1552      412048 :             num_blocks = dbt_tas_get_num_blocks(tensor%matrix_rep)
    1553      412048 :          END FUNCTION
    1554             : 
    1555             : ! **************************************************************************************************
    1556             : !> \brief Get total number of blocks
    1557             : !> \author Patrick Seewald
    1558             : ! **************************************************************************************************
    1559      135796 :          FUNCTION dbt_get_num_blocks_total(tensor) RESULT(num_blocks)
    1560             :             TYPE(dbt_type), INTENT(IN)    :: tensor
    1561             :             INTEGER(KIND=int_8)               :: num_blocks
    1562      135796 :             num_blocks = dbt_tas_get_num_blocks_total(tensor%matrix_rep)
    1563      135796 :          END FUNCTION
    1564             : 
    1565             : ! **************************************************************************************************
    1566             : !> \brief Clear tensor (s.t. it does not contain any blocks)
    1567             : !> \author Patrick Seewald
    1568             : ! **************************************************************************************************
    1569      733477 :          SUBROUTINE dbt_clear(tensor)
    1570             :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1571             : 
    1572      733477 :             CALL dbt_tas_clear(tensor%matrix_rep)
    1573      733477 :          END SUBROUTINE
    1574             : 
    1575             : ! **************************************************************************************************
    1576             : !> \brief Finalize tensor, as block_finalize. This should be taken care of internally in DBT
    1577             : !>        tensors, there should not be any need to call this routine outside of DBT tensors.
    1578             : !> \author Patrick Seewald
    1579             : ! **************************************************************************************************
    1580      864652 :          SUBROUTINE dbt_finalize(tensor)
    1581             :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1582      864652 :             CALL dbt_tas_finalize(tensor%matrix_rep)
    1583      864652 :          END SUBROUTINE
    1584             : 
    1585             : ! **************************************************************************************************
    1586             : !> \brief as block_scale
    1587             : !> \author Patrick Seewald
    1588             : ! **************************************************************************************************
    1589       24769 :          SUBROUTINE dbt_scale(tensor, alpha)
    1590             :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1591             :             REAL(dp), INTENT(IN) :: alpha
    1592       24769 :             CALL dbm_scale(tensor%matrix_rep%matrix, alpha)
    1593       24769 :          END SUBROUTINE
    1594             : 
    1595             : ! **************************************************************************************************
    1596             : !> \author Patrick Seewald
    1597             : ! **************************************************************************************************
    1598      133748 :          PURE FUNCTION dbt_get_nze(tensor)
    1599             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1600             :             INTEGER                        :: dbt_get_nze
    1601      133748 :             dbt_get_nze = dbt_tas_get_nze(tensor%matrix_rep)
    1602      133748 :          END FUNCTION
    1603             : 
    1604             : ! **************************************************************************************************
    1605             : !> \author Patrick Seewald
    1606             : ! **************************************************************************************************
    1607      245335 :          FUNCTION dbt_get_nze_total(tensor)
    1608             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1609             :             INTEGER(KIND=int_8)            :: dbt_get_nze_total
    1610      245335 :             dbt_get_nze_total = dbt_tas_get_nze_total(tensor%matrix_rep)
    1611      245335 :          END FUNCTION
    1612             : 
    1613             : ! **************************************************************************************************
    1614             : !> \brief block size of block with index ind along dimension idim
    1615             : !> \author Patrick Seewald
    1616             : ! **************************************************************************************************
    1617           0 :          PURE FUNCTION dbt_blk_size(tensor, ind, idim)
    1618             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1619             :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1620             :                INTENT(IN) :: ind
    1621             :             INTEGER, INTENT(IN) :: idim
    1622           0 :             INTEGER, DIMENSION(ndims_tensor(tensor)) :: blk_size
    1623             :             INTEGER :: dbt_blk_size
    1624             : 
    1625           0 :             IF (idim > ndims_tensor(tensor)) THEN
    1626             :                dbt_blk_size = 0
    1627             :             ELSE
    1628           0 :                blk_size(:) = get_array_elements(tensor%blk_sizes, ind)
    1629           0 :                dbt_blk_size = blk_size(idim)
    1630             :             END IF
    1631           0 :          END FUNCTION
    1632             : 
    1633             : ! **************************************************************************************************
    1634             : !> \brief returns an estimate of maximum number of local blocks in tensor
    1635             : !>        (irrespective of the actual number of currently present blocks)
    1636             : !>        this estimate is based on the following assumption: tensor data is dense and
    1637             : !>        load balancing is within a factor of 2
    1638             : !> \author Patrick Seewald
    1639             : ! **************************************************************************************************
    1640           0 :          PURE FUNCTION dbt_max_nblks_local(tensor) RESULT(blk_count)
    1641             :             TYPE(dbt_type), INTENT(IN) :: tensor
    1642             :             INTEGER :: blk_count, nproc
    1643           0 :             INTEGER, DIMENSION(ndims_tensor(tensor)) :: bdims
    1644             :             INTEGER(int_8) :: blk_count_total
    1645             :             INTEGER, PARAMETER :: max_load_imbalance = 2
    1646             : 
    1647           0 :             CALL dbt_get_mapping_info(tensor%nd_index_blk, dims_nd=bdims)
    1648             : 
    1649           0 :             blk_count_total = PRODUCT(INT(bdims, int_8))
    1650             : 
    1651             :             ! can not call an MPI routine due to PURE
    1652           0 :             nproc = tensor%pgrid%nproc
    1653             : 
    1654           0 :             blk_count = INT(blk_count_total/nproc*max_load_imbalance)
    1655             : 
    1656           0 :          END FUNCTION
    1657             : 
    1658             : ! **************************************************************************************************
    1659             : !> \brief get a load-balanced and randomized distribution along one tensor dimension
    1660             : !> \param nblk number of blocks (along one tensor dimension)
    1661             : !> \param nproc number of processes (along one process grid dimension)
    1662             : !> \param blk_size block sizes
    1663             : !> \param dist distribution
    1664             : !> \author Patrick Seewald
    1665             : ! **************************************************************************************************
    1666       36230 :          SUBROUTINE dbt_default_distvec(nblk, nproc, blk_size, dist)
    1667             :             INTEGER, INTENT(IN)                                :: nblk
    1668             :             INTEGER, INTENT(IN)                                :: nproc
    1669             :             INTEGER, DIMENSION(nblk), INTENT(IN)                :: blk_size
    1670             :             INTEGER, DIMENSION(nblk), INTENT(OUT)               :: dist
    1671             : 
    1672       36230 :             CALL dbt_tas_default_distvec(nblk, nproc, blk_size, dist)
    1673       36230 :          END SUBROUTINE
    1674             : 
    1675             : ! **************************************************************************************************
    1676             : !> \author Patrick Seewald
    1677             : ! **************************************************************************************************
    1678      451867 :          SUBROUTINE dbt_copy_contraction_storage(tensor_in, tensor_out)
    1679             :             TYPE(dbt_type), INTENT(IN) :: tensor_in
    1680             :             TYPE(dbt_type), INTENT(INOUT) :: tensor_out
    1681      451867 :             TYPE(dbt_contraction_storage), ALLOCATABLE :: tensor_storage_tmp
    1682      451867 :             TYPE(dbt_tas_mm_storage), ALLOCATABLE :: tas_storage_tmp
    1683             : 
    1684      451867 :             IF (tensor_in%matrix_rep%do_batched > 0) THEN
    1685      124437 :                ALLOCATE (tas_storage_tmp, SOURCE=tensor_in%matrix_rep%mm_storage)
    1686             :                ! transfer data for batched contraction
    1687      124437 :                IF (ALLOCATED(tensor_out%matrix_rep%mm_storage)) DEALLOCATE (tensor_out%matrix_rep%mm_storage)
    1688      124437 :                CALL move_alloc(tas_storage_tmp, tensor_out%matrix_rep%mm_storage)
    1689             :             END IF
    1690             :             CALL dbt_tas_set_batched_state(tensor_out%matrix_rep, state=tensor_in%matrix_rep%do_batched, &
    1691      451867 :                                            opt_grid=tensor_in%matrix_rep%has_opt_pgrid)
    1692      451867 :             IF (ALLOCATED(tensor_in%contraction_storage)) THEN
    1693      366961 :                ALLOCATE (tensor_storage_tmp, SOURCE=tensor_in%contraction_storage)
    1694             :             END IF
    1695      451867 :             IF (ALLOCATED(tensor_out%contraction_storage)) DEALLOCATE (tensor_out%contraction_storage)
    1696      451867 :             IF (ALLOCATED(tensor_storage_tmp)) CALL move_alloc(tensor_storage_tmp, tensor_out%contraction_storage)
    1697             : 
    1698      451867 :          END SUBROUTINE
    1699             : 
    1700    26251030 :       END MODULE

Generated by: LCOV version 1.15