LCOV - code coverage report
Current view: top level - src/dbt - dbt_types.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 90.0 % 580 522
Test Date: 2025-07-25 12:55:17 Functions: 77.4 % 62 48

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief DBT tensor framework for block-sparse tensor contraction: Types and create/destroy routines.
      10              : !> \author Patrick Seewald
      11              : ! **************************************************************************************************
      12              : MODULE dbt_types
      13              :    #:include "dbt_macros.fypp"
      14              :    #:set maxdim = maxrank
      15              :    #:set ndims = range(2,maxdim+1)
      16              : 
      17              :    USE cp_dbcsr_api, ONLY: dbcsr_type, dbcsr_get_info, dbcsr_distribution_type, dbcsr_distribution_get
      18              :    USE dbt_array_list_methods, ONLY: &
      19              :       array_list, array_offsets, create_array_list, destroy_array_list, get_array_elements, &
      20              :       sizes_of_arrays, sum_of_arrays, array_sublist, get_arrays, get_ith_array, array_eq_i
      21              :    USE dbm_api, ONLY: &
      22              :       dbm_distribution_obj, dbm_type
      23              :    USE kinds, ONLY: dp, dp, default_string_length
      24              :    USE dbt_tas_base, ONLY: &
      25              :       dbt_tas_create, dbt_tas_distribution_new, &
      26              :       dbt_tas_distribution_destroy, dbt_tas_finalize, dbt_tas_get_info, &
      27              :       dbt_tas_destroy, dbt_tas_get_stored_coordinates, dbt_tas_filter, &
      28              :       dbt_tas_get_num_blocks, dbt_tas_get_num_blocks_total, dbt_tas_get_nze, &
      29              :       dbt_tas_get_nze_total, dbt_tas_clear
      30              :    USE dbt_tas_types, ONLY: &
      31              :       dbt_tas_type, dbt_tas_distribution_type, dbt_tas_split_info, dbt_tas_mm_storage
      32              :    USE dbt_tas_mm, ONLY: dbt_tas_set_batched_state
      33              :    USE dbt_index, ONLY: &
      34              :       get_2d_indices_tensor, get_nd_indices_pgrid, create_nd_to_2d_mapping, destroy_nd_to_2d_mapping, &
      35              :       dbt_get_mapping_info, nd_to_2d_mapping, split_tensor_index, combine_tensor_index, combine_pgrid_index, &
      36              :       split_pgrid_index, ndims_mapping, ndims_mapping_row, ndims_mapping_column
      37              :    USE dbt_tas_split, ONLY: &
      38              :       dbt_tas_create_split_rows_or_cols, dbt_tas_release_info, dbt_tas_info_hold, &
      39              :       dbt_tas_create_split, dbt_tas_get_split_info, dbt_tas_set_strict_split
      40              :    USE kinds, ONLY: default_string_length, int_8, dp
      41              :    USE message_passing, ONLY: &
      42              :       mp_cart_type, mp_dims_create, mp_comm_type
      43              :    USE dbt_tas_global, ONLY: dbt_tas_distribution, dbt_tas_rowcol_data, dbt_tas_default_distvec
      44              :    USE dbt_allocate_wrap, ONLY: allocate_any
      45              :    USE dbm_api, ONLY: dbm_scale
      46              :    USE util, ONLY: sort
      47              : #include "../base/base_uses.f90"
      48              : 
      49              :    IMPLICIT NONE
      50              :    PRIVATE
      51              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_types'
      52              : 
      53              :    PUBLIC  :: &
      54              :       blk_dims_tensor, &
      55              :       dbt_blk_offsets, &
      56              :       dbt_blk_sizes, &
      57              :       dbt_clear, &
      58              :       dbt_create, &
      59              :       dbt_destroy, &
      60              :       dbt_distribution, &
      61              :       dbt_distribution_destroy, &
      62              :       dbt_distribution_new, &
      63              :       dbt_distribution_new_expert, &
      64              :       dbt_distribution_type, &
      65              :       dbt_filter, &
      66              :       dbt_finalize, &
      67              :       dbt_get_info, &
      68              :       dbt_get_num_blocks, &
      69              :       dbt_get_num_blocks_total, &
      70              :       dbt_get_nze, &
      71              :       dbt_get_nze_total, &
      72              :       dbt_get_stored_coordinates, &
      73              :       dbt_hold, &
      74              :       dbt_mp_dims_create, &
      75              :       dbt_nd_mp_comm, &
      76              :       dbt_nd_mp_free, &
      77              :       dbt_pgrid_change_dims, &
      78              :       dbt_pgrid_create, &
      79              :       dbt_pgrid_create_expert, &
      80              :       dbt_pgrid_destroy, &
      81              :       dbt_pgrid_type, &
      82              :       dbt_pgrid_set_strict_split, &
      83              :       dbt_scale, &
      84              :       dbt_type, &
      85              :       dims_tensor, &
      86              :       mp_environ_pgrid, &
      87              :       ndims_tensor, &
      88              :       ndims_matrix_row, &
      89              :       ndims_matrix_column, &
      90              :       dbt_nblks_local, &
      91              :       dbt_nblks_total, &
      92              :       dbt_blk_size, &
      93              :       dbt_max_nblks_local, &
      94              :       dbt_default_distvec, &
      95              :       dbt_contraction_storage, &
      96              :       dbt_copy_contraction_storage
      97              : 
      98              :    TYPE dbt_pgrid_type
      99              :       TYPE(nd_to_2d_mapping)                  :: nd_index_grid
     100              :       TYPE(mp_cart_type)                      :: mp_comm_2d
     101              :       TYPE(dbt_tas_split_info), ALLOCATABLE   :: tas_split_info
     102              :       INTEGER                                 :: nproc = -1
     103              :    END TYPE
     104              : 
     105              :    TYPE dbt_contraction_storage
     106              :       REAL(dp)         :: nsplit_avg = 0.0_dp
     107              :       INTEGER          :: ibatch = -1
     108              :       TYPE(array_list) :: batch_ranges
     109              :       LOGICAL          :: static = .FALSE.
     110              :    END TYPE
     111              : 
     112              :    TYPE dbt_type
     113              :       TYPE(dbt_tas_type), POINTER                :: matrix_rep => NULL()
     114              :       TYPE(nd_to_2d_mapping)                     :: nd_index_blk
     115              :       TYPE(nd_to_2d_mapping)                     :: nd_index
     116              :       TYPE(array_list)                           :: blk_sizes
     117              :       TYPE(array_list)                           :: blk_offsets
     118              :       TYPE(array_list)                           :: nd_dist
     119              :       TYPE(dbt_pgrid_type)                       :: pgrid
     120              :       TYPE(array_list)                           :: blks_local
     121              :       INTEGER, DIMENSION(:), ALLOCATABLE         :: nblks_local
     122              :       INTEGER, DIMENSION(:), ALLOCATABLE         :: nfull_local
     123              :       LOGICAL                                    :: valid = .FALSE.
     124              :       LOGICAL                                    :: owns_matrix = .FALSE.
     125              :       CHARACTER(LEN=default_string_length)       :: name = ""
     126              :       ! lightweight reference counting for communicators:
     127              :       INTEGER, POINTER                           :: refcount => NULL()
     128              :       TYPE(dbt_contraction_storage), ALLOCATABLE :: contraction_storage
     129              :    END TYPE dbt_type
     130              : 
     131              :    TYPE dbt_distribution_type
     132              :       TYPE(dbt_tas_distribution_type) :: dist
     133              :       TYPE(dbt_pgrid_type)            :: pgrid
     134              :       TYPE(array_list)                :: nd_dist
     135              :       ! lightweight reference counting for communicators:
     136              :       INTEGER, POINTER                :: refcount => NULL()
     137              :    END TYPE
     138              : 
     139              : ! **************************************************************************************************
     140              : !> \brief tas matrix distribution function object for one matrix index
     141              : !> \var dims tensor     dimensions only for this matrix dimension
     142              : !> \var dims_grid       grid dimensions only for this matrix dimension
     143              : !> \var nd_dist         dist only for tensor dimensions belonging to this matrix dimension
     144              : !> \var tas_dist_t map  matrix index to process grid
     145              : !> \var tas_rowcols_t   map process grid to matrix index
     146              : ! **************************************************************************************************
     147              :    TYPE, EXTENDS(dbt_tas_distribution) :: dbt_tas_dist_t
     148              :       INTEGER, DIMENSION(:), ALLOCATABLE :: dims
     149              :       INTEGER, DIMENSION(:), ALLOCATABLE :: dims_grid
     150              :       TYPE(array_list)                   :: nd_dist
     151              :    CONTAINS
     152              :       PROCEDURE                          :: dist => tas_dist_t
     153              :       PROCEDURE                          :: rowcols => tas_rowcols_t
     154              :    END TYPE
     155              : 
     156              : ! **************************************************************************************************
     157              : !> \brief  block size object for one matrix index
     158              : !> \var dims tensor dimensions only for this matrix dimension
     159              : !> \var blk_size block size only for this matrix dimension
     160              : ! **************************************************************************************************
     161              :    TYPE, EXTENDS(dbt_tas_rowcol_data) :: dbt_tas_blk_size_t
     162              :       INTEGER, DIMENSION(:), ALLOCATABLE :: dims
     163              :       TYPE(array_list)                   :: blk_size
     164              :    CONTAINS
     165              :       PROCEDURE                          :: data => tas_blk_size_t
     166              :    END TYPE
     167              : 
     168              :    INTERFACE dbt_create
     169              :       MODULE PROCEDURE dbt_create_new
     170              :       MODULE PROCEDURE dbt_create_template
     171              :       MODULE PROCEDURE dbt_create_matrix
     172              :    END INTERFACE
     173              : 
     174              :    INTERFACE dbt_tas_dist_t
     175              :       MODULE PROCEDURE new_dbt_tas_dist_t
     176              :    END INTERFACE
     177              : 
     178              :    INTERFACE dbt_tas_blk_size_t
     179              :       MODULE PROCEDURE new_dbt_tas_blk_size_t
     180              :    END INTERFACE
     181              : 
     182              : CONTAINS
     183              : 
     184              : ! **************************************************************************************************
     185              : !> \brief Create distribution object for one matrix dimension
     186              : !> \param nd_dist arrays for distribution vectors along all dimensions
     187              : !> \param map_blks tensor to matrix mapping object for blocks
     188              : !> \param map_grid tensor to matrix mapping object for process grid
     189              : !> \param which_dim for which dimension (1 or 2) distribution should be created
     190              : !> \return distribution object
     191              : !> \author Patrick Seewald
     192              : ! **************************************************************************************************
     193       877316 :    FUNCTION new_dbt_tas_dist_t(nd_dist, map_blks, map_grid, which_dim)
     194              :       TYPE(array_list), INTENT(IN)       :: nd_dist
     195              :       TYPE(nd_to_2d_mapping), INTENT(IN) :: map_blks, map_grid
     196              :       INTEGER, INTENT(IN)                :: which_dim
     197              : 
     198              :       TYPE(dbt_tas_dist_t)               :: new_dbt_tas_dist_t
     199              :       INTEGER, DIMENSION(2)              :: grid_dims
     200              :       INTEGER(KIND=int_8), DIMENSION(2)  :: matrix_dims
     201       877316 :       INTEGER, DIMENSION(:), ALLOCATABLE :: index_map
     202              : 
     203       877316 :       IF (which_dim == 1) THEN
     204      1315974 :          ALLOCATE (new_dbt_tas_dist_t%dims(ndims_mapping_row(map_blks)))
     205      1315974 :          ALLOCATE (index_map(ndims_mapping_row(map_blks)))
     206              :          CALL dbt_get_mapping_info(map_blks, &
     207              :                                    dims_2d_i8=matrix_dims, &
     208              :                                    map1_2d=index_map, &
     209       438658 :                                    dims1_2d=new_dbt_tas_dist_t%dims)
     210      1315974 :          ALLOCATE (new_dbt_tas_dist_t%dims_grid(ndims_mapping_row(map_grid)))
     211              :          CALL dbt_get_mapping_info(map_grid, &
     212              :                                    dims_2d=grid_dims, &
     213       438658 :                                    dims1_2d=new_dbt_tas_dist_t%dims_grid)
     214       438658 :       ELSEIF (which_dim == 2) THEN
     215      1315974 :          ALLOCATE (new_dbt_tas_dist_t%dims(ndims_mapping_column(map_blks)))
     216      1315974 :          ALLOCATE (index_map(ndims_mapping_column(map_blks)))
     217              :          CALL dbt_get_mapping_info(map_blks, &
     218              :                                    dims_2d_i8=matrix_dims, &
     219              :                                    map2_2d=index_map, &
     220       438658 :                                    dims2_2d=new_dbt_tas_dist_t%dims)
     221      1315974 :          ALLOCATE (new_dbt_tas_dist_t%dims_grid(ndims_mapping_column(map_grid)))
     222              :          CALL dbt_get_mapping_info(map_grid, &
     223              :                                    dims_2d=grid_dims, &
     224       438658 :                                    dims2_2d=new_dbt_tas_dist_t%dims_grid)
     225              :       ELSE
     226            0 :          CPABORT("Unknown value for which_dim")
     227              :       END IF
     228              : 
     229       877316 :       new_dbt_tas_dist_t%nd_dist = array_sublist(nd_dist, index_map)
     230       877316 :       new_dbt_tas_dist_t%nprowcol = grid_dims(which_dim)
     231       877316 :       new_dbt_tas_dist_t%nmrowcol = matrix_dims(which_dim)
     232      1754632 :    END FUNCTION
     233              : 
     234              : ! **************************************************************************************************
     235              : !> \author Patrick Seewald
     236              : ! **************************************************************************************************
     237     33739084 :    FUNCTION tas_dist_t(t, rowcol)
     238              :       CLASS(dbt_tas_dist_t), INTENT(IN) :: t
     239              :       INTEGER(KIND=int_8), INTENT(IN) :: rowcol
     240              :       INTEGER, DIMENSION(${maxrank}$) :: ind_blk
     241              :       INTEGER, DIMENSION(${maxrank}$) :: dist_blk
     242              :       INTEGER :: tas_dist_t
     243              : 
     244     33739084 :       ind_blk(:SIZE(t%dims)) = split_tensor_index(rowcol, t%dims)
     245     33739084 :       dist_blk(:SIZE(t%dims)) = get_array_elements(t%nd_dist, ind_blk(:SIZE(t%dims)))
     246     33739084 :       tas_dist_t = combine_pgrid_index(dist_blk(:SIZE(t%dims)), t%dims_grid)
     247     33739084 :    END FUNCTION
     248              : 
     249              : ! **************************************************************************************************
     250              : !> \author Patrick Seewald
     251              : ! **************************************************************************************************
     252       918268 :    FUNCTION tas_rowcols_t(t, dist)
     253              :       CLASS(dbt_tas_dist_t), INTENT(IN) :: t
     254              :       INTEGER, INTENT(IN) :: dist
     255              :       INTEGER(KIND=int_8), DIMENSION(:), ALLOCATABLE :: tas_rowcols_t
     256              :       INTEGER, DIMENSION(${maxrank}$) :: dist_blk
     257       918268 :       INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("dist")}$, ${varlist("blks")}$, blks_tmp, nd_ind
     258              :       INTEGER :: ${varlist("i")}$, i, iblk, iblk_count, nblks
     259              :       INTEGER(KIND=int_8) :: nrowcols
     260       918268 :       TYPE(array_list) :: blks
     261              : 
     262       918268 :       dist_blk(:SIZE(t%dims)) = split_pgrid_index(dist, t%dims_grid)
     263              : 
     264              :       #:for ndim in range(1, maxdim+1)
     265      1249188 :          IF (SIZE(t%dims) == ${ndim}$) THEN
     266       330920 :             CALL get_arrays(t%nd_dist, ${varlist("dist", nmax=ndim)}$)
     267              :          END IF
     268              :       #:endfor
     269              : 
     270              :       #:for idim in range(1, maxdim+1)
     271      2167652 :          IF (SIZE(t%dims) .GE. ${idim}$) THEN
     272      1249384 :             nblks = SIZE(dist_${idim}$)
     273      3748152 :             ALLOCATE (blks_tmp(nblks))
     274     14569612 :             iblk_count = 0
     275     14569612 :             DO iblk = 1, nblks
     276     14569612 :                IF (dist_${idim}$ (iblk) == dist_blk(${idim}$)) THEN
     277     11715472 :                   iblk_count = iblk_count + 1
     278     11715472 :                   blks_tmp(iblk_count) = iblk
     279              :                END IF
     280              :             END DO
     281      3742410 :             ALLOCATE (blks_${idim}$ (iblk_count))
     282     12964856 :             blks_${idim}$ (:) = blks_tmp(:iblk_count)
     283      1249384 :             DEALLOCATE (blks_tmp)
     284              :          END IF
     285              :       #:endfor
     286              : 
     287              :       #:for ndim in range(1, maxdim+1)
     288      1836536 :          IF (SIZE(t%dims) == ${ndim}$) THEN
     289       918268 :             CALL create_array_list(blks, ${ndim}$, ${varlist("blks", nmax=ndim)}$)
     290              :          END IF
     291              :       #:endfor
     292              : 
     293      2167652 :       nrowcols = PRODUCT(INT(sizes_of_arrays(blks), int_8))
     294      2749062 :       ALLOCATE (tas_rowcols_t(nrowcols))
     295              : 
     296              :       #:for ndim in range(1, maxdim+1)
     297      1836536 :          IF (SIZE(t%dims) == ${ndim}$) THEN
     298       918268 :             ALLOCATE (nd_ind(${ndim}$))
     299       918268 :             i = 0
     300              :             #:for idim in range(1,ndim+1)
     301     27946518 :                DO i_${idim}$ = 1, SIZE(blks_${idim}$)
     302              :                   #:endfor
     303     23649408 :                   i = i + 1
     304              : 
     305     64187602 :                   nd_ind(:) = get_array_elements(blks, [${varlist("i", nmax=ndim)}$])
     306     27614402 :                   tas_rowcols_t(i) = combine_tensor_index(nd_ind, t%dims)
     307              :                   #:for idim in range(1,ndim+1)
     308              :                      END DO
     309              :                   #:endfor
     310              :                END IF
     311              :             #:endfor
     312              : 
     313              :          END FUNCTION
     314              : 
     315              : ! **************************************************************************************************
     316              : !> \brief Create block size object for one matrix dimension
     317              : !> \param blk_size arrays for block sizes along all dimensions
     318              : !> \param map_blks tensor to matrix mapping object for blocks
     319              : !> \param which_dim for which dimension (1 or 2) distribution should be created
     320              : !> \return block size object
     321              : !> \author Patrick Seewald
     322              : ! **************************************************************************************************
     323       439166 :          FUNCTION new_dbt_tas_blk_size_t(blk_size, map_blks, which_dim)
     324              :             TYPE(array_list), INTENT(IN)                   :: blk_size
     325              :             TYPE(nd_to_2d_mapping), INTENT(IN)             :: map_blks
     326              :             INTEGER, INTENT(IN) :: which_dim
     327              :             INTEGER(KIND=int_8), DIMENSION(2) :: matrix_dims
     328       439166 :             INTEGER, DIMENSION(:), ALLOCATABLE :: index_map
     329              :             TYPE(dbt_tas_blk_size_t) :: new_dbt_tas_blk_size_t
     330              : 
     331       439166 :             IF (which_dim == 1) THEN
     332       658749 :                ALLOCATE (index_map(ndims_mapping_row(map_blks)))
     333       658749 :                ALLOCATE (new_dbt_tas_blk_size_t%dims(ndims_mapping_row(map_blks)))
     334              :                CALL dbt_get_mapping_info(map_blks, &
     335              :                                          dims_2d_i8=matrix_dims, &
     336              :                                          map1_2d=index_map, &
     337       219583 :                                          dims1_2d=new_dbt_tas_blk_size_t%dims)
     338       219583 :             ELSEIF (which_dim == 2) THEN
     339       658749 :                ALLOCATE (index_map(ndims_mapping_column(map_blks)))
     340       658749 :                ALLOCATE (new_dbt_tas_blk_size_t%dims(ndims_mapping_column(map_blks)))
     341              :                CALL dbt_get_mapping_info(map_blks, &
     342              :                                          dims_2d_i8=matrix_dims, &
     343              :                                          map2_2d=index_map, &
     344       219583 :                                          dims2_2d=new_dbt_tas_blk_size_t%dims)
     345              :             ELSE
     346            0 :                CPABORT("Unknown value for which_dim")
     347              :             END IF
     348              : 
     349       439166 :             new_dbt_tas_blk_size_t%blk_size = array_sublist(blk_size, index_map)
     350       439166 :             new_dbt_tas_blk_size_t%nmrowcol = matrix_dims(which_dim)
     351              : 
     352              :             new_dbt_tas_blk_size_t%nfullrowcol = PRODUCT(INT(sum_of_arrays(new_dbt_tas_blk_size_t%blk_size), &
     353       987901 :                                                              KIND=int_8))
     354       878332 :          END FUNCTION
     355              : 
     356              : ! **************************************************************************************************
     357              : !> \author Patrick Seewald
     358              : ! **************************************************************************************************
     359     15694191 :          FUNCTION tas_blk_size_t(t, rowcol)
     360              :             CLASS(dbt_tas_blk_size_t), INTENT(IN) :: t
     361              :             INTEGER(KIND=int_8), INTENT(IN) :: rowcol
     362              :             INTEGER :: tas_blk_size_t
     363     31388382 :             INTEGER, DIMENSION(SIZE(t%dims)) :: ind_blk
     364     31388382 :             INTEGER, DIMENSION(SIZE(t%dims)) :: blk_size
     365              : 
     366     15694191 :             ind_blk(:) = split_tensor_index(rowcol, t%dims)
     367     15694191 :             blk_size(:) = get_array_elements(t%blk_size, ind_blk)
     368     36256644 :             tas_blk_size_t = PRODUCT(blk_size)
     369              : 
     370     15694191 :          END FUNCTION
     371              : 
     372              : ! **************************************************************************************************
     373              : !> \brief load balancing criterion whether to accept process grid dimension based on total number of
     374              : !>        cores and tensor dimension
     375              : !> \param pdims_avail available process grid dimensions (total number of cores)
     376              : !> \param pdim process grid dimension to test
     377              : !> \param tdim tensor dimension corresponding to pdim
     378              : !> \param lb_ratio load imbalance acceptance factor
     379              : !> \author Patrick Seewald
     380              : ! **************************************************************************************************
     381        14250 :          PURE FUNCTION accept_pdims_loadbalancing(pdims_avail, pdim, tdim, lb_ratio)
     382              :             INTEGER, INTENT(IN) :: pdims_avail
     383              :             INTEGER, INTENT(IN) :: pdim
     384              :             INTEGER, INTENT(IN) :: tdim
     385              :             REAL(dp), INTENT(IN) :: lb_ratio
     386              :             LOGICAL :: accept_pdims_loadbalancing
     387              : 
     388        14250 :             accept_pdims_loadbalancing = .FALSE.
     389        14250 :             IF (MOD(pdims_avail, pdim) == 0) THEN
     390        11704 :                IF (REAL(tdim, dp)*lb_ratio < REAL(pdim, dp)) THEN
     391         9190 :                   IF (MOD(tdim, pdim) == 0) accept_pdims_loadbalancing = .TRUE.
     392              :                ELSE
     393              :                   accept_pdims_loadbalancing = .TRUE.
     394              :                END IF
     395              :             END IF
     396              : 
     397        14250 :          END FUNCTION
     398              : 
     399              : ! **************************************************************************************************
     400              : !> \brief Create process grid dimensions corresponding to one dimension of the matrix representation
     401              : !>        of a tensor, imposing that no process grid dimension is greater than the corresponding
     402              : !>        tensor dimension.
     403              : !> \param nodes Total number of nodes available for this matrix dimension
     404              : !> \param dims process grid dimension corresponding to tensor_dims
     405              : !> \param tensor_dims tensor dimensions
     406              : !> \param lb_ratio load imbalance acceptance factor
     407              : !> \author Patrick Seewald
     408              : ! **************************************************************************************************
     409         3680 :          RECURSIVE SUBROUTINE dbt_mp_dims_create(nodes, dims, tensor_dims, lb_ratio)
     410              :             INTEGER, INTENT(IN) :: nodes
     411              :             INTEGER, DIMENSION(:), INTENT(INOUT) :: dims
     412              :             INTEGER, DIMENSION(:), INTENT(IN) :: tensor_dims
     413              :             REAL(dp), INTENT(IN), OPTIONAL :: lb_ratio
     414              : 
     415         3680 :             INTEGER, DIMENSION(:), ALLOCATABLE :: tensor_dims_sorted, sort_indices, dims_store
     416         3680 :             REAL(dp), DIMENSION(:), ALLOCATABLE :: sort_key
     417              :             INTEGER :: pdims_rem, idim, pdim
     418              :             REAL(dp) :: lb_ratio_prv
     419              : 
     420         3680 :             IF (PRESENT(lb_ratio)) THEN
     421          544 :                lb_ratio_prv = lb_ratio
     422              :             ELSE
     423         3136 :                lb_ratio_prv = 0.1_dp
     424              :             END IF
     425              : 
     426        19362 :             ALLOCATE (dims_store, source=dims)
     427              : 
     428              :             ! get default process grid dimensions
     429         3680 :             IF (any(dims == 0)) THEN
     430         3680 :                CALL mp_dims_create(nodes, dims)
     431              :             END IF
     432              : 
     433              :             ! sort dimensions such that problematic grid dimensions (those who should be corrected) come first
     434        11040 :             ALLOCATE (sort_key(SIZE(tensor_dims)))
     435        12002 :             sort_key(:) = REAL(tensor_dims, dp)/dims
     436              : 
     437        19362 :             ALLOCATE (tensor_dims_sorted, source=tensor_dims)
     438        11040 :             ALLOCATE (sort_indices(SIZE(sort_key)))
     439         3680 :             CALL sort(sort_key, SIZE(sort_key), sort_indices)
     440        20324 :             tensor_dims_sorted(:) = tensor_dims_sorted(sort_indices)
     441        20324 :             dims(:) = dims(sort_indices)
     442              : 
     443              :             ! remaining number of nodes
     444         3680 :             pdims_rem = nodes
     445              : 
     446        11122 :             DO idim = 1, SIZE(tensor_dims_sorted)
     447        11122 :                IF (.NOT. accept_pdims_loadbalancing(pdims_rem, dims(idim), tensor_dims_sorted(idim), lb_ratio_prv)) THEN
     448         2216 :                   pdim = tensor_dims_sorted(idim)
     449         5928 :                   DO WHILE (.NOT. accept_pdims_loadbalancing(pdims_rem, pdim, tensor_dims_sorted(idim), lb_ratio_prv))
     450         3712 :                      pdim = pdim - 1
     451              :                   END DO
     452         2216 :                   dims(idim) = pdim
     453         2216 :                   pdims_rem = pdims_rem/dims(idim)
     454              : 
     455         2216 :                   IF (idim .NE. SIZE(tensor_dims_sorted)) THEN
     456         3238 :                      dims(idim + 1:) = 0
     457         1336 :                      CALL mp_dims_create(pdims_rem, dims(idim + 1:))
     458          880 :                   ELSEIF (lb_ratio_prv < 0.5_dp) THEN
     459              :                      ! resort to a less strict load imbalance factor
     460         1894 :                      dims(:) = dims_store
     461          544 :                      CALL dbt_mp_dims_create(nodes, dims, tensor_dims, 0.5_dp)
     462          544 :                      RETURN
     463              :                   ELSE
     464              :                      ! resort to default process grid dimensions
     465         1198 :                      dims(:) = dims_store
     466          336 :                      CALL mp_dims_create(nodes, dims)
     467          336 :                      RETURN
     468              :                   END IF
     469              : 
     470              :                ELSE
     471         6106 :                   pdims_rem = pdims_rem/dims(idim)
     472              :                END IF
     473              :             END DO
     474              : 
     475        15020 :             dims(sort_indices) = dims
     476              : 
     477         3680 :          END SUBROUTINE
     478              : 
     479              : ! **************************************************************************************************
     480              : !> \brief Create an n-dimensional process grid.
     481              : !>        We can not use a n-dimensional MPI cartesian grid for tensors since the mapping between
     482              : !>        n-dim. and 2-dim. index allows for an arbitrary reordering of tensor index. Therefore we
     483              : !>        can not use n-dim. MPI Cartesian grid because it may not be consistent with the respective
     484              : !>        2d grid. The 2d Cartesian MPI grid is the reference grid (since tensor data is stored as
     485              : !>        DBM matrix) and this routine creates an object that is a n-dim. interface to this grid.
     486              : !>        map1_2d and map2_2d don't need to be specified (correctly), grid may be redefined in
     487              : !>        dbt_distribution_new. Note that pgrid is equivalent to a MPI cartesian grid only
     488              : !>        if map1_2d and map2_2d don't reorder indices (which is the case if
     489              : !>        [map1_2d, map2_2d] == [1, 2, ..., ndims]). Otherwise the mapping of grid coordinates to
     490              : !>        processes depends on the ordering of the indices and is not equivalent to a MPI cartesian
     491              : !>        grid.
     492              : !> \param mp_comm simple MPI Communicator
     493              : !> \param dims grid dimensions - if entries are 0, dimensions are chosen automatically.
     494              : !> \param pgrid n-dimensional grid object
     495              : !> \param map1_2d which nd-indices map to first matrix index and in which order
     496              : !> \param map2_2d which nd-indices map to first matrix index and in which order
     497              : !> \param tensor_dims tensor block dimensions. If present, process grid dimensions are created such
     498              : !>                    that good load balancing is ensured even if some of the tensor dimensions are
     499              : !>                    small (i.e. on the same order or smaller than nproc**(1/ndim) where ndim is
     500              : !>                    the tensor rank)
     501              : !> \param nsplit impose a constant split factor
     502              : !> \param dimsplit which matrix dimension to split
     503              : !> \author Patrick Seewald
     504              : ! **************************************************************************************************
     505      2225280 :          SUBROUTINE dbt_pgrid_create_expert(mp_comm, dims, pgrid, map1_2d, map2_2d, tensor_dims, nsplit, dimsplit)
     506              :             CLASS(mp_comm_type), INTENT(IN) :: mp_comm
     507              :             INTEGER, DIMENSION(:), INTENT(INOUT) :: dims
     508              :             TYPE(dbt_pgrid_type), INTENT(OUT) :: pgrid
     509              :             INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
     510              :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL :: tensor_dims
     511              :             INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
     512              :             INTEGER, DIMENSION(2) :: pdims_2d
     513              :             INTEGER :: nproc, ndims, handle
     514      2670336 :             TYPE(dbt_tas_split_info) :: info
     515              : 
     516              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_pgrid_create_expert'
     517              : 
     518       445056 :             CALL timeset(routineN, handle)
     519              : 
     520       445056 :             ndims = SIZE(dims)
     521              : 
     522       445056 :             nproc = mp_comm%num_pe
     523      1550816 :             IF (ANY(dims == 0)) THEN
     524         2434 :                IF (.NOT. PRESENT(tensor_dims)) THEN
     525          988 :                   CALL mp_dims_create(nproc, dims)
     526              :                ELSE
     527         1446 :                   CALL dbt_mp_dims_create(nproc, dims, tensor_dims)
     528              :                END IF
     529              :             END IF
     530       445056 :             CALL create_nd_to_2d_mapping(pgrid%nd_index_grid, dims, map1_2d, map2_2d, base=0, col_major=.FALSE.)
     531       445056 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, dims_2d=pdims_2d)
     532       445056 :             CALL pgrid%mp_comm_2d%create(mp_comm, 2, pdims_2d)
     533              : 
     534       445056 :             IF (PRESENT(nsplit)) THEN
     535          780 :                CPASSERT(PRESENT(dimsplit))
     536          780 :                CALL dbt_tas_create_split(info, pgrid%mp_comm_2d, dimsplit, nsplit, opt_nsplit=.FALSE.)
     537          780 :                ALLOCATE (pgrid%tas_split_info, SOURCE=info)
     538              :             END IF
     539              : 
     540              :             ! store number of MPI ranks because we need it for PURE function dbt_max_nblks_local
     541       445056 :             pgrid%nproc = nproc
     542              : 
     543       445056 :             CALL timestop(handle)
     544       445056 :          END SUBROUTINE
     545              : 
     546              : ! **************************************************************************************************
     547              : !> \brief Create a default nd process topology that is consistent with a given 2d topology.
     548              : !>        Purpose: a nd tensor defined on the returned process grid can be represented as a DBM
     549              : !>        matrix with the given 2d topology.
     550              : !>        This is needed to enable contraction of 2 tensors (must have the same 2d process grid).
     551              : !> \param comm_2d communicator with 2-dimensional topology
     552              : !> \param map1_2d which nd-indices map to first matrix index and in which order
     553              : !> \param map2_2d which nd-indices map to second matrix index and in which order
     554              : !> \param dims_nd nd dimensions
     555              : !> \param pdims_2d if comm_2d does not have a cartesian topology associated, can input dimensions
     556              : !>                 with pdims_2d
     557              : !> \param tdims tensor block dimensions. If present, process grid dimensions are created such that
     558              : !>              good load balancing is ensured even if some of the tensor dimensions are small
     559              : !>              (i.e. on the same order or smaller than nproc**(1/ndim) where ndim is the tensor rank)
     560              : !> \return with nd cartesian grid
     561              : !> \author Patrick Seewald
     562              : ! **************************************************************************************************
     563        66607 :          FUNCTION dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, &
     564              :                                  nsplit, dimsplit)
     565              :             CLASS(mp_comm_type), INTENT(IN)                               :: comm_2d
     566              :             INTEGER, DIMENSION(:), INTENT(IN)                 :: map1_2d, map2_2d
     567              :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)), &
     568              :                INTENT(IN), OPTIONAL                           :: dims_nd
     569              :             INTEGER, DIMENSION(SIZE(map1_2d)), INTENT(IN), OPTIONAL :: dims1_nd
     570              :             INTEGER, DIMENSION(SIZE(map2_2d)), INTENT(IN), OPTIONAL :: dims2_nd
     571              :             INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL           :: pdims_2d
     572              :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)), &
     573              :                INTENT(IN), OPTIONAL                           :: tdims
     574              :             INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
     575              :             INTEGER                                           :: ndim1, ndim2
     576              :             INTEGER, DIMENSION(2)                             :: dims_2d
     577              : 
     578       109130 :             INTEGER, DIMENSION(SIZE(map1_2d)) :: dims1_nd_prv
     579       109130 :             INTEGER, DIMENSION(SIZE(map2_2d)) :: dims2_nd_prv
     580       109130 :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)) :: dims_nd_prv
     581              :             INTEGER                                           :: handle
     582              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_nd_mp_comm'
     583              :             TYPE(dbt_pgrid_type)                          :: dbt_nd_mp_comm
     584              : 
     585        54565 :             CALL timeset(routineN, handle)
     586              : 
     587        54565 :             ndim1 = SIZE(map1_2d); ndim2 = SIZE(map2_2d)
     588              : 
     589        54565 :             IF (PRESENT(pdims_2d)) THEN
     590        28874 :                dims_2d(:) = pdims_2d
     591              :             ELSE
     592              : ! This branch allows us to call this routine with a plain mp_comm_type without actually requiring an mp_cart_type
     593              : ! In a few cases in CP2K, this prevents erroneous calls to mpi_cart_get with a non-cartesian communicator
     594              :                SELECT TYPE (comm_2d)
     595              :                CLASS IS (mp_cart_type)
     596        77073 :                   dims_2d = comm_2d%num_pe_cart
     597              :                CLASS DEFAULT
     598              :                   CALL cp_abort(__LOCATION__, "If the argument pdims_2d is not given, the "// &
     599            0 :                                 "communicator comm_2d must be of class mp_cart_type.")
     600              :                END SELECT
     601              :             END IF
     602              : 
     603        54565 :             IF (.NOT. PRESENT(dims_nd)) THEN
     604       244605 :                dims1_nd_prv = 0; dims2_nd_prv = 0
     605        54565 :                IF (PRESENT(dims1_nd)) THEN
     606        17276 :                   dims1_nd_prv(:) = dims1_nd
     607              :                ELSE
     608              : 
     609        48763 :                   IF (PRESENT(tdims)) THEN
     610         1560 :                      CALL dbt_mp_dims_create(dims_2d(1), dims1_nd_prv, tdims(map1_2d))
     611              :                   ELSE
     612        47983 :                      CALL mp_dims_create(dims_2d(1), dims1_nd_prv)
     613              :                   END IF
     614              :                END IF
     615              : 
     616        54565 :                IF (PRESENT(dims2_nd)) THEN
     617        16372 :                   dims2_nd_prv(:) = dims2_nd
     618              :                ELSE
     619        49105 :                   IF (PRESENT(tdims)) THEN
     620         2340 :                      CALL dbt_mp_dims_create(dims_2d(2), dims2_nd_prv, tdims(map2_2d))
     621              :                   ELSE
     622        48325 :                      CALL mp_dims_create(dims_2d(2), dims2_nd_prv)
     623              :                   END IF
     624              :                END IF
     625       129107 :                dims_nd_prv(map1_2d) = dims1_nd_prv
     626       115498 :                dims_nd_prv(map2_2d) = dims2_nd_prv
     627              :             ELSE
     628            0 :                CPASSERT(PRODUCT(dims_nd(map1_2d)) == dims_2d(1))
     629            0 :                CPASSERT(PRODUCT(dims_nd(map2_2d)) == dims_2d(2))
     630            0 :                dims_nd_prv = dims_nd
     631              :             END IF
     632              : 
     633              :             CALL dbt_pgrid_create_expert(comm_2d, dims_nd_prv, dbt_nd_mp_comm, &
     634       108350 :                                          tensor_dims=tdims, map1_2d=map1_2d, map2_2d=map2_2d, nsplit=nsplit, dimsplit=dimsplit)
     635              : 
     636        54565 :             CALL timestop(handle)
     637              : 
     638       224842 :          END FUNCTION
     639              : 
     640              : ! **************************************************************************************************
     641              : !> \brief Release the MPI communicator.
     642              : !> \author Patrick Seewald
     643              : ! **************************************************************************************************
     644            0 :          SUBROUTINE dbt_nd_mp_free(mp_comm)
     645              :             TYPE(mp_comm_type), INTENT(INOUT)                               :: mp_comm
     646              : 
     647            0 :             CALL mp_comm%free()
     648            0 :          END SUBROUTINE dbt_nd_mp_free
     649              : 
     650              : ! **************************************************************************************************
     651              : !> \brief remap a process grid (needed when mapping between tensor and matrix index is changed)
     652              : !> \param map1_2d new mapping
     653              : !> \param map2_2d new mapping
     654              : !> \author Patrick Seewald
     655              : ! **************************************************************************************************
     656      1933045 :          SUBROUTINE dbt_pgrid_remap(pgrid_in, map1_2d, map2_2d, pgrid_out)
     657              :             TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid_in
     658              :             INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
     659              :             TYPE(dbt_pgrid_type), INTENT(OUT) :: pgrid_out
     660       386609 :             INTEGER, DIMENSION(:), ALLOCATABLE :: dims
     661       773218 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid_in%nd_index_grid)) :: map1_2d_old
     662       386609 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid_in%nd_index_grid)) :: map2_2d_old
     663              : 
     664      1159827 :             ALLOCATE (dims(SIZE(map1_2d) + SIZE(map2_2d)))
     665       386609 :             CALL dbt_get_mapping_info(pgrid_in%nd_index_grid, dims_nd=dims, map1_2d=map1_2d_old, map2_2d=map2_2d_old)
     666       386609 :             CALL dbt_pgrid_create_expert(pgrid_in%mp_comm_2d, dims, pgrid_out, map1_2d=map1_2d, map2_2d=map2_2d)
     667       386609 :             IF (array_eq_i(map1_2d_old, map1_2d) .AND. array_eq_i(map2_2d_old, map2_2d)) THEN
     668       383087 :                IF (ALLOCATED(pgrid_in%tas_split_info)) THEN
     669       365857 :                   ALLOCATE (pgrid_out%tas_split_info, SOURCE=pgrid_in%tas_split_info)
     670       365857 :                   CALL dbt_tas_info_hold(pgrid_out%tas_split_info)
     671              :                END IF
     672              :             END IF
     673       386609 :          END SUBROUTINE
     674              : 
     675              : ! **************************************************************************************************
     676              : !> \brief as mp_environ but for special pgrid type
     677              : !> \author Patrick Seewald
     678              : ! **************************************************************************************************
     679       851752 :          SUBROUTINE mp_environ_pgrid(pgrid, dims, task_coor)
     680              :             TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
     681              :             INTEGER, DIMENSION(ndims_mapping(pgrid%nd_index_grid)), INTENT(OUT) :: dims
     682              :             INTEGER, DIMENSION(ndims_mapping(pgrid%nd_index_grid)), INTENT(OUT) :: task_coor
     683              :             INTEGER, DIMENSION(2)                                          :: task_coor_2d
     684              : 
     685      2555256 :             task_coor_2d = pgrid%mp_comm_2d%mepos_cart
     686       851752 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, dims_nd=dims)
     687       851752 :             task_coor = get_nd_indices_pgrid(pgrid%nd_index_grid, task_coor_2d)
     688       851752 :          END SUBROUTINE
     689              : 
     690              : ! **************************************************************************************************
     691              : !> \brief Create a tensor distribution.
     692              : !> \param pgrid process grid
     693              : !> \param map1_2d which nd-indices map to first matrix index and in which order
     694              : !> \param map2_2d which nd-indices map to second matrix index and in which order
     695              : !> \param own_comm whether distribution should own communicator
     696              : !> \author Patrick Seewald
     697              : ! **************************************************************************************************
     698      3947922 :          SUBROUTINE dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, ${varlist("nd_dist")}$, own_comm)
     699              :             TYPE(dbt_distribution_type), INTENT(OUT)    :: dist
     700              :             TYPE(dbt_pgrid_type), INTENT(IN)            :: pgrid
     701              :             INTEGER, DIMENSION(:), INTENT(IN)               :: map1_2d
     702              :             INTEGER, DIMENSION(:), INTENT(IN)               :: map2_2d
     703              :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL     :: ${varlist("nd_dist")}$
     704              :             LOGICAL, INTENT(IN), OPTIONAL                   :: own_comm
     705              :             INTEGER                                         :: ndims
     706       438658 :             TYPE(mp_cart_type)                              :: comm_2d
     707              :             INTEGER, DIMENSION(2)                           :: pdims_2d_check, &
     708              :                                                                pdims_2d
     709      2631948 :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)) :: dims, nblks_nd, task_coor
     710       438658 :             TYPE(array_list)                                :: nd_dist
     711      2193290 :             TYPE(nd_to_2d_mapping)                          :: map_blks, map_grid
     712              :             INTEGER                                         :: handle
     713       877316 :             TYPE(dbt_tas_dist_t)                          :: row_dist_obj, col_dist_obj
     714      1315974 :             TYPE(dbt_pgrid_type)                        :: pgrid_prv
     715              :             LOGICAL                                         :: need_pgrid_remap
     716       877316 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid%nd_index_grid)) :: map1_2d_check
     717       438658 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid%nd_index_grid)) :: map2_2d_check
     718              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_distribution_new_expert'
     719              : 
     720       438658 :             CALL timeset(routineN, handle)
     721       438658 :             ndims = SIZE(map1_2d) + SIZE(map2_2d)
     722       438658 :             CPASSERT(ndims .GE. 2 .AND. ndims .LE. ${maxdim}$)
     723              : 
     724      1097344 :             CALL create_array_list(nd_dist, ndims, ${varlist("nd_dist")}$)
     725              : 
     726      1534604 :             nblks_nd(:) = sizes_of_arrays(nd_dist)
     727              : 
     728       438658 :             need_pgrid_remap = .TRUE.
     729       438658 :             IF (PRESENT(own_comm)) THEN
     730        52049 :                CALL dbt_get_mapping_info(pgrid%nd_index_grid, map1_2d=map1_2d_check, map2_2d=map2_2d_check)
     731        52049 :                IF (own_comm) THEN
     732        52049 :                   IF (.NOT. array_eq_i(map1_2d_check, map1_2d) .OR. .NOT. array_eq_i(map2_2d_check, map2_2d)) THEN
     733            0 :                      CPABORT("map1_2d / map2_2d are not consistent with pgrid")
     734              :                   END IF
     735        52049 :                   pgrid_prv = pgrid
     736              :                   need_pgrid_remap = .FALSE.
     737              :                END IF
     738              :             END IF
     739              : 
     740       386609 :             IF (need_pgrid_remap) CALL dbt_pgrid_remap(pgrid, map1_2d, map2_2d, pgrid_prv)
     741              : 
     742              :             ! check that 2d process topology is consistent with nd topology.
     743       438658 :             CALL mp_environ_pgrid(pgrid_prv, dims, task_coor)
     744              : 
     745              :             ! process grid index mapping
     746       438658 :             CALL create_nd_to_2d_mapping(map_grid, dims, map1_2d, map2_2d, base=0, col_major=.FALSE.)
     747              : 
     748              :             ! blk index mapping
     749       438658 :             CALL create_nd_to_2d_mapping(map_blks, nblks_nd, map1_2d, map2_2d)
     750              : 
     751       438658 :             row_dist_obj = dbt_tas_dist_t(nd_dist, map_blks, map_grid, 1)
     752       438658 :             col_dist_obj = dbt_tas_dist_t(nd_dist, map_blks, map_grid, 2)
     753              : 
     754       438658 :             CALL dbt_get_mapping_info(map_grid, dims_2d=pdims_2d)
     755              : 
     756       438658 :             comm_2d = pgrid_prv%mp_comm_2d
     757              : 
     758      1315974 :             pdims_2d_check = comm_2d%num_pe_cart
     759      1315974 :             IF (ANY(pdims_2d_check .NE. pdims_2d)) THEN
     760            0 :                CPABORT("inconsistent process grid dimensions")
     761              :             END IF
     762              : 
     763       438658 :             IF (ALLOCATED(pgrid_prv%tas_split_info)) THEN
     764       365857 :                CALL dbt_tas_distribution_new(dist%dist, comm_2d, row_dist_obj, col_dist_obj, split_info=pgrid_prv%tas_split_info)
     765              :             ELSE
     766        72801 :                CALL dbt_tas_distribution_new(dist%dist, comm_2d, row_dist_obj, col_dist_obj)
     767        72801 :                ALLOCATE (pgrid_prv%tas_split_info, SOURCE=dist%dist%info)
     768        72801 :                CALL dbt_tas_info_hold(pgrid_prv%tas_split_info)
     769              :             END IF
     770              : 
     771       438658 :             dist%nd_dist = nd_dist
     772       438658 :             dist%pgrid = pgrid_prv
     773              : 
     774       438658 :             ALLOCATE (dist%refcount)
     775       438658 :             dist%refcount = 1
     776       438658 :             CALL timestop(handle)
     777              : 
     778       438658 :          END SUBROUTINE
     779              : 
     780              : ! **************************************************************************************************
     781              : !> \brief Create a tensor distribution.
     782              : !> \param pgrid process grid
     783              : !> \param nd_dist_i distribution vectors for all tensor dimensions
     784              : !> \author Patrick Seewald
     785              : ! **************************************************************************************************
     786       162018 :          SUBROUTINE dbt_distribution_new(dist, pgrid, ${varlist("nd_dist")}$)
     787              :             TYPE(dbt_distribution_type), INTENT(OUT)    :: dist
     788              :             TYPE(dbt_pgrid_type), INTENT(IN)            :: pgrid
     789              :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL     :: ${varlist("nd_dist")}$
     790        36004 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid%nd_index_grid)) :: map1_2d
     791        18002 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid%nd_index_grid)) :: map2_2d
     792              :             INTEGER :: ndims
     793              : 
     794        18002 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, map1_2d=map1_2d, map2_2d=map2_2d, ndim_nd=ndims)
     795              : 
     796        45306 :             CALL dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, ${varlist("nd_dist")}$)
     797              : 
     798        18002 :          END SUBROUTINE
     799              : 
     800              : ! **************************************************************************************************
     801              : !> \brief destroy process grid
     802              : !> \param keep_comm  if .TRUE. communicator is not freed
     803              : !> \author Patrick Seewald
     804              : ! **************************************************************************************************
     805      1534534 :          SUBROUTINE dbt_pgrid_destroy(pgrid, keep_comm)
     806              :             TYPE(dbt_pgrid_type), INTENT(INOUT) :: pgrid
     807              :             LOGICAL, INTENT(IN), OPTIONAL           :: keep_comm
     808              :             LOGICAL :: keep_comm_prv
     809      1534534 :             IF (PRESENT(keep_comm)) THEN
     810      1089478 :                keep_comm_prv = keep_comm
     811              :             ELSE
     812              :                keep_comm_prv = .FALSE.
     813              :             END IF
     814      1534534 :             IF (.NOT. keep_comm_prv) CALL pgrid%mp_comm_2d%free()
     815      1534534 :             CALL destroy_nd_to_2d_mapping(pgrid%nd_index_grid)
     816      1534534 :             IF (ALLOCATED(pgrid%tas_split_info) .AND. .NOT. keep_comm_prv) THEN
     817       439438 :                CALL dbt_tas_release_info(pgrid%tas_split_info)
     818       439438 :                DEALLOCATE (pgrid%tas_split_info)
     819              :             END IF
     820      1534534 :          END SUBROUTINE
     821              : 
     822              : ! **************************************************************************************************
     823              : !> \brief Destroy tensor distribution
     824              : !> \author Patrick Seewald
     825              : ! **************************************************************************************************
     826       438658 :          SUBROUTINE dbt_distribution_destroy(dist)
     827              :             TYPE(dbt_distribution_type), INTENT(INOUT) :: dist
     828              :             INTEGER                                   :: handle
     829              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_distribution_destroy'
     830              :             LOGICAL :: abort
     831              : 
     832       438658 :             CALL timeset(routineN, handle)
     833       438658 :             CALL dbt_tas_distribution_destroy(dist%dist)
     834       438658 :             CALL destroy_array_list(dist%nd_dist)
     835              : 
     836       438658 :             abort = .FALSE.
     837       438658 :             IF (.NOT. ASSOCIATED(dist%refcount)) THEN
     838              :                abort = .TRUE.
     839       438658 :             ELSEIF (dist%refcount < 1) THEN
     840              :                abort = .TRUE.
     841              :             END IF
     842              : 
     843              :             IF (abort) THEN
     844            0 :                CPABORT("can not destroy non-existing tensor distribution")
     845              :             END IF
     846              : 
     847       438658 :             dist%refcount = dist%refcount - 1
     848              : 
     849       438658 :             IF (dist%refcount == 0) THEN
     850       219075 :                CALL dbt_pgrid_destroy(dist%pgrid)
     851       219075 :                DEALLOCATE (dist%refcount)
     852              :             ELSE
     853       219583 :                CALL dbt_pgrid_destroy(dist%pgrid, keep_comm=.TRUE.)
     854              :             END IF
     855              : 
     856       438658 :             CALL timestop(handle)
     857       438658 :          END SUBROUTINE
     858              : 
     859              : ! **************************************************************************************************
     860              : !> \brief reference counting for distribution
     861              : !>        (only needed for communicator handle that must be freed when no longer needed)
     862              : !> \author Patrick Seewald
     863              : ! **************************************************************************************************
     864       219583 :          SUBROUTINE dbt_distribution_hold(dist)
     865              :             TYPE(dbt_distribution_type), INTENT(IN) :: dist
     866              :             INTEGER, POINTER                            :: ref
     867              : 
     868       219583 :             IF (dist%refcount < 1) THEN
     869            0 :                CPABORT("can not hold non-existing tensor distribution")
     870              :             END IF
     871       219583 :             ref => dist%refcount
     872       219583 :             ref = ref + 1
     873       219583 :          END SUBROUTINE
     874              : 
     875              : ! **************************************************************************************************
     876              : !> \brief get distribution from tensor
     877              : !> \return distribution
     878              : !> \author Patrick Seewald
     879              : ! **************************************************************************************************
     880       160626 :          FUNCTION dbt_distribution(tensor)
     881              :             TYPE(dbt_type), INTENT(IN)  :: tensor
     882              :             TYPE(dbt_distribution_type) :: dbt_distribution
     883              : 
     884       160626 :             CALL dbt_tas_get_info(tensor%matrix_rep, distribution=dbt_distribution%dist)
     885       160626 :             dbt_distribution%pgrid = tensor%pgrid
     886       160626 :             dbt_distribution%nd_dist = tensor%nd_dist
     887              :             dbt_distribution%refcount => dbt_distribution%refcount
     888      1124382 :          END FUNCTION
     889              : 
     890              : ! **************************************************************************************************
     891              : !> \author Patrick Seewald
     892              : ! **************************************************************************************************
     893      2415413 :          SUBROUTINE dbt_distribution_remap(dist_in, map1_2d, map2_2d, dist_out)
     894              :             TYPE(dbt_distribution_type), INTENT(IN)    :: dist_in
     895              :             INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
     896              :             TYPE(dbt_distribution_type), INTENT(OUT)    :: dist_out
     897       219583 :             INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("dist")}$
     898              :             INTEGER :: ndims
     899       219583 :             ndims = SIZE(map1_2d) + SIZE(map2_2d)
     900              :             #:for ndim in range(1, maxdim+1)
     901       658445 :                IF (ndims == ${ndim}$) THEN
     902       219583 :                   CALL get_arrays(dist_in%nd_dist, ${varlist("dist", nmax=ndim)}$)
     903       219583 :                   CALL dbt_distribution_new_expert(dist_out, dist_in%pgrid, map1_2d, map2_2d, ${varlist("dist", nmax=ndim)}$)
     904              :                END IF
     905              :             #:endfor
     906       219583 :          END SUBROUTINE
     907              : 
     908              : ! **************************************************************************************************
     909              : !> \brief create a tensor.
     910              : !>        For performance, the arguments map1_2d and map2_2d (controlling matrix representation of
     911              : !>        tensor) should be consistent with the the contraction to be performed (see documentation
     912              : !>        of dbt_contract).
     913              : !> \param map1_2d which nd-indices to map to first 2d index and in which order
     914              : !> \param map2_2d which nd-indices to map to first 2d index and in which order
     915              : !> \param blk_size_i blk sizes in each dimension
     916              : !> \author Patrick Seewald
     917              : ! **************************************************************************************************
     918      2415413 :          SUBROUTINE dbt_create_new(tensor, name, dist, map1_2d, map2_2d, &
     919       219583 :                                    ${varlist("blk_size")}$)
     920              :             TYPE(dbt_type), INTENT(OUT)                   :: tensor
     921              :             CHARACTER(len=*), INTENT(IN)                      :: name
     922              :             TYPE(dbt_distribution_type), INTENT(INOUT)    :: dist
     923              :             INTEGER, DIMENSION(:), INTENT(IN)                 :: map1_2d
     924              :             INTEGER, DIMENSION(:), INTENT(IN)                 :: map2_2d
     925              :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL       :: ${varlist("blk_size")}$
     926              :             INTEGER                                           :: ndims
     927              :             INTEGER(KIND=int_8), DIMENSION(2)                             :: dims_2d
     928      1097915 :             INTEGER, DIMENSION(SIZE(map1_2d) + SIZE(map2_2d)) :: dims, pdims, task_coor
     929       439166 :             TYPE(dbt_tas_blk_size_t)                        :: col_blk_size_obj, row_blk_size_obj
     930      1976247 :             TYPE(dbt_distribution_type)                   :: dist_new
     931       219583 :             TYPE(array_list)                                  :: blk_size, blks_local
     932       658749 :             TYPE(nd_to_2d_mapping)                            :: map
     933              :             INTEGER                                   :: handle
     934              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_create_new'
     935       219583 :             INTEGER, DIMENSION(:), ALLOCATABLE              :: ${varlist("blks_local")}$
     936       219583 :             INTEGER, DIMENSION(:), ALLOCATABLE              :: ${varlist("dist")}$
     937              :             INTEGER                                         :: iblk_count, iblk
     938       219583 :             INTEGER, DIMENSION(:), ALLOCATABLE              :: nblks_local, nfull_local
     939              : 
     940       219583 :             CALL timeset(routineN, handle)
     941       219583 :             ndims = SIZE(map1_2d) + SIZE(map2_2d)
     942       549180 :             CALL create_array_list(blk_size, ndims, ${varlist("blk_size")}$)
     943       768318 :             dims = sizes_of_arrays(blk_size)
     944              : 
     945       219583 :             CALL create_nd_to_2d_mapping(map, dims, map1_2d, map2_2d)
     946       219583 :             CALL dbt_get_mapping_info(map, dims_2d_i8=dims_2d)
     947              : 
     948       219583 :             row_blk_size_obj = dbt_tas_blk_size_t(blk_size, map, 1)
     949       219583 :             col_blk_size_obj = dbt_tas_blk_size_t(blk_size, map, 2)
     950              : 
     951       219583 :             CALL dbt_distribution_remap(dist, map1_2d, map2_2d, dist_new)
     952              : 
     953      1537081 :             ALLOCATE (tensor%matrix_rep)
     954              :             CALL dbt_tas_create(matrix=tensor%matrix_rep, &
     955              :                                 name=TRIM(name)//" matrix", &
     956              :                                 dist=dist_new%dist, &
     957              :                                 row_blk_size=row_blk_size_obj, &
     958       219583 :                                 col_blk_size=col_blk_size_obj)
     959              : 
     960       219583 :             tensor%owns_matrix = .TRUE.
     961              : 
     962       219583 :             tensor%nd_index_blk = map
     963       219583 :             tensor%name = name
     964              : 
     965       219583 :             CALL dbt_tas_finalize(tensor%matrix_rep)
     966       219583 :             CALL destroy_nd_to_2d_mapping(map)
     967              : 
     968              :             ! map element-wise tensor index
     969       219583 :             CALL create_nd_to_2d_mapping(map, sum_of_arrays(blk_size), map1_2d, map2_2d)
     970       219583 :             tensor%nd_index = map
     971       219583 :             tensor%blk_sizes = blk_size
     972              : 
     973       219583 :             CALL mp_environ_pgrid(dist_new%pgrid, pdims, task_coor)
     974              : 
     975              :             #:for ndim in range(1, maxdim+1)
     976       439166 :                IF (ndims == ${ndim}$) THEN
     977       219583 :                   CALL get_arrays(dist_new%nd_dist, ${varlist("dist", nmax=ndim)}$)
     978              :                END IF
     979              :             #:endfor
     980              : 
     981       658749 :             ALLOCATE (nblks_local(ndims))
     982       439166 :             ALLOCATE (nfull_local(ndims))
     983       768318 :             nfull_local(:) = 0
     984              :             #:for idim in range(1, maxdim+1)
     985       768014 :                IF (ndims .GE. ${idim}$) THEN
     986      5126457 :                   nblks_local(${idim}$) = COUNT(dist_${idim}$ == task_coor(${idim}$))
     987      1644714 :                   ALLOCATE (blks_local_${idim}$ (nblks_local(${idim}$)))
     988       548735 :                   iblk_count = 0
     989      5126457 :                   DO iblk = 1, SIZE(dist_${idim}$)
     990      5126457 :                      IF (dist_${idim}$ (iblk) == task_coor(${idim}$)) THEN
     991      4104430 :                         iblk_count = iblk_count + 1
     992      4104430 :                         blks_local_${idim}$ (iblk_count) = iblk
     993      4104430 :                         nfull_local(${idim}$) = nfull_local(${idim}$) + blk_size_${idim}$ (iblk)
     994              :                      END IF
     995              :                   END DO
     996              :                END IF
     997              :             #:endfor
     998              : 
     999              :             #:for ndim in range(1, maxdim+1)
    1000       438862 :                IF (ndims == ${ndim}$) THEN
    1001       219583 :                   CALL create_array_list(blks_local, ${ndim}$, ${varlist("blks_local", nmax=ndim)}$)
    1002              :                END IF
    1003              :             #:endfor
    1004              : 
    1005       658749 :             ALLOCATE (tensor%nblks_local(ndims))
    1006       439166 :             ALLOCATE (tensor%nfull_local(ndims))
    1007       768318 :             tensor%nblks_local(:) = nblks_local
    1008       768318 :             tensor%nfull_local(:) = nfull_local
    1009              : 
    1010       219583 :             tensor%blks_local = blks_local
    1011              : 
    1012       219583 :             tensor%nd_dist = dist_new%nd_dist
    1013       219583 :             tensor%pgrid = dist_new%pgrid
    1014              : 
    1015       219583 :             CALL dbt_distribution_hold(dist_new)
    1016       219583 :             tensor%refcount => dist_new%refcount
    1017       219583 :             CALL dbt_distribution_destroy(dist_new)
    1018              : 
    1019       219583 :             CALL array_offsets(tensor%blk_sizes, tensor%blk_offsets)
    1020              : 
    1021       219583 :             tensor%valid = .TRUE.
    1022       219583 :             CALL timestop(handle)
    1023       658749 :          END SUBROUTINE
    1024              : 
    1025              : ! **************************************************************************************************
    1026              : !> \brief reference counting for tensors
    1027              : !>        (only needed for communicator handle that must be freed when no longer needed)
    1028              : !> \author Patrick Seewald
    1029              : ! **************************************************************************************************
    1030       869895 :          SUBROUTINE dbt_hold(tensor)
    1031              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1032              :             INTEGER, POINTER :: ref
    1033              : 
    1034       869895 :             IF (tensor%refcount < 1) THEN
    1035            0 :                CPABORT("can not hold non-existing tensor")
    1036              :             END IF
    1037       869895 :             ref => tensor%refcount
    1038       869895 :             ref = ref + 1
    1039              : 
    1040       869895 :          END SUBROUTINE
    1041              : 
    1042              : ! **************************************************************************************************
    1043              : !> \brief how many tensor dimensions are mapped to matrix row
    1044              : !> \author Patrick Seewald
    1045              : ! **************************************************************************************************
    1046      2024906 :          PURE FUNCTION ndims_matrix_row(tensor)
    1047              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1048              :             INTEGER(int_8) :: ndims_matrix_row
    1049              : 
    1050      2024906 :             ndims_matrix_row = ndims_mapping_row(tensor%nd_index_blk)
    1051              : 
    1052      2024906 :          END FUNCTION
    1053              : 
    1054              : ! **************************************************************************************************
    1055              : !> \brief how many tensor dimensions are mapped to matrix column
    1056              : !> \author Patrick Seewald
    1057              : ! **************************************************************************************************
    1058      2024906 :          PURE FUNCTION ndims_matrix_column(tensor)
    1059              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1060              :             INTEGER(int_8) :: ndims_matrix_column
    1061              : 
    1062      2024906 :             ndims_matrix_column = ndims_mapping_column(tensor%nd_index_blk)
    1063      2024906 :          END FUNCTION
    1064              : 
    1065              : ! **************************************************************************************************
    1066              : !> \brief tensor rank
    1067              : !> \author Patrick Seewald
    1068              : ! **************************************************************************************************
    1069    106168763 :          PURE FUNCTION ndims_tensor(tensor)
    1070              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1071              :             INTEGER                        :: ndims_tensor
    1072              : 
    1073    106168763 :             ndims_tensor = tensor%nd_index%ndim_nd
    1074    106168763 :          END FUNCTION
    1075              : 
    1076              : ! **************************************************************************************************
    1077              : !> \brief tensor dimensions
    1078              : !> \author Patrick Seewald
    1079              : ! **************************************************************************************************
    1080         3580 :          SUBROUTINE dims_tensor(tensor, dims)
    1081              :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1082              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1083              :                INTENT(OUT)                              :: dims
    1084              : 
    1085         3580 :             CPASSERT(tensor%valid)
    1086        17176 :             dims = tensor%nd_index%dims_nd
    1087         3580 :          END SUBROUTINE
    1088              : 
    1089              : ! **************************************************************************************************
    1090              : !> \brief create a tensor from template
    1091              : !> \author Patrick Seewald
    1092              : ! **************************************************************************************************
    1093      2633643 :          SUBROUTINE dbt_create_template(tensor_in, tensor, name, dist, map1_2d, map2_2d)
    1094              :             TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
    1095              :             TYPE(dbt_type), INTENT(OUT)        :: tensor
    1096              :             CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
    1097              :             TYPE(dbt_distribution_type), &
    1098              :                INTENT(INOUT), OPTIONAL             :: dist
    1099              :             INTEGER, DIMENSION(:), INTENT(IN), &
    1100              :                OPTIONAL                            :: map1_2d, map2_2d
    1101              :             INTEGER                                :: handle
    1102              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_create_template'
    1103       292627 :             INTEGER, DIMENSION(:), ALLOCATABLE     :: ${varlist("bsize")}$
    1104       292627 :             INTEGER, DIMENSION(:), ALLOCATABLE     :: map1_2d_prv, map2_2d_prv
    1105              :             CHARACTER(len=default_string_length)   :: name_prv
    1106      2048389 :             TYPE(dbt_distribution_type)        :: dist_prv
    1107              : 
    1108       292627 :             CALL timeset(routineN, handle)
    1109              : 
    1110       292627 :             IF (PRESENT(dist) .OR. PRESENT(map1_2d) .OR. PRESENT(map2_2d)) THEN
    1111              :                ! need to create matrix representation from scratch
    1112          340 :                IF (PRESENT(dist)) THEN
    1113            0 :                   dist_prv = dist
    1114              :                ELSE
    1115          340 :                   dist_prv = dbt_distribution(tensor_in)
    1116              :                END IF
    1117          340 :                IF (PRESENT(map1_2d) .AND. PRESENT(map2_2d)) THEN
    1118         1360 :                   ALLOCATE (map1_2d_prv, source=map1_2d)
    1119         1700 :                   ALLOCATE (map2_2d_prv, source=map2_2d)
    1120              :                ELSE
    1121            0 :                   ALLOCATE (map1_2d_prv(ndims_matrix_row(tensor_in)))
    1122            0 :                   ALLOCATE (map2_2d_prv(ndims_matrix_column(tensor_in)))
    1123            0 :                   CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d_prv, map2_2d=map2_2d_prv)
    1124              :                END IF
    1125          340 :                IF (PRESENT(name)) THEN
    1126            0 :                   name_prv = name
    1127              :                ELSE
    1128          340 :                   name_prv = tensor_in%name
    1129              :                END IF
    1130              : 
    1131              :                #:for ndim in range(1, maxdim+1)
    1132         1020 :                   IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1133          340 :                      CALL get_arrays(tensor_in%blk_sizes, ${varlist("bsize", nmax=ndim)}$)
    1134              :                      CALL dbt_create(tensor, name_prv, dist_prv, map1_2d_prv, map2_2d_prv, &
    1135          340 :                                      ${varlist("bsize", nmax=ndim)}$)
    1136              :                   END IF
    1137              :                #:endfor
    1138              :             ELSE
    1139              :                ! create matrix representation from template
    1140      1461435 :                ALLOCATE (tensor%matrix_rep)
    1141       292287 :                IF (.NOT. PRESENT(name)) THEN
    1142              :                   CALL dbt_tas_create(tensor_in%matrix_rep, tensor%matrix_rep, &
    1143       276613 :                                       name=TRIM(tensor_in%name)//" matrix")
    1144              :                ELSE
    1145        15674 :                   CALL dbt_tas_create(tensor_in%matrix_rep, tensor%matrix_rep, name=TRIM(name)//" matrix")
    1146              :                END IF
    1147       292287 :                tensor%owns_matrix = .TRUE.
    1148       292287 :                CALL dbt_tas_finalize(tensor%matrix_rep)
    1149              : 
    1150       292287 :                tensor%nd_index_blk = tensor_in%nd_index_blk
    1151       292287 :                tensor%nd_index = tensor_in%nd_index
    1152       292287 :                tensor%blk_sizes = tensor_in%blk_sizes
    1153       292287 :                tensor%blk_offsets = tensor_in%blk_offsets
    1154       292287 :                tensor%nd_dist = tensor_in%nd_dist
    1155       292287 :                tensor%blks_local = tensor_in%blks_local
    1156       876861 :                ALLOCATE (tensor%nblks_local(ndims_tensor(tensor_in)))
    1157      1082478 :                tensor%nblks_local(:) = tensor_in%nblks_local
    1158       876861 :                ALLOCATE (tensor%nfull_local(ndims_tensor(tensor_in)))
    1159      1082478 :                tensor%nfull_local(:) = tensor_in%nfull_local
    1160       292287 :                tensor%pgrid = tensor_in%pgrid
    1161              : 
    1162       292287 :                tensor%refcount => tensor_in%refcount
    1163       292287 :                CALL dbt_hold(tensor)
    1164              : 
    1165       292287 :                tensor%valid = .TRUE.
    1166       292287 :                IF (PRESENT(name)) THEN
    1167        15674 :                   tensor%name = name
    1168              :                ELSE
    1169       276613 :                   tensor%name = tensor_in%name
    1170              :                END IF
    1171              :             END IF
    1172       292627 :             CALL timestop(handle)
    1173       585254 :          END SUBROUTINE
    1174              : 
    1175              : ! **************************************************************************************************
    1176              : !> \brief Create 2-rank tensor from matrix.
    1177              : !> \author Patrick Seewald
    1178              : ! **************************************************************************************************
    1179       263580 :          SUBROUTINE dbt_create_matrix(matrix_in, tensor, order, name)
    1180              :             TYPE(dbcsr_type), INTENT(IN)                :: matrix_in
    1181              :             TYPE(dbt_type), INTENT(OUT)        :: tensor
    1182              :             INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: order
    1183              :             CHARACTER(len=*), INTENT(IN), OPTIONAL      :: name
    1184              : 
    1185              :             CHARACTER(len=default_string_length)        :: name_in
    1186              :             INTEGER, DIMENSION(2)                       :: order_in
    1187              :             TYPE(mp_comm_type)                          :: comm_2d
    1188              :             TYPE(dbcsr_distribution_type)               :: matrix_dist
    1189       237222 :             TYPE(dbt_distribution_type)                 :: dist
    1190        52716 :             INTEGER, DIMENSION(:), POINTER              :: row_blk_size, col_blk_size
    1191        52716 :             INTEGER, DIMENSION(:), POINTER              :: col_dist, row_dist
    1192              :             INTEGER                                   :: handle, comm_2d_handle
    1193              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_create_matrix'
    1194        79074 :             TYPE(dbt_pgrid_type)                  :: comm_nd
    1195              :             INTEGER, DIMENSION(2)                     :: pdims_2d
    1196              : 
    1197        26358 :             CALL timeset(routineN, handle)
    1198              : 
    1199        26358 :             NULLIFY (row_blk_size, col_blk_size, col_dist, row_dist)
    1200        26358 :             IF (PRESENT(name)) THEN
    1201          850 :                name_in = name
    1202              :             ELSE
    1203        25508 :                CALL dbcsr_get_info(matrix_in, name=name_in)
    1204              :             END IF
    1205              : 
    1206        26358 :             IF (PRESENT(order)) THEN
    1207            0 :                order_in = order
    1208              :             ELSE
    1209        26358 :                order_in = [1, 2]
    1210              :             END IF
    1211              : 
    1212        26358 :             CALL dbcsr_get_info(matrix_in, distribution=matrix_dist)
    1213              :             CALL dbcsr_distribution_get(matrix_dist, group=comm_2d_handle, row_dist=row_dist, col_dist=col_dist, &
    1214        26358 :                                         nprows=pdims_2d(1), npcols=pdims_2d(2))
    1215        26358 :             CALL comm_2d%set_handle(comm_2d_handle)
    1216        79074 :             comm_nd = dbt_nd_mp_comm(comm_2d, [order_in(1)], [order_in(2)], pdims_2d=pdims_2d)
    1217              : 
    1218              :             CALL dbt_distribution_new_expert( &
    1219              :                dist, &
    1220              :                comm_nd, &
    1221              :                [order_in(1)], [order_in(2)], &
    1222        79074 :                row_dist, col_dist, own_comm=.TRUE.)
    1223              : 
    1224        26358 :             CALL dbcsr_get_info(matrix_in, row_blk_size=row_blk_size, col_blk_size=col_blk_size)
    1225              : 
    1226              :             CALL dbt_create_new(tensor, name_in, dist, &
    1227              :                                 [order_in(1)], [order_in(2)], &
    1228              :                                 row_blk_size, &
    1229        79074 :                                 col_blk_size)
    1230              : 
    1231        26358 :             CALL dbt_distribution_destroy(dist)
    1232        26358 :             CALL timestop(handle)
    1233       131790 :          END SUBROUTINE
    1234              : 
    1235              : ! **************************************************************************************************
    1236              : !> \brief Destroy a tensor
    1237              : !> \author Patrick Seewald
    1238              : ! **************************************************************************************************
    1239      1089478 :          SUBROUTINE dbt_destroy(tensor)
    1240              :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1241              :             INTEGER                                   :: handle
    1242              :             CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_destroy'
    1243              :             LOGICAL :: abort
    1244              : 
    1245      1089478 :             CALL timeset(routineN, handle)
    1246      1089478 :             IF (tensor%owns_matrix) THEN
    1247       511870 :                CALL dbt_tas_destroy(tensor%matrix_rep)
    1248       511870 :                DEALLOCATE (tensor%matrix_rep)
    1249              :             ELSE
    1250       577608 :                NULLIFY (tensor%matrix_rep)
    1251              :             END IF
    1252      1089478 :             tensor%owns_matrix = .FALSE.
    1253              : 
    1254      1089478 :             CALL destroy_nd_to_2d_mapping(tensor%nd_index_blk)
    1255      1089478 :             CALL destroy_nd_to_2d_mapping(tensor%nd_index)
    1256              :             !CALL destroy_nd_to_2d_mapping(tensor%nd_index_grid)
    1257      1089478 :             CALL destroy_array_list(tensor%blk_sizes)
    1258      1089478 :             CALL destroy_array_list(tensor%blk_offsets)
    1259      1089478 :             CALL destroy_array_list(tensor%nd_dist)
    1260      1089478 :             CALL destroy_array_list(tensor%blks_local)
    1261              : 
    1262      1089478 :             DEALLOCATE (tensor%nblks_local, tensor%nfull_local)
    1263              : 
    1264      1089478 :             abort = .FALSE.
    1265      1089478 :             IF (.NOT. ASSOCIATED(tensor%refcount)) THEN
    1266              :                abort = .TRUE.
    1267      1089478 :             ELSEIF (tensor%refcount < 1) THEN
    1268              :                abort = .TRUE.
    1269              :             END IF
    1270              : 
    1271              :             IF (abort) THEN
    1272            0 :                CPABORT("can not destroy non-existing tensor")
    1273              :             END IF
    1274              : 
    1275      1089478 :             tensor%refcount = tensor%refcount - 1
    1276              : 
    1277      1089478 :             IF (tensor%refcount == 0) THEN
    1278       219583 :                CALL dbt_pgrid_destroy(tensor%pgrid)
    1279              :                !CALL tensor%comm_2d%free()
    1280              :                !CALL tensor%comm_nd%free()
    1281       219583 :                DEALLOCATE (tensor%refcount)
    1282              :             ELSE
    1283       869895 :                CALL dbt_pgrid_destroy(tensor%pgrid, keep_comm=.TRUE.)
    1284              :             END IF
    1285              : 
    1286      1089478 :             tensor%valid = .FALSE.
    1287      1089478 :             tensor%name = ""
    1288      1089478 :             CALL timestop(handle)
    1289      1089478 :          END SUBROUTINE
    1290              : 
    1291              : ! **************************************************************************************************
    1292              : !> \brief tensor block dimensions
    1293              : !> \author Patrick Seewald
    1294              : ! **************************************************************************************************
    1295       910802 :          SUBROUTINE blk_dims_tensor(tensor, dims)
    1296              :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1297              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1298              :                INTENT(OUT)                              :: dims
    1299              : 
    1300       910802 :             CPASSERT(tensor%valid)
    1301      3406986 :             dims = tensor%nd_index_blk%dims_nd
    1302       910802 :          END SUBROUTINE
    1303              : 
    1304              : ! **************************************************************************************************
    1305              : !> \brief Size of tensor block
    1306              : !> \author Patrick Seewald
    1307              : ! **************************************************************************************************
    1308     27080711 :          SUBROUTINE dbt_blk_sizes(tensor, ind, blk_size)
    1309              :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1310              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1311              :                INTENT(IN)                               :: ind
    1312              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1313              :                INTENT(OUT)                              :: blk_size
    1314              : 
    1315     27080711 :             blk_size(:) = get_array_elements(tensor%blk_sizes, ind)
    1316     27080711 :          END SUBROUTINE
    1317              : 
    1318              : ! **************************************************************************************************
    1319              : !> \brief offset of tensor block
    1320              : !> \param ind block index
    1321              : !> \param blk_offset block offset
    1322              : !> \author Patrick Seewald
    1323              : ! **************************************************************************************************
    1324            0 :          SUBROUTINE dbt_blk_offsets(tensor, ind, blk_offset)
    1325              :             TYPE(dbt_type), INTENT(IN)              :: tensor
    1326              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1327              :                INTENT(IN)                               :: ind
    1328              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1329              :                INTENT(OUT)                              :: blk_offset
    1330              : 
    1331            0 :             CPASSERT(tensor%valid)
    1332            0 :             blk_offset(:) = get_array_elements(tensor%blk_offsets, ind)
    1333            0 :          END SUBROUTINE
    1334              : 
    1335              : ! **************************************************************************************************
    1336              : !> \brief Generalization of block_get_stored_coordinates for tensors.
    1337              : !> \author Patrick Seewald
    1338              : ! **************************************************************************************************
    1339     15220556 :          SUBROUTINE dbt_get_stored_coordinates(tensor, ind_nd, processor)
    1340              :             TYPE(dbt_type), INTENT(IN)               :: tensor
    1341              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1342              :                INTENT(IN)                                :: ind_nd
    1343              :             INTEGER, INTENT(OUT)                         :: processor
    1344              : 
    1345              :             INTEGER(KIND=int_8), DIMENSION(2)                        :: ind_2d
    1346              : 
    1347      7610278 :             ind_2d(:) = get_2d_indices_tensor(tensor%nd_index_blk, ind_nd)
    1348      7610278 :             CALL dbt_tas_get_stored_coordinates(tensor%matrix_rep, ind_2d(1), ind_2d(2), processor)
    1349      7610278 :          END SUBROUTINE
    1350              : 
    1351              : ! **************************************************************************************************
    1352              : !> \author Patrick Seewald
    1353              : ! **************************************************************************************************
    1354        19410 :          SUBROUTINE dbt_pgrid_create(mp_comm, dims, pgrid, tensor_dims)
    1355              :             CLASS(mp_comm_type), INTENT(IN) :: mp_comm
    1356              :             INTEGER, DIMENSION(:), INTENT(INOUT) :: dims
    1357              :             TYPE(dbt_pgrid_type), INTENT(OUT) :: pgrid
    1358              :             INTEGER, DIMENSION(:), INTENT(IN), OPTIONAL :: tensor_dims
    1359         3882 :             INTEGER, DIMENSION(:), ALLOCATABLE :: map1_2d, map2_2d
    1360              :             INTEGER :: i, ndims
    1361              : 
    1362         3882 :             ndims = SIZE(dims)
    1363              : 
    1364        11646 :             ALLOCATE (map1_2d(ndims/2))
    1365        11646 :             ALLOCATE (map2_2d(ndims - ndims/2))
    1366        15536 :             map1_2d(:) = (/(i, i=1, SIZE(map1_2d))/)
    1367        21248 :             map2_2d(:) = (/(i, i=SIZE(map1_2d) + 1, SIZE(map1_2d) + SIZE(map2_2d))/)
    1368              : 
    1369         6264 :             CALL dbt_pgrid_create_expert(mp_comm, dims, pgrid, map1_2d, map2_2d, tensor_dims)
    1370              : 
    1371         3882 :          END SUBROUTINE
    1372              : 
    1373              : ! **************************************************************************************************
    1374              : !> \brief freeze current split factor such that it is never changed during contraction
    1375              : !> \author Patrick Seewald
    1376              : ! **************************************************************************************************
    1377            0 :          SUBROUTINE dbt_pgrid_set_strict_split(pgrid)
    1378              :             TYPE(dbt_pgrid_type), INTENT(INOUT) :: pgrid
    1379            0 :             IF (ALLOCATED(pgrid%tas_split_info)) CALL dbt_tas_set_strict_split(pgrid%tas_split_info)
    1380            0 :          END SUBROUTINE
    1381              : 
    1382              : ! **************************************************************************************************
    1383              : !> \brief change dimensions of an existing process grid.
    1384              : !> \param pgrid process grid to be changed
    1385              : !> \param pdims new process grid dimensions, should all be set > 0
    1386              : !> \author Patrick Seewald
    1387              : ! **************************************************************************************************
    1388            0 :          SUBROUTINE dbt_pgrid_change_dims(pgrid, pdims)
    1389              :             TYPE(dbt_pgrid_type), INTENT(INOUT) :: pgrid
    1390              :             INTEGER, DIMENSION(:), INTENT(INOUT)    :: pdims
    1391            0 :             TYPE(dbt_pgrid_type)                :: pgrid_tmp
    1392              :             INTEGER                                 :: nsplit, dimsplit
    1393            0 :             INTEGER, DIMENSION(ndims_mapping_row(pgrid%nd_index_grid)) :: map1_2d
    1394            0 :             INTEGER, DIMENSION(ndims_mapping_column(pgrid%nd_index_grid)) :: map2_2d
    1395            0 :             TYPe(nd_to_2d_mapping)                  :: nd_index_grid
    1396              :             INTEGER, DIMENSION(2)                   :: pdims_2d
    1397              : 
    1398            0 :             CPASSERT(ALL(pdims > 0))
    1399            0 :             CALL dbt_tas_get_split_info(pgrid%tas_split_info, nsplit=nsplit, split_rowcol=dimsplit)
    1400            0 :             CALL dbt_get_mapping_info(pgrid%nd_index_grid, map1_2d=map1_2d, map2_2d=map2_2d)
    1401            0 :             CALL create_nd_to_2d_mapping(nd_index_grid, pdims, map1_2d, map2_2d, base=0, col_major=.FALSE.)
    1402            0 :             CALL dbt_get_mapping_info(nd_index_grid, dims_2d=pdims_2d)
    1403            0 :             IF (MOD(pdims_2d(dimsplit), nsplit) == 0) THEN
    1404              :                CALL dbt_pgrid_create_expert(pgrid%mp_comm_2d, pdims, pgrid_tmp, map1_2d=map1_2d, map2_2d=map2_2d, &
    1405            0 :                                             nsplit=nsplit, dimsplit=dimsplit)
    1406              :             ELSE
    1407            0 :                CALL dbt_pgrid_create_expert(pgrid%mp_comm_2d, pdims, pgrid_tmp, map1_2d=map1_2d, map2_2d=map2_2d)
    1408              :             END IF
    1409            0 :             CALL dbt_pgrid_destroy(pgrid)
    1410            0 :             pgrid = pgrid_tmp
    1411            0 :          END SUBROUTINE
    1412              : 
    1413              : ! **************************************************************************************************
    1414              : !> \brief As block_filter
    1415              : !> \author Patrick Seewald
    1416              : ! **************************************************************************************************
    1417       198471 :          SUBROUTINE dbt_filter(tensor, eps)
    1418              :             TYPE(dbt_type), INTENT(INOUT)    :: tensor
    1419              :             REAL(dp), INTENT(IN)                :: eps
    1420              : 
    1421       198471 :             CALL dbt_tas_filter(tensor%matrix_rep, eps)
    1422              : 
    1423       198471 :          END SUBROUTINE
    1424              : 
    1425              : ! **************************************************************************************************
    1426              : !> \brief local number of blocks along dimension idim
    1427              : !> \author Patrick Seewald
    1428              : ! **************************************************************************************************
    1429       405374 :          PURE FUNCTION dbt_nblks_local(tensor, idim)
    1430              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1431              :             INTEGER, INTENT(IN) :: idim
    1432              :             INTEGER :: dbt_nblks_local
    1433              : 
    1434       405374 :             IF (idim > ndims_tensor(tensor)) THEN
    1435              :                dbt_nblks_local = 0
    1436              :             ELSE
    1437       405374 :                dbt_nblks_local = tensor%nblks_local(idim)
    1438              :             END IF
    1439              : 
    1440       405374 :          END FUNCTION
    1441              : 
    1442              : ! **************************************************************************************************
    1443              : !> \brief total numbers of blocks along dimension idim
    1444              : !> \author Patrick Seewald
    1445              : ! **************************************************************************************************
    1446      2696430 :          PURE FUNCTION dbt_nblks_total(tensor, idim)
    1447              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1448              :             INTEGER, INTENT(IN) :: idim
    1449              :             INTEGER :: dbt_nblks_total
    1450              : 
    1451      2696430 :             IF (idim > ndims_tensor(tensor)) THEN
    1452              :                dbt_nblks_total = 0
    1453              :             ELSE
    1454      2101248 :                dbt_nblks_total = tensor%nd_index_blk%dims_nd(idim)
    1455              :             END IF
    1456      2696430 :          END FUNCTION
    1457              : 
    1458              : ! **************************************************************************************************
    1459              : !> \brief As block_get_info but for tensors
    1460              : !> \param nblks_total number of blocks along each dimension
    1461              : !> \param nfull_total number of elements along each dimension
    1462              : !> \param nblks_local local number of blocks along each dimension
    1463              : !> \param nfull_local local number of elements along each dimension
    1464              : !> \param my_ploc process coordinates in process grid
    1465              : !> \param pdims process grid dimensions
    1466              : !> \param blks_local_${idim}$ local blocks along dimension ${idim}$
    1467              : !> \param proc_dist_${idim}$ distribution along dimension ${idim}$
    1468              : !> \param blk_size_${idim}$ block sizes along dimension ${idim}$
    1469              : !> \param blk_offset_${idim}$ block offsets along dimension ${idim}$
    1470              : !> \param distribution distribution object
    1471              : !> \param name name of tensor
    1472              : !> \author Patrick Seewald
    1473              : ! **************************************************************************************************
    1474            0 :          SUBROUTINE dbt_get_info(tensor, nblks_total, &
    1475              :                                  nfull_total, &
    1476       150942 :                                  nblks_local, &
    1477       150942 :                                  nfull_local, &
    1478              :                                  pdims, &
    1479              :                                  my_ploc, &
    1480              :                                  ${varlist("blks_local")}$, &
    1481              :                                  ${varlist("proc_dist")}$, &
    1482              :                                  ${varlist("blk_size")}$, &
    1483              :                                  ${varlist("blk_offset")}$, &
    1484              :                                  distribution, &
    1485              :                                  name)
    1486              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1487              :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nblks_total
    1488              :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nfull_total
    1489              :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nblks_local
    1490              :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: nfull_local
    1491              :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: my_ploc
    1492              :             INTEGER, INTENT(OUT), OPTIONAL, DIMENSION(ndims_tensor(tensor)) :: pdims
    1493              :             #:for idim in range(1, maxdim+1)
    1494              :                INTEGER, DIMENSION(dbt_nblks_local(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: blks_local_${idim}$
    1495              :                INTEGER, DIMENSION(dbt_nblks_total(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: proc_dist_${idim}$
    1496              :                INTEGER, DIMENSION(dbt_nblks_total(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: blk_size_${idim}$
    1497              :                INTEGER, DIMENSION(dbt_nblks_total(tensor, ${idim}$)), INTENT(OUT), OPTIONAL :: blk_offset_${idim}$
    1498              :             #:endfor
    1499              :             TYPE(dbt_distribution_type), INTENT(OUT), OPTIONAL    :: distribution
    1500              :             CHARACTER(len=*), INTENT(OUT), OPTIONAL                   :: name
    1501       847769 :             INTEGER, DIMENSION(ndims_tensor(tensor))                  :: pdims_tmp, my_ploc_tmp
    1502              : 
    1503       847769 :             IF (PRESENT(nblks_total)) CALL dbt_get_mapping_info(tensor%nd_index_blk, dims_nd=nblks_total)
    1504       847769 :             IF (PRESENT(nfull_total)) CALL dbt_get_mapping_info(tensor%nd_index, dims_nd=nfull_total)
    1505      1253143 :             IF (PRESENT(nblks_local)) nblks_local(:) = tensor%nblks_local
    1506      1253143 :             IF (PRESENT(nfull_local)) nfull_local(:) = tensor%nfull_local
    1507              : 
    1508       847769 :             IF (PRESENT(my_ploc) .OR. PRESENT(pdims)) CALL mp_environ_pgrid(tensor%pgrid, pdims_tmp, my_ploc_tmp)
    1509       556518 :             IF (PRESENT(my_ploc)) my_ploc = my_ploc_tmp
    1510      1253749 :             IF (PRESENT(pdims)) pdims = pdims_tmp
    1511              : 
    1512              :             #:for idim in range(1, maxdim+1)
    1513      3092286 :                IF (${idim}$ <= ndims_tensor(tensor)) THEN
    1514      2244569 :                   IF (PRESENT(blks_local_${idim}$)) CALL get_ith_array(tensor%blks_local, ${idim}$, &
    1515              :                                                                        dbt_nblks_local(tensor, ${idim}$), &
    1516       405374 :                                                                        blks_local_${idim}$)
    1517      2244569 :                   IF (PRESENT(proc_dist_${idim}$)) CALL get_ith_array(tensor%nd_dist, ${idim}$, &
    1518              :                                                                       dbt_nblks_total(tensor, ${idim}$), &
    1519       407512 :                                                                       proc_dist_${idim}$)
    1520      2244569 :                   IF (PRESENT(blk_size_${idim}$)) CALL get_ith_array(tensor%blk_sizes, ${idim}$, &
    1521              :                                                                      dbt_nblks_total(tensor, ${idim}$), &
    1522       472200 :                                                                      blk_size_${idim}$)
    1523      2244569 :                   IF (PRESENT(blk_offset_${idim}$)) CALL get_ith_array(tensor%blk_offsets, ${idim}$, &
    1524              :                                                                        dbt_nblks_total(tensor, ${idim}$), &
    1525           62 :                                                                        blk_offset_${idim}$)
    1526              :                END IF
    1527              :             #:endfor
    1528              : 
    1529       847769 :             IF (PRESENT(distribution)) distribution = dbt_distribution(tensor)
    1530       847769 :             IF (PRESENT(name)) name = tensor%name
    1531              : 
    1532       998711 :          END SUBROUTINE
    1533              : 
    1534              : ! **************************************************************************************************
    1535              : !> \brief As block_get_num_blocks: get number of local blocks
    1536              : !> \author Patrick Seewald
    1537              : ! **************************************************************************************************
    1538       557728 :          PURE FUNCTION dbt_get_num_blocks(tensor) RESULT(num_blocks)
    1539              :             TYPE(dbt_type), INTENT(IN)    :: tensor
    1540              :             INTEGER                           :: num_blocks
    1541       557728 :             num_blocks = dbt_tas_get_num_blocks(tensor%matrix_rep)
    1542       557728 :          END FUNCTION
    1543              : 
    1544              : ! **************************************************************************************************
    1545              : !> \brief Get total number of blocks
    1546              : !> \author Patrick Seewald
    1547              : ! **************************************************************************************************
    1548       153404 :          FUNCTION dbt_get_num_blocks_total(tensor) RESULT(num_blocks)
    1549              :             TYPE(dbt_type), INTENT(IN)    :: tensor
    1550              :             INTEGER(KIND=int_8)               :: num_blocks
    1551       153404 :             num_blocks = dbt_tas_get_num_blocks_total(tensor%matrix_rep)
    1552       153404 :          END FUNCTION
    1553              : 
    1554              : ! **************************************************************************************************
    1555              : !> \brief Clear tensor (s.t. it does not contain any blocks)
    1556              : !> \author Patrick Seewald
    1557              : ! **************************************************************************************************
    1558      1002360 :          SUBROUTINE dbt_clear(tensor)
    1559              :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1560              : 
    1561      1002360 :             CALL dbt_tas_clear(tensor%matrix_rep)
    1562      1002360 :          END SUBROUTINE
    1563              : 
    1564              : ! **************************************************************************************************
    1565              : !> \brief Finalize tensor, as block_finalize. This should be taken care of internally in DBT
    1566              : !>        tensors, there should not be any need to call this routine outside of DBT tensors.
    1567              : !> \author Patrick Seewald
    1568              : ! **************************************************************************************************
    1569       994502 :          SUBROUTINE dbt_finalize(tensor)
    1570              :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1571       994502 :             CALL dbt_tas_finalize(tensor%matrix_rep)
    1572       994502 :          END SUBROUTINE
    1573              : 
    1574              : ! **************************************************************************************************
    1575              : !> \brief as block_scale
    1576              : !> \author Patrick Seewald
    1577              : ! **************************************************************************************************
    1578        48203 :          SUBROUTINE dbt_scale(tensor, alpha)
    1579              :             TYPE(dbt_type), INTENT(INOUT) :: tensor
    1580              :             REAL(dp), INTENT(IN) :: alpha
    1581        48203 :             CALL dbm_scale(tensor%matrix_rep%matrix, alpha)
    1582        48203 :          END SUBROUTINE
    1583              : 
    1584              : ! **************************************************************************************************
    1585              : !> \author Patrick Seewald
    1586              : ! **************************************************************************************************
    1587       150942 :          PURE FUNCTION dbt_get_nze(tensor)
    1588              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1589              :             INTEGER                        :: dbt_get_nze
    1590       150942 :             dbt_get_nze = dbt_tas_get_nze(tensor%matrix_rep)
    1591       150942 :          END FUNCTION
    1592              : 
    1593              : ! **************************************************************************************************
    1594              : !> \author Patrick Seewald
    1595              : ! **************************************************************************************************
    1596       290286 :          FUNCTION dbt_get_nze_total(tensor)
    1597              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1598              :             INTEGER(KIND=int_8)            :: dbt_get_nze_total
    1599       290286 :             dbt_get_nze_total = dbt_tas_get_nze_total(tensor%matrix_rep)
    1600       290286 :          END FUNCTION
    1601              : 
    1602              : ! **************************************************************************************************
    1603              : !> \brief block size of block with index ind along dimension idim
    1604              : !> \author Patrick Seewald
    1605              : ! **************************************************************************************************
    1606            0 :          PURE FUNCTION dbt_blk_size(tensor, ind, idim)
    1607              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1608              :             INTEGER, DIMENSION(ndims_tensor(tensor)), &
    1609              :                INTENT(IN) :: ind
    1610              :             INTEGER, INTENT(IN) :: idim
    1611            0 :             INTEGER, DIMENSION(ndims_tensor(tensor)) :: blk_size
    1612              :             INTEGER :: dbt_blk_size
    1613              : 
    1614            0 :             IF (idim > ndims_tensor(tensor)) THEN
    1615              :                dbt_blk_size = 0
    1616              :             ELSE
    1617            0 :                blk_size(:) = get_array_elements(tensor%blk_sizes, ind)
    1618            0 :                dbt_blk_size = blk_size(idim)
    1619              :             END IF
    1620            0 :          END FUNCTION
    1621              : 
    1622              : ! **************************************************************************************************
    1623              : !> \brief returns an estimate of maximum number of local blocks in tensor
    1624              : !>        (irrespective of the actual number of currently present blocks)
    1625              : !>        this estimate is based on the following assumption: tensor data is dense and
    1626              : !>        load balancing is within a factor of 2
    1627              : !> \author Patrick Seewald
    1628              : ! **************************************************************************************************
    1629            0 :          PURE FUNCTION dbt_max_nblks_local(tensor) RESULT(blk_count)
    1630              :             TYPE(dbt_type), INTENT(IN) :: tensor
    1631              :             INTEGER :: blk_count, nproc
    1632            0 :             INTEGER, DIMENSION(ndims_tensor(tensor)) :: bdims
    1633              :             INTEGER(int_8) :: blk_count_total
    1634              :             INTEGER, PARAMETER :: max_load_imbalance = 2
    1635              : 
    1636            0 :             CALL dbt_get_mapping_info(tensor%nd_index_blk, dims_nd=bdims)
    1637              : 
    1638            0 :             blk_count_total = PRODUCT(INT(bdims, int_8))
    1639              : 
    1640              :             ! can not call an MPI routine due to PURE
    1641            0 :             nproc = tensor%pgrid%nproc
    1642              : 
    1643            0 :             blk_count = INT(blk_count_total/nproc*max_load_imbalance)
    1644              : 
    1645            0 :          END FUNCTION
    1646              : 
    1647              : ! **************************************************************************************************
    1648              : !> \brief get a load-balanced and randomized distribution along one tensor dimension
    1649              : !> \param nblk number of blocks (along one tensor dimension)
    1650              : !> \param nproc number of processes (along one process grid dimension)
    1651              : !> \param blk_size block sizes
    1652              : !> \param dist distribution
    1653              : !> \author Patrick Seewald
    1654              : ! **************************************************************************************************
    1655        94779 :          SUBROUTINE dbt_default_distvec(nblk, nproc, blk_size, dist)
    1656              :             INTEGER, INTENT(IN)                                :: nblk
    1657              :             INTEGER, INTENT(IN)                                :: nproc
    1658              :             INTEGER, DIMENSION(nblk), INTENT(IN)                :: blk_size
    1659              :             INTEGER, DIMENSION(nblk), INTENT(OUT)               :: dist
    1660              : 
    1661        94779 :             CALL dbt_tas_default_distvec(nblk, nproc, blk_size, dist)
    1662        94779 :          END SUBROUTINE
    1663              : 
    1664              : ! **************************************************************************************************
    1665              : !> \author Patrick Seewald
    1666              : ! **************************************************************************************************
    1667       604924 :          SUBROUTINE dbt_copy_contraction_storage(tensor_in, tensor_out)
    1668              :             TYPE(dbt_type), INTENT(IN) :: tensor_in
    1669              :             TYPE(dbt_type), INTENT(INOUT) :: tensor_out
    1670       604924 :             TYPE(dbt_contraction_storage), ALLOCATABLE :: tensor_storage_tmp
    1671       604924 :             TYPE(dbt_tas_mm_storage), ALLOCATABLE :: tas_storage_tmp
    1672              : 
    1673       604924 :             IF (tensor_in%matrix_rep%do_batched > 0) THEN
    1674       142964 :                ALLOCATE (tas_storage_tmp, SOURCE=tensor_in%matrix_rep%mm_storage)
    1675              :                ! transfer data for batched contraction
    1676       142964 :                IF (ALLOCATED(tensor_out%matrix_rep%mm_storage)) DEALLOCATE (tensor_out%matrix_rep%mm_storage)
    1677       142964 :                CALL move_alloc(tas_storage_tmp, tensor_out%matrix_rep%mm_storage)
    1678              :             END IF
    1679              :             CALL dbt_tas_set_batched_state(tensor_out%matrix_rep, state=tensor_in%matrix_rep%do_batched, &
    1680       604924 :                                            opt_grid=tensor_in%matrix_rep%has_opt_pgrid)
    1681       604924 :             IF (ALLOCATED(tensor_in%contraction_storage)) THEN
    1682       420820 :                ALLOCATE (tensor_storage_tmp, SOURCE=tensor_in%contraction_storage)
    1683              :             END IF
    1684       604924 :             IF (ALLOCATED(tensor_out%contraction_storage)) DEALLOCATE (tensor_out%contraction_storage)
    1685       604924 :             IF (ALLOCATED(tensor_storage_tmp)) CALL move_alloc(tensor_storage_tmp, tensor_out%contraction_storage)
    1686              : 
    1687       604924 :          END SUBROUTINE
    1688              : 
    1689     21056275 :       END MODULE
        

Generated by: LCOV version 2.0-1