LCOV - code coverage report
Current view: top level - src/dbt/tas - dbt_tas_mm.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b279b6b) Lines: 688 742 92.7 %
Date: 2024-04-24 07:13:09 Functions: 15 15 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Matrix multiplication for tall-and-skinny matrices.
      10             : !>        This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
      11             : !>        as long as the two smaller dimensions have the same size.
      12             : !>        Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
      13             : !>        submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
      14             : !>        parameters (group sizes and process grid dimensions) can not be derived from matrix
      15             : !>        dimensions and need to be set manually.
      16             : !> \author Patrick Seewald
      17             : ! **************************************************************************************************
      18             : MODULE dbt_tas_mm
      19             :    USE dbm_api,                         ONLY: &
      20             :         dbm_add, dbm_clear, dbm_copy, dbm_create, dbm_create_from_template, dbm_distribution_new, &
      21             :         dbm_distribution_obj, dbm_distribution_release, dbm_get_col_block_sizes, &
      22             :         dbm_get_distribution, dbm_get_name, dbm_get_nze, dbm_get_row_block_sizes, dbm_multiply, &
      23             :         dbm_redistribute, dbm_release, dbm_scale, dbm_type, dbm_zero
      24             :    USE dbt_tas_base,                    ONLY: &
      25             :         dbt_tas_clear, dbt_tas_copy, dbt_tas_create, dbt_tas_destroy, dbt_tas_distribution_new, &
      26             :         dbt_tas_filter, dbt_tas_get_info, dbt_tas_get_nze_total, dbt_tas_info, &
      27             :         dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
      28             :         dbt_tas_iterator_stop, dbt_tas_nblkcols_total, dbt_tas_nblkrows_total, dbt_tas_put_block, &
      29             :         dbt_tas_reserve_blocks
      30             :    USE dbt_tas_global,                  ONLY: dbt_tas_blk_size_one,&
      31             :                                               dbt_tas_default_distvec,&
      32             :                                               dbt_tas_dist_arb,&
      33             :                                               dbt_tas_dist_arb_default,&
      34             :                                               dbt_tas_dist_cyclic,&
      35             :                                               dbt_tas_distribution,&
      36             :                                               dbt_tas_rowcol_data
      37             :    USE dbt_tas_io,                      ONLY: dbt_tas_write_dist,&
      38             :                                               dbt_tas_write_matrix_info,&
      39             :                                               dbt_tas_write_split_info,&
      40             :                                               prep_output_unit
      41             :    USE dbt_tas_reshape_ops,             ONLY: dbt_tas_merge,&
      42             :                                               dbt_tas_replicate,&
      43             :                                               dbt_tas_reshape
      44             :    USE dbt_tas_split,                   ONLY: &
      45             :         accept_pgrid_dims, colsplit, dbt_tas_create_split, dbt_tas_get_split_info, &
      46             :         dbt_tas_info_hold, dbt_tas_mp_comm, dbt_tas_release_info, default_nsplit_accept_ratio, &
      47             :         rowsplit
      48             :    USE dbt_tas_types,                   ONLY: dbt_tas_distribution_type,&
      49             :                                               dbt_tas_iterator,&
      50             :                                               dbt_tas_split_info,&
      51             :                                               dbt_tas_type
      52             :    USE dbt_tas_util,                    ONLY: array_eq,&
      53             :                                               swap
      54             :    USE kinds,                           ONLY: default_string_length,&
      55             :                                               dp,&
      56             :                                               int_8
      57             :    USE message_passing,                 ONLY: mp_cart_type
      58             : #include "../../base/base_uses.f90"
      59             : 
      60             :    IMPLICIT NONE
      61             :    PRIVATE
      62             : 
      63             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
      64             : 
      65             :    PUBLIC :: &
      66             :       dbt_tas_multiply, &
      67             :       dbt_tas_batched_mm_init, &
      68             :       dbt_tas_batched_mm_finalize, &
      69             :       dbt_tas_set_batched_state, &
      70             :       dbt_tas_batched_mm_complete
      71             : 
      72             : CONTAINS
      73             : 
      74             : ! **************************************************************************************************
      75             : !> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
      76             : !>        to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
      77             : !> \param transa ...
      78             : !> \param transb ...
      79             : !> \param transc ...
      80             : !> \param alpha ...
      81             : !> \param matrix_a ...
      82             : !> \param matrix_b ...
      83             : !> \param beta ...
      84             : !> \param matrix_c ...
      85             : !> \param optimize_dist Whether distribution should be optimized internally. In the current
      86             : !>                      implementation this guarantees optimal parameters only for dense matrices.
      87             : !> \param split_opt optionally return split info containing optimal grid and split parameters.
      88             : !>                  This can be used to choose optimal process grids for subsequent matrix
      89             : !>                  multiplications with matrices of similar shape and sparsity.
      90             : !> \param filter_eps ...
      91             : !> \param flop ...
      92             : !> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
      93             : !>                   (for internal use only)
      94             : !> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
      95             : !>                   (for internal use only)
      96             : !> \param retain_sparsity ...
      97             : !> \param simple_split ...
      98             : !> \param unit_nr unit number for logging output
      99             : !> \param log_verbose only for testing: verbose output
     100             : !> \author Patrick Seewald
     101             : ! **************************************************************************************************
     102      518945 :    RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
     103             :                                          optimize_dist, split_opt, filter_eps, flop, move_data_a, &
     104             :                                          move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
     105             : 
     106             :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
     107             :       REAL(dp), INTENT(IN)                               :: alpha
     108             :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_a, matrix_b
     109             :       REAL(dp), INTENT(IN)                               :: beta
     110             :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_c
     111             :       LOGICAL, INTENT(IN), OPTIONAL                      :: optimize_dist
     112             :       TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL    :: split_opt
     113             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
     114             :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL         :: flop
     115             :       LOGICAL, INTENT(IN), OPTIONAL                      :: move_data_a, move_data_b, &
     116             :                                                             retain_sparsity, simple_split
     117             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
     118             :       LOGICAL, INTENT(IN), OPTIONAL                      :: log_verbose
     119             : 
     120             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'dbt_tas_multiply'
     121             : 
     122             :       INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
     123             :          nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
     124             :          unit_nr_prv
     125             :       INTEGER(KIND=int_8)                                :: nze_a, nze_b, nze_c, nze_c_sum
     126             :       INTEGER(KIND=int_8), DIMENSION(2)                  :: dims_a, dims_b, dims_c
     127             :       INTEGER(KIND=int_8), DIMENSION(3)                  :: dims
     128             :       INTEGER, DIMENSION(2)                              :: pdims, pdims_sub
     129             :       LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
     130             :          simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
     131             :       REAL(KIND=dp)                                      :: filter_eps_prv
     132             :       TYPE(dbm_type)                                     :: matrix_a_mm, matrix_b_mm, matrix_c_mm
     133     2135523 :       TYPE(dbt_tas_split_info)                           :: info, info_a, info_b, info_c
     134             :       TYPE(dbt_tas_type), POINTER                        :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
     135             :                                                             matrix_b_rs, matrix_c_rep, matrix_c_rs
     136      251238 :       TYPE(mp_cart_type)                                 :: comm_tmp, mp_comm, mp_comm_group, &
     137      125619 :                                                             mp_comm_mm, mp_comm_opt
     138             : 
     139      125619 :       CALL timeset(routineN, handle)
     140      125619 :       CALL matrix_a%dist%info%mp_comm%sync()
     141      125619 :       CALL timeset("dbt_tas_total", handle2)
     142             : 
     143      125619 :       NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
     144             : 
     145      125619 :       unit_nr_prv = prep_output_unit(unit_nr)
     146             : 
     147      125619 :       IF (PRESENT(simple_split)) THEN
     148       21782 :          simple_split_prv = simple_split
     149             :       ELSE
     150      103837 :          simple_split_prv = .FALSE.
     151             : 
     152      103837 :          info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
     153      103837 :          IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .TRUE.
     154             :       END IF
     155             : 
     156      125619 :       nodata_3 = .TRUE.
     157      125619 :       IF (PRESENT(retain_sparsity)) THEN
     158        4762 :          IF (retain_sparsity) nodata_3 = .FALSE.
     159             :       END IF
     160             : 
     161             :       ! get prestored info for multiplication strategy in case of batched mm
     162      125619 :       batched_repl = 0
     163      125619 :       do_batched = .FALSE.
     164      125619 :       IF (matrix_a%do_batched > 0) THEN
     165       40477 :          do_batched = .TRUE.
     166       40477 :          IF (matrix_a%do_batched == 3) THEN
     167             :             CPASSERT(batched_repl == 0)
     168       13241 :             batched_repl = 1
     169             :             CALL dbt_tas_get_split_info( &
     170             :                dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
     171       13241 :                nsplit=nsplit_batched)
     172       13241 :             CPASSERT(nsplit_batched > 0)
     173             :             max_mm_dim_batched = 3
     174             :          END IF
     175             :       END IF
     176             : 
     177      125619 :       IF (matrix_b%do_batched > 0) THEN
     178       14488 :          do_batched = .TRUE.
     179       14488 :          IF (matrix_b%do_batched == 3) THEN
     180        2960 :             CPASSERT(batched_repl == 0)
     181        2960 :             batched_repl = 2
     182             :             CALL dbt_tas_get_split_info( &
     183             :                dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
     184        2960 :                nsplit=nsplit_batched)
     185        2960 :             CPASSERT(nsplit_batched > 0)
     186             :             max_mm_dim_batched = 1
     187             :          END IF
     188             :       END IF
     189             : 
     190      125619 :       IF (matrix_c%do_batched > 0) THEN
     191       32424 :          do_batched = .TRUE.
     192       32424 :          IF (matrix_c%do_batched == 3) THEN
     193        6452 :             CPASSERT(batched_repl == 0)
     194        6452 :             batched_repl = 3
     195             :             CALL dbt_tas_get_split_info( &
     196             :                dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
     197        6452 :                nsplit=nsplit_batched)
     198        6452 :             CPASSERT(nsplit_batched > 0)
     199             :             max_mm_dim_batched = 2
     200             :          END IF
     201             :       END IF
     202             : 
     203      125619 :       move_a = .FALSE.
     204      125619 :       move_b = .FALSE.
     205             : 
     206      125619 :       IF (PRESENT(move_data_a)) move_a = move_data_a
     207      125619 :       IF (PRESENT(move_data_b)) move_b = move_data_b
     208             : 
     209      125619 :       transa_prv = transa; transb_prv = transb; transc_prv = transc
     210             : 
     211      376857 :       dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
     212      376857 :       dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
     213      376857 :       dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
     214             : 
     215      125619 :       IF (unit_nr_prv > 0) THEN
     216          34 :          WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
     217             :          WRITE (unit_nr_prv, "(A)") &
     218             :             "DBT TAS MATRIX MULTIPLICATION: "// &
     219             :             TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
     220             :             TRIM(dbm_get_name(matrix_b%matrix))//" = "// &
     221          34 :             TRIM(dbm_get_name(matrix_c%matrix))
     222          34 :          WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
     223             :       END IF
     224      125619 :       IF (do_batched) THEN
     225       84949 :          IF (unit_nr_prv > 0) THEN
     226             :             WRITE (unit_nr_prv, "(T2,A)") &
     227           0 :                "BATCHED PROCESSING OF MATMUL"
     228           0 :             IF (batched_repl > 0) THEN
     229           0 :                WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
     230             :             END IF
     231             :          END IF
     232             :       END IF
     233             : 
     234      125619 :       IF (transa_prv) THEN
     235       36847 :          CALL swap(dims_a)
     236             :       END IF
     237             : 
     238      125619 :       IF (transb_prv) THEN
     239       60806 :          CALL swap(dims_b)
     240             :       END IF
     241             : 
     242      376857 :       dims_c = [dims_a(1), dims_b(2)]
     243             : 
     244      125619 :       IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
     245           0 :          CPABORT("inconsistent matrix dimensions")
     246             :       END IF
     247             : 
     248      502476 :       dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
     249             : 
     250      125619 :       IF (unit_nr_prv > 0) THEN
     251          34 :          WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
     252             :       END IF
     253             : 
     254      125619 :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
     255      125619 :       numproc = mp_comm%num_pe
     256             : 
     257             :       ! derive optimal matrix layout and split factor from occupancies
     258      125619 :       nze_a = dbt_tas_get_nze_total(matrix_a)
     259      125619 :       nze_b = dbt_tas_get_nze_total(matrix_b)
     260             : 
     261      125619 :       IF (.NOT. simple_split_prv) THEN
     262             :          CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
     263             :                                           estimated_nze=nze_c, filter_eps=filter_eps, &
     264       21898 :                                           retain_sparsity=retain_sparsity)
     265             : 
     266       87592 :          max_mm_dim = MAXLOC(dims, 1)
     267       21898 :          nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
     268       21898 :          nsplit_opt = nsplit
     269             : 
     270       21898 :          IF (unit_nr_prv > 0) THEN
     271             :             WRITE (unit_nr_prv, "(T2,A)") &
     272          34 :                "MM PARAMETERS"
     273          34 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
     274          68 :                (nze_c + numproc - 1)/numproc
     275             : 
     276          34 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
     277             :          END IF
     278             : 
     279      103721 :       ELSEIF (batched_repl > 0) THEN
     280       22653 :          nsplit = nsplit_batched
     281       22653 :          nsplit_opt = nsplit
     282       22653 :          max_mm_dim = max_mm_dim_batched
     283       22653 :          IF (unit_nr_prv > 0) THEN
     284             :             WRITE (unit_nr_prv, "(T2,A)") &
     285           0 :                "MM PARAMETERS"
     286           0 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
     287             :          END IF
     288             : 
     289             :       ELSE
     290       81068 :          nsplit = 0
     291      324272 :          max_mm_dim = MAXLOC(dims, 1)
     292             :       END IF
     293             : 
     294             :       ! reshape matrices to the optimal layout and split factor
     295      125619 :       split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
     296       37082 :       SELECT CASE (max_mm_dim)
     297             :       CASE (1)
     298             : 
     299             :          split_a = rowsplit; split_c = rowsplit
     300             :          CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
     301             :                                     new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
     302             :                                     nsplit=nsplit, &
     303             :                                     opt_nsplit=batched_repl == 0, &
     304             :                                     split_rc_1=split_a, split_rc_2=split_c, &
     305             :                                     nodata2=nodata_3, comm_new=comm_tmp, &
     306       37082 :                                     move_data_1=move_a, unit_nr=unit_nr_prv)
     307             : 
     308       37082 :          info = dbt_tas_info(matrix_a_rs)
     309       37082 :          CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
     310             : 
     311       37082 :          new_b = .FALSE.
     312       37082 :          IF (matrix_b%do_batched <= 2) THEN
     313      170610 :             ALLOCATE (matrix_b_rs)
     314       34122 :             CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
     315       34122 :             transb_prv = .FALSE.
     316       34122 :             new_b = .TRUE.
     317             :          END IF
     318             : 
     319       37082 :          tr_case = transa_prv
     320             : 
     321       37082 :          IF (unit_nr_prv > 0) THEN
     322          11 :             IF (.NOT. tr_case) THEN
     323          11 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
     324             :             ELSE
     325           0 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
     326             :             END IF
     327             :          END IF
     328             : 
     329             :       CASE (2)
     330             : 
     331       38648 :          split_a = colsplit; split_b = rowsplit
     332             :          CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
     333             :                                     optimize_dist=optimize_dist, &
     334             :                                     nsplit=nsplit, &
     335             :                                     opt_nsplit=batched_repl == 0, &
     336             :                                     split_rc_1=split_a, split_rc_2=split_b, &
     337             :                                     comm_new=comm_tmp, &
     338       38648 :                                     move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
     339             : 
     340       38648 :          info = dbt_tas_info(matrix_a_rs)
     341       38648 :          CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
     342             : 
     343       38648 :          IF (matrix_c%do_batched == 1) THEN
     344       24534 :             matrix_c%mm_storage%batched_beta = beta
     345       14114 :          ELSEIF (matrix_c%do_batched > 1) THEN
     346        7608 :             matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
     347             :          END IF
     348             : 
     349       38648 :          IF (matrix_c%do_batched <= 2) THEN
     350      160980 :             ALLOCATE (matrix_c_rs)
     351       32196 :             CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
     352       32196 :             transc_prv = .FALSE.
     353             : 
     354             :             ! just leave sparsity structure for retain sparsity but no values
     355       32196 :             IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
     356             : 
     357       32196 :             IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
     358        6452 :          ELSEIF (matrix_c%do_batched == 3) THEN
     359        6452 :             matrix_c_rs => matrix_c%mm_storage%store_batched
     360             :          END IF
     361             : 
     362       38648 :          new_c = matrix_c%do_batched == 0
     363       38648 :          tr_case = transa_prv
     364             : 
     365       38648 :          IF (unit_nr_prv > 0) THEN
     366          13 :             IF (.NOT. tr_case) THEN
     367           2 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
     368             :             ELSE
     369          11 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
     370             :             END IF
     371             :          END IF
     372             : 
     373             :       CASE (3)
     374             : 
     375       49889 :          split_b = colsplit; split_c = colsplit
     376             :          CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
     377             :                                     transc_prv, optimize_dist=optimize_dist, &
     378             :                                     nsplit=nsplit, &
     379             :                                     opt_nsplit=batched_repl == 0, &
     380             :                                     split_rc_1=split_b, split_rc_2=split_c, &
     381             :                                     nodata2=nodata_3, comm_new=comm_tmp, &
     382       49889 :                                     move_data_1=move_b, unit_nr=unit_nr_prv)
     383       49889 :          info = dbt_tas_info(matrix_b_rs)
     384       49889 :          CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
     385             : 
     386       49889 :          new_a = .FALSE.
     387       49889 :          IF (matrix_a%do_batched <= 2) THEN
     388      183240 :             ALLOCATE (matrix_a_rs)
     389       36648 :             CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
     390       36648 :             transa_prv = .FALSE.
     391       36648 :             new_a = .TRUE.
     392             :          END IF
     393             : 
     394       49889 :          tr_case = transb_prv
     395             : 
     396      301127 :          IF (unit_nr_prv > 0) THEN
     397          10 :             IF (.NOT. tr_case) THEN
     398           0 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
     399             :             ELSE
     400          10 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
     401             :             END IF
     402             :          END IF
     403             : 
     404             :       END SELECT
     405             : 
     406      125619 :       CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
     407             : 
     408      125619 :       numproc = mp_comm%num_pe
     409      376857 :       pdims_sub = mp_comm_group%num_pe_cart
     410             : 
     411      125619 :       opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.)
     412             : 
     413      125619 :       IF (PRESENT(filter_eps)) THEN
     414      119747 :          filter_eps_prv = filter_eps
     415             :       ELSE
     416        5872 :          filter_eps_prv = 0.0_dp
     417             :       END IF
     418             : 
     419      125619 :       IF (unit_nr_prv /= 0) THEN
     420       46194 :          IF (unit_nr_prv > 0) THEN
     421          34 :             WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
     422             :          END IF
     423       46194 :          CALL dbt_tas_write_split_info(info, unit_nr_prv)
     424       46194 :          IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
     425       46194 :          IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
     426       46194 :          IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
     427       46194 :          IF (unit_nr_prv > 0) THEN
     428          34 :             IF (opt_pgrid) THEN
     429           0 :                WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
     430             :             ELSE
     431          34 :                WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
     432             :             END IF
     433             :          END IF
     434             :       END IF
     435             : 
     436      125619 :       pdims = 0
     437      125619 :       CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
     438             : 
     439             :       ! Convert DBM submatrices to optimized process grids and multiply
     440       37082 :       SELECT CASE (max_mm_dim)
     441             :       CASE (1)
     442       37082 :          IF (matrix_b%do_batched <= 2) THEN
     443      170610 :             ALLOCATE (matrix_b_rep)
     444       34122 :             CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.)
     445       34122 :             IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
     446        8574 :                matrix_b%mm_storage%store_batched_repl => matrix_b_rep
     447        8574 :                CALL dbt_tas_set_batched_state(matrix_b, state=3)
     448             :             END IF
     449        2960 :          ELSEIF (matrix_b%do_batched == 3) THEN
     450        2960 :             matrix_b_rep => matrix_b%mm_storage%store_batched_repl
     451             :          END IF
     452             : 
     453       37082 :          IF (new_b) THEN
     454       34122 :             CALL dbt_tas_destroy(matrix_b_rs)
     455       34122 :             DEALLOCATE (matrix_b_rs)
     456             :          END IF
     457       37082 :          IF (unit_nr_prv /= 0) THEN
     458         418 :             CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
     459         418 :             CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
     460             :          END IF
     461             : 
     462       37082 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
     463             : 
     464             :          ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
     465       37082 :          info_a = dbt_tas_info(matrix_a_rs)
     466       37082 :          CALL dbt_tas_info_hold(info_a)
     467             : 
     468       37082 :          IF (new_a) THEN
     469        5262 :             CALL dbt_tas_destroy(matrix_a_rs)
     470        5262 :             DEALLOCATE (matrix_a_rs)
     471             :          END IF
     472             :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
     473       37082 :                                    move_data=matrix_b%do_batched == 0)
     474             : 
     475       37082 :          info_b = dbt_tas_info(matrix_b_rep)
     476       37082 :          CALL dbt_tas_info_hold(info_b)
     477             : 
     478       37082 :          IF (matrix_b%do_batched == 0) THEN
     479       25548 :             CALL dbt_tas_destroy(matrix_b_rep)
     480       25548 :             DEALLOCATE (matrix_b_rep)
     481             :          END IF
     482             : 
     483       37082 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
     484             : 
     485       37082 :          info_c = dbt_tas_info(matrix_c_rs)
     486       37082 :          CALL dbt_tas_info_hold(info_c)
     487             : 
     488       37082 :          CALL matrix_a%dist%info%mp_comm%sync()
     489       37082 :          CALL timeset("dbt_tas_dbm", handle4)
     490       37082 :          IF (.NOT. tr_case) THEN
     491       32030 :             CALL timeset("dbt_tas_mm_1N", handle3)
     492             : 
     493             :             CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
     494             :                               matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
     495       32030 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     496       32030 :             CALL timestop(handle3)
     497             :          ELSE
     498        5052 :             CALL timeset("dbt_tas_mm_1T", handle3)
     499             :             CALL dbm_multiply(transa=.TRUE., transb=.FALSE., alpha=alpha, &
     500             :                               matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
     501        5052 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     502             : 
     503        5052 :             CALL timestop(handle3)
     504             :          END IF
     505       37082 :          CALL matrix_a%dist%info%mp_comm%sync()
     506       37082 :          CALL timestop(handle4)
     507             : 
     508       37082 :          CALL dbm_release(matrix_a_mm)
     509       37082 :          CALL dbm_release(matrix_b_mm)
     510             : 
     511       37082 :          nze_c = dbm_get_nze(matrix_c_mm)
     512             : 
     513       37082 :          IF (.NOT. new_c) THEN
     514       31974 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
     515             :          ELSE
     516        5108 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
     517             :          END IF
     518             : 
     519       37082 :          CALL dbm_release(matrix_c_mm)
     520             : 
     521       37082 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
     522             : 
     523       37082 :          IF (unit_nr_prv /= 0) THEN
     524         418 :             CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
     525             :          END IF
     526             : 
     527             :       CASE (2)
     528       38648 :          IF (matrix_c%do_batched <= 1) THEN
     529      155200 :             ALLOCATE (matrix_c_rep)
     530       31040 :             CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
     531       31040 :             IF (matrix_c%do_batched == 1) THEN
     532       24534 :                matrix_c%mm_storage%store_batched_repl => matrix_c_rep
     533       24534 :                CALL dbt_tas_set_batched_state(matrix_c, state=3)
     534             :             END IF
     535        7608 :          ELSEIF (matrix_c%do_batched == 2) THEN
     536        5780 :             ALLOCATE (matrix_c_rep)
     537        1156 :             CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
     538             :             ! just leave sparsity structure for retain sparsity but no values
     539        1156 :             IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
     540        1156 :             matrix_c%mm_storage%store_batched_repl => matrix_c_rep
     541        1156 :             CALL dbt_tas_set_batched_state(matrix_c, state=3)
     542        6452 :          ELSEIF (matrix_c%do_batched == 3) THEN
     543        6452 :             matrix_c_rep => matrix_c%mm_storage%store_batched_repl
     544             :          END IF
     545             : 
     546       38648 :          IF (unit_nr_prv /= 0) THEN
     547       20468 :             CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
     548       20468 :             CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
     549             :          END IF
     550             : 
     551       38648 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
     552             : 
     553             :          ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
     554       38648 :          info_a = dbt_tas_info(matrix_a_rs)
     555       38648 :          CALL dbt_tas_info_hold(info_a)
     556             : 
     557       38648 :          IF (new_a) THEN
     558         494 :             CALL dbt_tas_destroy(matrix_a_rs)
     559         494 :             DEALLOCATE (matrix_a_rs)
     560             :          END IF
     561             : 
     562       38648 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
     563             : 
     564       38648 :          info_b = dbt_tas_info(matrix_b_rs)
     565       38648 :          CALL dbt_tas_info_hold(info_b)
     566             : 
     567       38648 :          IF (new_b) THEN
     568         650 :             CALL dbt_tas_destroy(matrix_b_rs)
     569         650 :             DEALLOCATE (matrix_b_rs)
     570             :          END IF
     571             : 
     572       38648 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
     573             : 
     574       38648 :          info_c = dbt_tas_info(matrix_c_rep)
     575       38648 :          CALL dbt_tas_info_hold(info_c)
     576             : 
     577       38648 :          CALL matrix_a%dist%info%mp_comm%sync()
     578       38648 :          CALL timeset("dbt_tas_dbm", handle4)
     579       38648 :          CALL timeset("dbt_tas_mm_2", handle3)
     580             :          CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
     581             :                            matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
     582       38648 :                            filter_eps=filter_eps_prv/REAL(nsplit, KIND=dp), retain_sparsity=retain_sparsity, flop=flop)
     583       38648 :          CALL matrix_a%dist%info%mp_comm%sync()
     584       38648 :          CALL timestop(handle3)
     585       38648 :          CALL timestop(handle4)
     586             : 
     587       38648 :          CALL dbm_release(matrix_a_mm)
     588       38648 :          CALL dbm_release(matrix_b_mm)
     589             : 
     590       38648 :          nze_c = dbm_get_nze(matrix_c_mm)
     591             : 
     592       38648 :          CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
     593       38648 :          nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
     594             : 
     595       38648 :          CALL dbm_release(matrix_c_mm)
     596             : 
     597       38648 :          IF (unit_nr_prv /= 0) THEN
     598       20468 :             CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
     599             :          END IF
     600             : 
     601       38648 :          IF (matrix_c%do_batched == 0) THEN
     602        6506 :             CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.)
     603             :          ELSE
     604       32142 :             matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
     605             :          END IF
     606             : 
     607       38648 :          IF (matrix_c%do_batched == 0) THEN
     608        6506 :             CALL dbt_tas_destroy(matrix_c_rep)
     609        6506 :             DEALLOCATE (matrix_c_rep)
     610             :          END IF
     611             : 
     612       38648 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
     613             : 
     614             :          ! set upper limit on memory consumption for replicated matrix and complete batched mm
     615             :          ! if limit is exceeded
     616       38648 :          IF (nze_c_sum > default_nsplit_accept_ratio*MAX(nze_a, nze_b)) THEN
     617        1484 :             CALL dbt_tas_batched_mm_complete(matrix_c)
     618             :          END IF
     619             : 
     620             :       CASE (3)
     621       49889 :          IF (matrix_a%do_batched <= 2) THEN
     622      183240 :             ALLOCATE (matrix_a_rep)
     623       36648 :             CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.)
     624       36648 :             IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
     625       23000 :                matrix_a%mm_storage%store_batched_repl => matrix_a_rep
     626       23000 :                CALL dbt_tas_set_batched_state(matrix_a, state=3)
     627             :             END IF
     628       13241 :          ELSEIF (matrix_a%do_batched == 3) THEN
     629       13241 :             matrix_a_rep => matrix_a%mm_storage%store_batched_repl
     630             :          END IF
     631             : 
     632       49889 :          IF (new_a) THEN
     633       36648 :             CALL dbt_tas_destroy(matrix_a_rs)
     634       36648 :             DEALLOCATE (matrix_a_rs)
     635             :          END IF
     636       49889 :          IF (unit_nr_prv /= 0) THEN
     637       25308 :             CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
     638       25308 :             CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
     639             :          END IF
     640             : 
     641             :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
     642       49889 :                                    move_data=matrix_a%do_batched == 0)
     643             : 
     644             :          ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
     645       49889 :          info_a = dbt_tas_info(matrix_a_rep)
     646       49889 :          CALL dbt_tas_info_hold(info_a)
     647             : 
     648       49889 :          IF (matrix_a%do_batched == 0) THEN
     649       13648 :             CALL dbt_tas_destroy(matrix_a_rep)
     650       13648 :             DEALLOCATE (matrix_a_rep)
     651             :          END IF
     652             : 
     653       49889 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
     654             : 
     655       49889 :          info_b = dbt_tas_info(matrix_b_rs)
     656       49889 :          CALL dbt_tas_info_hold(info_b)
     657             : 
     658       49889 :          IF (new_b) THEN
     659          16 :             CALL dbt_tas_destroy(matrix_b_rs)
     660          16 :             DEALLOCATE (matrix_b_rs)
     661             :          END IF
     662       49889 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
     663             : 
     664       49889 :          info_c = dbt_tas_info(matrix_c_rs)
     665       49889 :          CALL dbt_tas_info_hold(info_c)
     666             : 
     667       49889 :          CALL matrix_a%dist%info%mp_comm%sync()
     668       49889 :          CALL timeset("dbt_tas_dbm", handle4)
     669       49889 :          IF (.NOT. tr_case) THEN
     670       22243 :             CALL timeset("dbt_tas_mm_3N", handle3)
     671             :             CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
     672             :                               matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
     673       22243 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     674       22243 :             CALL timestop(handle3)
     675             :          ELSE
     676       27646 :             CALL timeset("dbt_tas_mm_3T", handle3)
     677             :             CALL dbm_multiply(transa=.FALSE., transb=.TRUE., alpha=alpha, &
     678             :                               matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
     679       27646 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     680       27646 :             CALL timestop(handle3)
     681             :          END IF
     682       49889 :          CALL matrix_a%dist%info%mp_comm%sync()
     683       49889 :          CALL timestop(handle4)
     684             : 
     685       49889 :          CALL dbm_release(matrix_a_mm)
     686       49889 :          CALL dbm_release(matrix_b_mm)
     687             : 
     688       49889 :          nze_c = dbm_get_nze(matrix_c_mm)
     689             : 
     690       49889 :          IF (.NOT. new_c) THEN
     691       49145 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
     692             :          ELSE
     693         744 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
     694             :          END IF
     695             : 
     696       49889 :          CALL dbm_release(matrix_c_mm)
     697             : 
     698       49889 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
     699             : 
     700      301127 :          IF (unit_nr_prv /= 0) THEN
     701       25308 :             CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
     702             :          END IF
     703             :       END SELECT
     704             : 
     705      125619 :       CALL mp_comm_mm%free()
     706             : 
     707      125619 :       CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
     708             : 
     709      125619 :       IF (PRESENT(split_opt)) THEN
     710       66801 :          SELECT CASE (max_mm_dim)
     711             :          CASE (1, 3)
     712       66801 :             CALL mp_comm%sum(nze_c)
     713             :          CASE (2)
     714       36988 :             CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
     715       36988 :             CALL mp_comm%sum(nze_c)
     716      140777 :             CALL mp_comm%max(nze_c)
     717             : 
     718             :          END SELECT
     719      103789 :          nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
     720             :          ! ideally we should rederive the split factor from the actual sparsity of C, but
     721             :          ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
     722      103789 :          mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
     723      103789 :          CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.)
     724      103789 :          IF (unit_nr_prv > 0) THEN
     725             :             WRITE (unit_nr_prv, "(T2,A)") &
     726          10 :                "MM PARAMETERS"
     727          10 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
     728          20 :                (nze_c + numproc - 1)/numproc
     729             : 
     730          10 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
     731             :          END IF
     732             : 
     733             :       END IF
     734             : 
     735      125619 :       IF (new_c) THEN
     736       12358 :          CALL dbm_scale(matrix_c%matrix, beta)
     737             :          CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., &
     738             :                               transposed=(transc_prv .NEQV. transc), &
     739       12358 :                               move_data=.TRUE.)
     740       12358 :          CALL dbt_tas_destroy(matrix_c_rs)
     741       12358 :          DEALLOCATE (matrix_c_rs)
     742       12358 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
     743      113261 :       ELSEIF (matrix_c%do_batched > 0) THEN
     744       32416 :          IF (matrix_c%mm_storage%batched_out) THEN
     745       32142 :             matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
     746             :          END IF
     747             :       END IF
     748             : 
     749      125619 :       IF (PRESENT(move_data_a)) THEN
     750      125571 :          IF (move_data_a) CALL dbt_tas_clear(matrix_a)
     751             :       END IF
     752      125619 :       IF (PRESENT(move_data_b)) THEN
     753      125571 :          IF (move_data_b) CALL dbt_tas_clear(matrix_b)
     754             :       END IF
     755             : 
     756      125619 :       IF (PRESENT(flop)) THEN
     757       90645 :          CALL mp_comm%sum(flop)
     758       90645 :          flop = (flop + numproc - 1)/numproc
     759             :       END IF
     760             : 
     761      125619 :       IF (PRESENT(optimize_dist)) THEN
     762          48 :          IF (optimize_dist) CALL comm_tmp%free()
     763             :       END IF
     764      125619 :       IF (unit_nr_prv > 0) THEN
     765          34 :          WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
     766          34 :          WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
     767          34 :          WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
     768             :       END IF
     769             : 
     770      125619 :       CALL dbt_tas_release_info(info_a)
     771      125619 :       CALL dbt_tas_release_info(info_b)
     772      125619 :       CALL dbt_tas_release_info(info_c)
     773             : 
     774      125619 :       CALL matrix_a%dist%info%mp_comm%sync()
     775      125619 :       CALL timestop(handle2)
     776      125619 :       CALL timestop(handle)
     777             : 
     778      251238 :    END SUBROUTINE
     779             : 
     780             : ! **************************************************************************************************
     781             : !> \brief ...
     782             : !> \param matrix_in ...
     783             : !> \param matrix_out ...
     784             : !> \param local_copy ...
     785             : !> \param alpha ...
     786             : !> \author Patrick Seewald
     787             : ! **************************************************************************************************
     788      125619 :    SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
     789             :       TYPE(dbm_type), INTENT(IN)                         :: matrix_in
     790             :       TYPE(dbm_type), INTENT(INOUT)                      :: matrix_out
     791             :       LOGICAL, INTENT(IN), OPTIONAL                      :: local_copy
     792             :       REAL(dp), INTENT(IN)                               :: alpha
     793             : 
     794             :       LOGICAL                                            :: local_copy_prv
     795             :       TYPE(dbm_type)                                     :: matrix_tmp
     796             : 
     797      125619 :       IF (PRESENT(local_copy)) THEN
     798      125619 :          local_copy_prv = local_copy
     799             :       ELSE
     800             :          local_copy_prv = .FALSE.
     801             :       END IF
     802             : 
     803      125619 :       IF (alpha /= 1.0_dp) THEN
     804       74207 :          CALL dbm_scale(matrix_out, alpha)
     805             :       END IF
     806             : 
     807      125619 :       IF (.NOT. local_copy_prv) THEN
     808           0 :          CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
     809           0 :          CALL dbm_redistribute(matrix_in, matrix_tmp)
     810           0 :          CALL dbm_add(matrix_out, matrix_tmp)
     811           0 :          CALL dbm_release(matrix_tmp)
     812             :       ELSE
     813      125619 :          CALL dbm_add(matrix_out, matrix_in)
     814             :       END IF
     815             : 
     816      125619 :    END SUBROUTINE
     817             : 
     818             : ! **************************************************************************************************
     819             : !> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
     820             : !>        the same process grid as the other 2 matrices.
     821             : !> \param mp_comm communicator that defines Cartesian topology
     822             : !> \param matrix_in ...
     823             : !> \param matrix_out ...
     824             : !> \param transposed Whether matrix_out should be transposed
     825             : !> \param nodata Data of matrix_in should not be copied to matrix_out
     826             : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
     827             : !> \author Patrick Seewald
     828             : ! **************************************************************************************************
     829      617796 :    SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
     830             :       TYPE(mp_cart_type), INTENT(IN)                     :: mp_comm
     831             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_in
     832             :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix_out
     833             :       LOGICAL, INTENT(IN)                                :: transposed
     834             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata, move_data
     835             : 
     836             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reshape_mm_small'
     837             : 
     838             :       INTEGER                                            :: handle
     839             :       INTEGER(KIND=int_8), DIMENSION(2)                  :: dims
     840             :       INTEGER, DIMENSION(2)                              :: pdims
     841             :       LOGICAL                                            :: nodata_prv
     842      102966 :       TYPE(dbt_tas_dist_arb)                             :: new_col_dist, new_row_dist
     843      514830 :       TYPE(dbt_tas_distribution_type)                    :: dist
     844             : 
     845      102966 :       CALL timeset(routineN, handle)
     846             : 
     847      102966 :       IF (PRESENT(nodata)) THEN
     848       32196 :          nodata_prv = nodata
     849             :       ELSE
     850             :          nodata_prv = .FALSE.
     851             :       END IF
     852             : 
     853      308898 :       pdims = mp_comm%num_pe_cart
     854             : 
     855      308898 :       dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
     856             : 
     857      102966 :       IF (transposed) CALL swap(dims)
     858             : 
     859      102966 :       IF (.NOT. transposed) THEN
     860       78958 :          new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
     861       78958 :          new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
     862       78958 :          CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
     863             :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
     864       78958 :                              matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
     865             :       ELSE
     866       24008 :          new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
     867       24008 :          new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
     868       24008 :          CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
     869             :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
     870       24008 :                              matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
     871             :       END IF
     872      102966 :       IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
     873             : 
     874      102966 :       CALL timestop(handle)
     875             : 
     876      102966 :    END SUBROUTINE
     877             : 
     878             : ! **************************************************************************************************
     879             : !> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
     880             : !>        with the same split factor.
     881             : !> \param matrix1_in ...
     882             : !> \param matrix2_in ...
     883             : !> \param matrix1_out ...
     884             : !> \param matrix2_out ...
     885             : !> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
     886             : !> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
     887             : !> \param trans1 transpose flag of matrix1_in for multiplication
     888             : !> \param trans2 transpose flag of matrix2_in for multiplication
     889             : !> \param optimize_dist experimental: optimize matrix splitting and distribution
     890             : !> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
     891             : !> \param opt_nsplit ...
     892             : !> \param split_rc_1 Whether to split rows or columns for matrix 1
     893             : !> \param split_rc_2 Whether to split rows or columns for matrix 2
     894             : !> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
     895             : !> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
     896             : !> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
     897             : !> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
     898             : !> \param comm_new returns the new communicator only if optimize_dist
     899             : !> \param unit_nr output unit
     900             : !> \author Patrick Seewald
     901             : ! **************************************************************************************************
     902      125619 :    SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
     903             :                                     optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
     904             :                                     move_data_1, move_data_2, comm_new, unit_nr)
     905             :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix1_in, matrix2_in
     906             :       TYPE(dbt_tas_type), INTENT(OUT), POINTER           :: matrix1_out, matrix2_out
     907             :       LOGICAL, INTENT(OUT)                               :: new1, new2
     908             :       LOGICAL, INTENT(INOUT)                             :: trans1, trans2
     909             :       LOGICAL, INTENT(IN), OPTIONAL                      :: optimize_dist
     910             :       INTEGER, INTENT(IN), OPTIONAL                      :: nsplit
     911             :       LOGICAL, INTENT(IN), OPTIONAL                      :: opt_nsplit
     912             :       INTEGER, INTENT(INOUT)                             :: split_rc_1, split_rc_2
     913             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata1, nodata2
     914             :       LOGICAL, INTENT(INOUT), OPTIONAL                   :: move_data_1, move_data_2
     915             :       TYPE(mp_cart_type), INTENT(OUT), OPTIONAL          :: comm_new
     916             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
     917             : 
     918             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_compatible'
     919             : 
     920             :       INTEGER                                            :: handle, nsplit_prv, ref, split_rc_ref, &
     921             :                                                             unit_nr_prv
     922             :       INTEGER(KIND=int_8)                                :: d1, d2, nze1, nze2
     923             :       INTEGER(KIND=int_8), DIMENSION(2)                  :: dims1, dims2, dims_ref
     924             :       INTEGER, DIMENSION(2)                              :: pdims
     925             :       LOGICAL                                            :: nodata1_prv, nodata2_prv, &
     926             :                                                             optimize_dist_prv, trans1_newdist, &
     927             :                                                             trans2_newdist
     928             :       TYPE(dbt_tas_dist_cyclic)                          :: col_dist_1, col_dist_2, row_dist_1, &
     929             :                                                             row_dist_2
     930     1130571 :       TYPE(dbt_tas_distribution_type)                    :: dist_1, dist_2
     931      628095 :       TYPE(dbt_tas_split_info)                           :: split_info
     932      125619 :       TYPE(mp_cart_type)                                 :: mp_comm
     933             : 
     934      125619 :       CALL timeset(routineN, handle)
     935      125619 :       new1 = .FALSE.; new2 = .FALSE.
     936             : 
     937      125619 :       IF (PRESENT(nodata1)) THEN
     938           0 :          nodata1_prv = nodata1
     939             :       ELSE
     940             :          nodata1_prv = .FALSE.
     941             :       END IF
     942             : 
     943      125619 :       IF (PRESENT(nodata2)) THEN
     944       86971 :          nodata2_prv = nodata2
     945             :       ELSE
     946             :          nodata2_prv = .FALSE.
     947             :       END IF
     948             : 
     949      125619 :       unit_nr_prv = prep_output_unit(unit_nr)
     950             : 
     951      125619 :       NULLIFY (matrix1_out, matrix2_out)
     952             : 
     953      125619 :       IF (PRESENT(optimize_dist)) THEN
     954          48 :          optimize_dist_prv = optimize_dist
     955             :       ELSE
     956             :          optimize_dist_prv = .FALSE.
     957             :       END IF
     958             : 
     959      376857 :       dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
     960      376857 :       dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
     961      125619 :       nze1 = dbt_tas_get_nze_total(matrix1_in)
     962      125619 :       nze2 = dbt_tas_get_nze_total(matrix2_in)
     963             : 
     964      125619 :       IF (trans1) split_rc_1 = MOD(split_rc_1, 2) + 1
     965             : 
     966      125619 :       IF (trans2) split_rc_2 = MOD(split_rc_2, 2) + 1
     967             : 
     968      125619 :       IF (nze1 >= nze2) THEN
     969      120620 :          ref = 1
     970      120620 :          split_rc_ref = split_rc_1
     971      120620 :          dims_ref = dims1
     972             :       ELSE
     973        4999 :          ref = 2
     974        4999 :          split_rc_ref = split_rc_2
     975        4999 :          dims_ref = dims2
     976             :       END IF
     977             : 
     978      125619 :       IF (PRESENT(nsplit)) THEN
     979      125619 :          nsplit_prv = nsplit
     980             :       ELSE
     981           0 :          nsplit_prv = 0
     982             :       END IF
     983             : 
     984      125619 :       IF (optimize_dist_prv) THEN
     985          48 :          CPASSERT(PRESENT(comm_new))
     986             :       END IF
     987             : 
     988      125619 :       IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
     989             :          CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
     990      118461 :                            move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
     991      118461 :          CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
     992             :          CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
     993      118461 :                            move_data=move_data_2, nodata=nodata2, opt_nsplit=.FALSE.)
     994      118461 :          IF (unit_nr_prv > 0) THEN
     995          10 :             WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
     996          10 :                TRIM(dbm_get_name(matrix1_in%matrix)), &
     997          20 :                "and", TRIM(dbm_get_name(matrix2_in%matrix))
     998          10 :             IF (new1) THEN
     999           0 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1000           0 :                   TRIM(dbm_get_name(matrix1_in%matrix)), ": Yes"
    1001             :             ELSE
    1002          10 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1003          20 :                   TRIM(dbm_get_name(matrix1_in%matrix)), ": No"
    1004             :             END IF
    1005          10 :             IF (new2) THEN
    1006           0 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1007           0 :                   TRIM(dbm_get_name(matrix2_in%matrix)), ": Yes"
    1008             :             ELSE
    1009          10 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1010          20 :                   TRIM(dbm_get_name(matrix2_in%matrix)), ": No"
    1011             :             END IF
    1012             :          END IF
    1013             :       ELSE
    1014             : 
    1015        7158 :          IF (optimize_dist_prv) THEN
    1016          48 :             IF (unit_nr_prv > 0) THEN
    1017          24 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
    1018          24 :                   TRIM(dbm_get_name(matrix1_in%matrix)), &
    1019          48 :                   "and", TRIM(dbm_get_name(matrix2_in%matrix))
    1020             :             END IF
    1021             : 
    1022          48 :             trans1_newdist = (split_rc_1 == colsplit)
    1023          48 :             trans2_newdist = (split_rc_2 == colsplit)
    1024             : 
    1025          48 :             IF (trans1_newdist) THEN
    1026          24 :                CALL swap(dims1)
    1027          24 :                trans1 = .NOT. trans1
    1028             :             END IF
    1029             : 
    1030          48 :             IF (trans2_newdist) THEN
    1031          24 :                CALL swap(dims2)
    1032          24 :                trans2 = .NOT. trans2
    1033             :             END IF
    1034             : 
    1035          48 :             IF (nsplit_prv == 0) THEN
    1036           0 :                SELECT CASE (split_rc_ref)
    1037             :                CASE (rowsplit)
    1038           0 :                   d1 = dims_ref(1)
    1039           0 :                   d2 = dims_ref(2)
    1040             :                CASE (colsplit)
    1041           0 :                   d1 = dims_ref(2)
    1042           0 :                   d2 = dims_ref(1)
    1043             :                END SELECT
    1044           0 :                nsplit_prv = INT((d1 - 1)/d2 + 1)
    1045             :             END IF
    1046             : 
    1047          48 :             CPASSERT(nsplit_prv > 0)
    1048             : 
    1049          48 :             CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
    1050          48 :             comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
    1051          48 :             CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
    1052             : 
    1053         144 :             pdims = comm_new%num_pe_cart
    1054             : 
    1055             :             ! use a very simple cyclic distribution that may not be load balanced if block
    1056             :             ! sizes are not equal. However we can not use arbitrary distributions
    1057             :             ! for large dimensions since this would require storing distribution vectors as arrays
    1058             :             ! which can not be stored for large dimensions.
    1059          48 :             row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
    1060          48 :             col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
    1061             : 
    1062          48 :             row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
    1063          48 :             col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
    1064             : 
    1065          48 :             CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
    1066          48 :             CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
    1067          48 :             CALL dbt_tas_release_info(split_info)
    1068             : 
    1069         240 :             ALLOCATE (matrix1_out)
    1070          48 :             IF (.NOT. trans1_newdist) THEN
    1071             :                CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
    1072          24 :                                    matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.)
    1073             : 
    1074             :             ELSE
    1075             :                CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
    1076          24 :                                    matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.)
    1077             :             END IF
    1078             : 
    1079         240 :             ALLOCATE (matrix2_out)
    1080          48 :             IF (.NOT. trans2_newdist) THEN
    1081             :                CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
    1082          24 :                                    matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.)
    1083             :             ELSE
    1084             :                CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
    1085          24 :                                    matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.)
    1086             :             END IF
    1087             : 
    1088          48 :             IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
    1089          48 :             IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
    1090          48 :             new1 = .TRUE.
    1091          48 :             new2 = .TRUE.
    1092             : 
    1093             :          ELSE
    1094        6430 :             SELECT CASE (ref)
    1095             :             CASE (1)
    1096        6430 :                IF (unit_nr_prv > 0) THEN
    1097           0 :                   WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
    1098           0 :                      TRIM(dbm_get_name(matrix2_in%matrix))
    1099             :                END IF
    1100             : 
    1101             :                CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
    1102        6430 :                                  move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
    1103             : 
    1104       32150 :                ALLOCATE (matrix2_out)
    1105             :                CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
    1106        6430 :                                         nodata=nodata2, move_data=move_data_2)
    1107        6430 :                new2 = .TRUE.
    1108             :             CASE (2)
    1109         680 :                IF (unit_nr_prv > 0) THEN
    1110           0 :                   WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
    1111           0 :                      TRIM(dbm_get_name(matrix1_in%matrix))
    1112             :                END IF
    1113             : 
    1114             :                CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
    1115         680 :                                  move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
    1116             : 
    1117        3400 :                ALLOCATE (matrix1_out)
    1118             :                CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
    1119         680 :                                         nodata=nodata1, move_data=move_data_1)
    1120        7790 :                new1 = .TRUE.
    1121             :             END SELECT
    1122             :          END IF
    1123             :       END IF
    1124             : 
    1125      125619 :       IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE.
    1126      125619 :       IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE.
    1127             : 
    1128      125619 :       CALL timestop(handle)
    1129             : 
    1130      376857 :    END SUBROUTINE
    1131             : 
    1132             : ! **************************************************************************************************
    1133             : !> \brief Change split factor without redistribution
    1134             : !> \param matrix_in ...
    1135             : !> \param matrix_out ...
    1136             : !> \param nsplit new split factor, set to 0 to not change split of matrix_in
    1137             : !> \param split_rowcol split rows or columns
    1138             : !> \param is_new whether matrix_out is new or a pointer to matrix_in
    1139             : !> \param opt_nsplit whether nsplit should be optimized for current process grid
    1140             : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
    1141             : !> \param nodata Data of matrix_in should not be copied to matrix_out
    1142             : !> \author Patrick Seewald
    1143             : ! **************************************************************************************************
    1144      244032 :    SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
    1145             :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_in
    1146             :       TYPE(dbt_tas_type), INTENT(OUT), POINTER           :: matrix_out
    1147             :       INTEGER, INTENT(IN)                                :: nsplit, split_rowcol
    1148             :       LOGICAL, INTENT(OUT)                               :: is_new
    1149             :       LOGICAL, INTENT(IN), OPTIONAL                      :: opt_nsplit
    1150             :       LOGICAL, INTENT(INOUT), OPTIONAL                   :: move_data
    1151             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    1152             : 
    1153             :       CHARACTER(len=default_string_length)               :: name
    1154             :       INTEGER                                            :: handle, nsplit_new, nsplit_old, &
    1155             :                                                             nsplit_prv, split_rc
    1156             :       LOGICAL                                            :: nodata_prv
    1157     1220160 :       TYPE(dbt_tas_distribution_type)                    :: dist
    1158     1220160 :       TYPE(dbt_tas_split_info)                           :: split_info
    1159      244032 :       TYPE(mp_cart_type)                                 :: mp_comm
    1160             : 
    1161      976128 :       CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
    1162      488064 :       CLASS(dbt_tas_rowcol_data), ALLOCATABLE  :: rbsize, cbsize
    1163             :       CHARACTER(LEN=*), PARAMETER                :: routineN = 'change_split'
    1164             : 
    1165      244032 :       NULLIFY (matrix_out)
    1166             : 
    1167      244032 :       is_new = .TRUE.
    1168             : 
    1169             :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
    1170      244032 :                                   split_rowcol=split_rc, nsplit=nsplit_old)
    1171             : 
    1172      244032 :       IF (nsplit == 0) THEN
    1173       81068 :          IF (split_rowcol == split_rc) THEN
    1174       78421 :             matrix_out => matrix_in
    1175       78421 :             is_new = .FALSE.
    1176       78421 :             RETURN
    1177             :          ELSE
    1178        2647 :             nsplit_prv = 1
    1179             :          END IF
    1180             :       ELSE
    1181      162964 :          nsplit_prv = nsplit
    1182             :       END IF
    1183             : 
    1184      165611 :       CALL timeset(routineN, handle)
    1185             : 
    1186      165611 :       nodata_prv = .FALSE.
    1187      165611 :       IF (PRESENT(nodata)) nodata_prv = nodata
    1188             : 
    1189             :       CALL dbt_tas_get_info(matrix_in, name=name, &
    1190             :                             row_blk_size=rbsize, col_blk_size=cbsize, &
    1191             :                             proc_row_dist=rdist, proc_col_dist=cdist)
    1192             : 
    1193      165611 :       CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
    1194             : 
    1195      165611 :       CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
    1196             : 
    1197      165611 :       IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
    1198      160543 :          matrix_out => matrix_in
    1199      160543 :          is_new = .FALSE.
    1200      160543 :          CALL dbt_tas_release_info(split_info)
    1201      160543 :          CALL timestop(handle)
    1202      160543 :          RETURN
    1203             :       END IF
    1204             : 
    1205             :       CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
    1206        5068 :                                     split_info=split_info)
    1207             : 
    1208        5068 :       CALL dbt_tas_release_info(split_info)
    1209             : 
    1210       25340 :       ALLOCATE (matrix_out)
    1211        5068 :       CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.TRUE.)
    1212             : 
    1213        5068 :       IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
    1214             : 
    1215        5068 :       IF (PRESENT(move_data)) THEN
    1216        5068 :          IF (.NOT. nodata_prv) THEN
    1217        5068 :             IF (move_data) CALL dbt_tas_clear(matrix_in)
    1218        5068 :             move_data = .TRUE.
    1219             :          END IF
    1220             :       END IF
    1221             : 
    1222        5068 :       CALL timestop(handle)
    1223      864776 :    END SUBROUTINE
    1224             : 
    1225             : ! **************************************************************************************************
    1226             : !> \brief Check whether matrices have same distribution and same split.
    1227             : !> \param mat_a ...
    1228             : !> \param mat_b ...
    1229             : !> \param split_rc_a ...
    1230             : !> \param split_rc_b ...
    1231             : !> \param unit_nr ...
    1232             : !> \return ...
    1233             : !> \author Patrick Seewald
    1234             : ! **************************************************************************************************
    1235      125571 :    FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
    1236             :       TYPE(dbt_tas_type), INTENT(IN)                     :: mat_a, mat_b
    1237             :       INTEGER, INTENT(IN)                                :: split_rc_a, split_rc_b
    1238             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    1239             :       LOGICAL                                            :: dist_compatible
    1240             : 
    1241             :       INTEGER                                            :: numproc, same_local_rowcols, &
    1242             :                                                             split_check_a, split_check_b, &
    1243             :                                                             unit_nr_prv
    1244      125571 :       INTEGER(int_8), ALLOCATABLE, DIMENSION(:)          :: local_rowcols_a, local_rowcols_b
    1245             :       INTEGER, DIMENSION(2)                              :: pdims_a, pdims_b
    1246     1130139 :       TYPE(dbt_tas_split_info)                           :: info_a, info_b
    1247             : 
    1248      125571 :       unit_nr_prv = prep_output_unit(unit_nr)
    1249             : 
    1250      125571 :       dist_compatible = .FALSE.
    1251             : 
    1252      125571 :       info_a = dbt_tas_info(mat_a)
    1253      125571 :       info_b = dbt_tas_info(mat_b)
    1254      125571 :       CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
    1255      125571 :       CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
    1256      125571 :       IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
    1257        7030 :          IF (unit_nr_prv > 0) THEN
    1258           0 :             WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
    1259           0 :             WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
    1260             :          END IF
    1261        7030 :          RETURN
    1262             :       END IF
    1263             : 
    1264             :       ! check if communicators are equivalent
    1265             :       ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
    1266             :       ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
    1267      118541 :       numproc = info_b%mp_comm%num_pe
    1268      355623 :       pdims_a = info_a%mp_comm%num_pe_cart
    1269      355623 :       pdims_b = info_b%mp_comm%num_pe_cart
    1270      118541 :       IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
    1271          56 :          IF (unit_nr_prv > 0) THEN
    1272           0 :             WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
    1273             :          END IF
    1274          56 :          RETURN
    1275             :       END IF
    1276             : 
    1277             :       ! check that distribution is the same by comparing local rows / columns for each matrix
    1278       81726 :       SELECT CASE (split_rc_a)
    1279             :       CASE (rowsplit)
    1280       81726 :          CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
    1281       81726 :          CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
    1282             :       CASE (colsplit)
    1283       36759 :          CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
    1284      155244 :          CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
    1285             :       END SELECT
    1286             : 
    1287      118485 :       same_local_rowcols = MERGE(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
    1288             : 
    1289      118485 :       CALL info_a%mp_comm%sum(same_local_rowcols)
    1290             : 
    1291      118485 :       IF (same_local_rowcols == numproc) THEN
    1292             :          dist_compatible = .TRUE.
    1293             :       ELSE
    1294          24 :          IF (unit_nr_prv > 0) THEN
    1295           0 :             WRITE (unit_nr_prv, *) "local rowcols not compatible"
    1296           0 :             WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
    1297           0 :             WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
    1298             :          END IF
    1299             :       END IF
    1300             : 
    1301      125571 :    END FUNCTION
    1302             : 
    1303             : ! **************************************************************************************************
    1304             : !> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
    1305             : !> \param template ...
    1306             : !> \param matrix_in ...
    1307             : !> \param matrix_out ...
    1308             : !> \param trans ...
    1309             : !> \param split_rc ...
    1310             : !> \param nodata ...
    1311             : !> \param move_data ...
    1312             : !> \author Patrick Seewald
    1313             : ! **************************************************************************************************
    1314       42660 :    SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
    1315             :       TYPE(dbt_tas_type), INTENT(IN)                     :: template
    1316             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_in
    1317             :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix_out
    1318             :       LOGICAL, INTENT(INOUT)                             :: trans
    1319             :       INTEGER, INTENT(IN)                                :: split_rc
    1320             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata, move_data
    1321             : 
    1322       14220 :       CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
    1323             : 
    1324       42660 :       TYPE(dbt_tas_distribution_type)          :: dist_new
    1325       78210 :       TYPE(dbt_tas_split_info)                 :: info_template, info_matrix
    1326             :       INTEGER                                    :: dim_split_template, dim_split_matrix, &
    1327             :                                                     handle
    1328             :       INTEGER, DIMENSION(2)                      :: pdims
    1329             :       LOGICAL                                    :: nodata_prv, transposed
    1330        7110 :       TYPE(mp_cart_type) :: mp_comm
    1331             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template'
    1332             : 
    1333        7110 :       CALL timeset(routineN, handle)
    1334             : 
    1335        7110 :       IF (PRESENT(nodata)) THEN
    1336        5820 :          nodata_prv = nodata
    1337             :       ELSE
    1338             :          nodata_prv = .FALSE.
    1339             :       END IF
    1340             : 
    1341        7110 :       info_template = dbt_tas_info(template)
    1342        7110 :       info_matrix = dbt_tas_info(matrix_in)
    1343             : 
    1344        7110 :       dim_split_template = info_template%split_rowcol
    1345        7110 :       dim_split_matrix = split_rc
    1346             : 
    1347        7110 :       transposed = dim_split_template .NE. dim_split_matrix
    1348        7110 :       IF (transposed) trans = .NOT. trans
    1349             : 
    1350       21330 :       pdims = info_template%mp_comm%num_pe_cart
    1351             : 
    1352        1376 :       SELECT CASE (dim_split_template)
    1353             :       CASE (1)
    1354        1376 :          IF (.NOT. transposed) THEN
    1355          44 :             ALLOCATE (row_dist, source=template%dist%row_dist)
    1356          44 :             ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
    1357             :          ELSE
    1358        1332 :             ALLOCATE (row_dist, source=template%dist%row_dist)
    1359        1332 :             ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
    1360             :          END IF
    1361             :       CASE (2)
    1362        7110 :          IF (.NOT. transposed) THEN
    1363         120 :             ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
    1364         120 :             ALLOCATE (col_dist, source=template%dist%col_dist)
    1365             :          ELSE
    1366       11348 :             ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
    1367       11348 :             ALLOCATE (col_dist, source=template%dist%col_dist)
    1368             :          END IF
    1369             :       END SELECT
    1370             : 
    1371        7110 :       CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
    1372        7110 :       CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
    1373        7110 :       IF (.NOT. transposed) THEN
    1374             :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
    1375         104 :                              matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
    1376             :       ELSE
    1377             :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
    1378        7006 :                              matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
    1379             :       END IF
    1380             : 
    1381        7110 :       IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
    1382             : 
    1383        7110 :       CALL timestop(handle)
    1384             : 
    1385        7110 :    END SUBROUTINE
    1386             : 
    1387             : ! **************************************************************************************************
    1388             : !> \brief Estimate sparsity pattern of C resulting from A x B = C
    1389             : !>         by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
    1390             : !> \param transa ...
    1391             : !> \param transb ...
    1392             : !> \param transc ...
    1393             : !> \param matrix_a ...
    1394             : !> \param matrix_b ...
    1395             : !> \param matrix_c ...
    1396             : !> \param estimated_nze ...
    1397             : !> \param filter_eps ...
    1398             : !> \param unit_nr ...
    1399             : !> \param retain_sparsity ...
    1400             : !> \author Patrick Seewald
    1401             : ! **************************************************************************************************
    1402       21898 :    SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
    1403             :                                           estimated_nze, filter_eps, unit_nr, retain_sparsity)
    1404             :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
    1405             :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_a, matrix_b, matrix_c
    1406             :       INTEGER(int_8), INTENT(OUT)                        :: estimated_nze
    1407             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
    1408             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    1409             :       LOGICAL, INTENT(IN), OPTIONAL                      :: retain_sparsity
    1410             : 
    1411             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_estimate_result_nze'
    1412             : 
    1413             :       INTEGER                                            :: col_size, handle, row_size
    1414             :       INTEGER(int_8)                                     :: col, row
    1415             :       LOGICAL                                            :: retain_sparsity_prv
    1416             :       TYPE(dbt_tas_iterator)                             :: iter
    1417             :       TYPE(dbt_tas_type), POINTER                        :: matrix_a_bnorm, matrix_b_bnorm, &
    1418             :                                                             matrix_c_bnorm
    1419       21898 :       TYPE(mp_cart_type)                                 :: mp_comm
    1420             : 
    1421       21898 :       CALL timeset(routineN, handle)
    1422             : 
    1423       21898 :       IF (PRESENT(retain_sparsity)) THEN
    1424         116 :          retain_sparsity_prv = retain_sparsity
    1425             :       ELSE
    1426             :          retain_sparsity_prv = .FALSE.
    1427             :       END IF
    1428             : 
    1429         116 :       IF (.NOT. retain_sparsity_prv) THEN
    1430      413858 :          ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
    1431       21782 :          CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
    1432       21782 :          CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
    1433       21782 :          CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.)
    1434             : 
    1435             :          CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
    1436             :                                matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
    1437             :                                filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., &
    1438       21782 :                                simple_split=.TRUE., unit_nr=unit_nr)
    1439       21782 :          CALL dbt_tas_destroy(matrix_a_bnorm)
    1440       21782 :          CALL dbt_tas_destroy(matrix_b_bnorm)
    1441             : 
    1442       21782 :          DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
    1443             :       ELSE
    1444             :          matrix_c_bnorm => matrix_c
    1445             :       END IF
    1446             : 
    1447       21898 :       estimated_nze = 0
    1448             : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
    1449       21898 : !$OMP PRIVATE(iter,row,col,row_size,col_size)
    1450             :       CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
    1451             :       DO WHILE (dbt_tas_iterator_blocks_left(iter))
    1452             :          CALL dbt_tas_iterator_next_block(iter, row, col)
    1453             :          row_size = matrix_c%row_blk_size%data(row)
    1454             :          col_size = matrix_c%col_blk_size%data(col)
    1455             :          estimated_nze = estimated_nze + row_size*col_size
    1456             :       END DO
    1457             :       CALL dbt_tas_iterator_stop(iter)
    1458             : !$OMP END PARALLEL
    1459             : 
    1460       21898 :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
    1461       21898 :       CALL mp_comm%sum(estimated_nze)
    1462             : 
    1463       21898 :       IF (.NOT. retain_sparsity_prv) THEN
    1464       21782 :          CALL dbt_tas_destroy(matrix_c_bnorm)
    1465       21782 :          DEALLOCATE (matrix_c_bnorm)
    1466             :       END IF
    1467             : 
    1468       21898 :       CALL timestop(handle)
    1469             : 
    1470       43796 :    END SUBROUTINE
    1471             : 
    1472             : ! **************************************************************************************************
    1473             : !> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
    1474             : !>        This estimate is based on the minimization of communication volume whereby the
    1475             : !>        communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
    1476             : !> \param max_mm_dim ...
    1477             : !> \param nze_a number of non-zeroes in A
    1478             : !> \param nze_b number of non-zeroes in B
    1479             : !> \param nze_c number of non-zeroes in C
    1480             : !> \param numnodes number of MPI ranks
    1481             : !> \return estimated split factor
    1482             : !> \author Patrick Seewald
    1483             : ! **************************************************************************************************
    1484      125687 :    FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
    1485             :       INTEGER, INTENT(IN)                                :: max_mm_dim
    1486             :       INTEGER(KIND=int_8), INTENT(IN)                    :: nze_a, nze_b, nze_c
    1487             :       INTEGER, INTENT(IN)                                :: numnodes
    1488             :       INTEGER                                            :: nsplit
    1489             : 
    1490             :       INTEGER(KIND=int_8)                                :: max_nze, min_nze
    1491             :       REAL(dp)                                           :: s_opt_factor
    1492             : 
    1493      125687 :       s_opt_factor = 1.0_dp ! Could be further tuned.
    1494             : 
    1495      162753 :       SELECT CASE (max_mm_dim)
    1496             :       CASE (1)
    1497       37066 :          min_nze = MAX(nze_b, 1_int_8)
    1498      111198 :          max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8)
    1499             :       CASE (2)
    1500       38632 :          min_nze = MAX(nze_c, 1_int_8)
    1501      115896 :          max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8)
    1502             :       CASE (3)
    1503       49989 :          min_nze = MAX(nze_a, 1_int_8)
    1504      149967 :          max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8)
    1505             :       CASE DEFAULT
    1506      125687 :          CPABORT("")
    1507             :       END SELECT
    1508             : 
    1509      125687 :       nsplit = INT(MIN(INT(numnodes, KIND=int_8), NINT(REAL(max_nze, dp)/(REAL(min_nze, dp)*s_opt_factor), KIND=int_8)))
    1510      125687 :       IF (nsplit == 0) nsplit = 1
    1511             : 
    1512      125687 :    END FUNCTION
    1513             : 
    1514             : ! **************************************************************************************************
    1515             : !> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
    1516             : !> \param matrix_in ...
    1517             : !> \param matrix_out ...
    1518             : !> \param nodata ...
    1519             : !> \author Patrick Seewald
    1520             : ! **************************************************************************************************
    1521      326730 :    SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
    1522             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_in
    1523             :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix_out
    1524             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    1525             : 
    1526             :       CHARACTER(len=default_string_length)               :: name
    1527             :       INTEGER(KIND=int_8)                                :: column, nblkcols, nblkrows, row
    1528             :       LOGICAL                                            :: nodata_prv
    1529             :       REAL(dp), DIMENSION(1, 1)                          :: blk_put
    1530       65346 :       REAL(dp), DIMENSION(:, :), POINTER                 :: blk_get
    1531             :       TYPE(dbt_tas_blk_size_one)                         :: col_blk_size, row_blk_size
    1532             :       TYPE(dbt_tas_iterator)                             :: iter
    1533             : 
    1534             : !REAL(dp), DIMENSION(:, :), POINTER        :: dbt_put
    1535             : 
    1536       65346 :       CPASSERT(matrix_in%valid)
    1537             : 
    1538       65346 :       IF (PRESENT(nodata)) THEN
    1539       21782 :          nodata_prv = nodata
    1540             :       ELSE
    1541             :          nodata_prv = .FALSE.
    1542             :       END IF
    1543             : 
    1544       65346 :       CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
    1545       65346 :       row_blk_size = dbt_tas_blk_size_one(nblkrows)
    1546       65346 :       col_blk_size = dbt_tas_blk_size_one(nblkcols)
    1547             : 
    1548             :       ! not sure if assumption that same distribution can be taken still holds
    1549       65346 :       CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
    1550             : 
    1551       65346 :       IF (.NOT. nodata_prv) THEN
    1552       43564 :          CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
    1553             : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
    1554       43564 : !$OMP PRIVATE(iter,row,column,blk_get,blk_put)
    1555             :          CALL dbt_tas_iterator_start(iter, matrix_in)
    1556             :          DO WHILE (dbt_tas_iterator_blocks_left(iter))
    1557             :             CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
    1558             :             blk_put(1, 1) = NORM2(blk_get)
    1559             :             CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
    1560             :          END DO
    1561             :          CALL dbt_tas_iterator_stop(iter)
    1562             : !$OMP END PARALLEL
    1563             :       END IF
    1564             : 
    1565       65346 :    END SUBROUTINE
    1566             : 
    1567             : ! **************************************************************************************************
    1568             : !> \brief Convert a DBM matrix to a new process grid
    1569             : !> \param mp_comm_cart new process grid
    1570             : !> \param matrix_in ...
    1571             : !> \param matrix_out ...
    1572             : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
    1573             : !> \param nodata Data of matrix_in should not be copied to matrix_out
    1574             : !> \param optimize_pgrid Whether to change process grid
    1575             : !> \author Patrick Seewald
    1576             : ! **************************************************************************************************
    1577      376857 :    SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
    1578             :       TYPE(mp_cart_type), INTENT(IN)                     :: mp_comm_cart
    1579             :       TYPE(dbm_type), INTENT(INOUT)                      :: matrix_in
    1580             :       TYPE(dbm_type), INTENT(OUT)                        :: matrix_out
    1581             :       LOGICAL, INTENT(IN), OPTIONAL                      :: move_data, nodata, optimize_pgrid
    1582             : 
    1583             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_to_new_pgrid'
    1584             : 
    1585             :       CHARACTER(len=default_string_length)               :: name
    1586             :       INTEGER                                            :: handle, nbcols, nbrows
    1587      376857 :       INTEGER, CONTIGUOUS, DIMENSION(:), POINTER         :: col_dist, rbsize, rcsize, row_dist
    1588             :       INTEGER, DIMENSION(2)                              :: pdims
    1589             :       LOGICAL                                            :: nodata_prv, optimize_pgrid_prv
    1590             :       TYPE(dbm_distribution_obj)                         :: dist, dist_old
    1591             : 
    1592      376857 :       NULLIFY (row_dist, col_dist, rbsize, rcsize)
    1593             : 
    1594      376857 :       CALL timeset(routineN, handle)
    1595             : 
    1596      376857 :       IF (PRESENT(optimize_pgrid)) THEN
    1597      376857 :          optimize_pgrid_prv = optimize_pgrid
    1598             :       ELSE
    1599             :          optimize_pgrid_prv = .TRUE.
    1600             :       END IF
    1601             : 
    1602      376857 :       IF (PRESENT(nodata)) THEN
    1603      125619 :          nodata_prv = nodata
    1604             :       ELSE
    1605             :          nodata_prv = .FALSE.
    1606             :       END IF
    1607             : 
    1608      376857 :       name = dbm_get_name(matrix_in)
    1609             : 
    1610      376857 :       IF (.NOT. optimize_pgrid_prv) THEN
    1611      376857 :          CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
    1612      376857 :          IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
    1613      376857 :          CALL timestop(handle)
    1614      376857 :          RETURN
    1615             :       END IF
    1616             : 
    1617           0 :       rbsize => dbm_get_row_block_sizes(matrix_in)
    1618           0 :       rcsize => dbm_get_col_block_sizes(matrix_in)
    1619           0 :       nbrows = SIZE(rbsize)
    1620           0 :       nbcols = SIZE(rcsize)
    1621           0 :       dist_old = dbm_get_distribution(matrix_in)
    1622           0 :       pdims = mp_comm_cart%num_pe_cart
    1623             : 
    1624           0 :       ALLOCATE (row_dist(nbrows), col_dist(nbcols))
    1625           0 :       CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
    1626           0 :       CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
    1627             : 
    1628           0 :       CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
    1629           0 :       DEALLOCATE (row_dist, col_dist)
    1630             : 
    1631           0 :       CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
    1632           0 :       CALL dbm_distribution_release(dist)
    1633             : 
    1634           0 :       IF (.NOT. nodata_prv) THEN
    1635           0 :          CALL dbm_redistribute(matrix_in, matrix_out)
    1636           0 :          IF (PRESENT(move_data)) THEN
    1637           0 :             IF (move_data) CALL dbm_clear(matrix_in)
    1638             :          END IF
    1639             :       END IF
    1640             : 
    1641           0 :       CALL timestop(handle)
    1642      376857 :    END SUBROUTINE
    1643             : 
    1644             : ! **************************************************************************************************
    1645             : !> \brief ...
    1646             : !> \param matrix ...
    1647             : !> \author Patrick Seewald
    1648             : ! **************************************************************************************************
    1649       68117 :    SUBROUTINE dbt_tas_batched_mm_init(matrix)
    1650             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1651             : 
    1652       68117 :       CALL dbt_tas_set_batched_state(matrix, state=1)
    1653       68117 :       ALLOCATE (matrix%mm_storage)
    1654             :       matrix%mm_storage%batched_out = .FALSE.
    1655       68117 :    END SUBROUTINE
    1656             : 
    1657             : ! **************************************************************************************************
    1658             : !> \brief ...
    1659             : !> \param matrix ...
    1660             : !> \author Patrick Seewald
    1661             : ! **************************************************************************************************
    1662      136234 :    SUBROUTINE dbt_tas_batched_mm_finalize(matrix)
    1663             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1664             : 
    1665             :       INTEGER                                            :: handle
    1666             : 
    1667       68117 :       CALL matrix%dist%info%mp_comm%sync()
    1668       68117 :       CALL timeset("dbt_tas_total", handle)
    1669             : 
    1670       68117 :       IF (matrix%do_batched == 0) RETURN
    1671             : 
    1672       68117 :       IF (matrix%mm_storage%batched_out) THEN
    1673       24534 :          CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
    1674             :       END IF
    1675             : 
    1676       68117 :       CALL dbt_tas_batched_mm_complete(matrix)
    1677             : 
    1678       68117 :       matrix%mm_storage%batched_out = .FALSE.
    1679             : 
    1680       68117 :       DEALLOCATE (matrix%mm_storage)
    1681       68117 :       CALL dbt_tas_set_batched_state(matrix, state=0)
    1682             : 
    1683       68117 :       CALL matrix%dist%info%mp_comm%sync()
    1684       68117 :       CALL timestop(handle)
    1685             : 
    1686             :    END SUBROUTINE
    1687             : 
    1688             : ! **************************************************************************************************
    1689             : !> \brief set state flags during batched multiplication
    1690             : !> \param matrix ...
    1691             : !> \param state 0 no batched MM
    1692             : !>              1 batched MM but mm_storage not yet initialized
    1693             : !>              2 batched MM and mm_storage requires update
    1694             : !>              3 batched MM and mm_storage initialized
    1695             : !> \param opt_grid whether process grid was already optimized and should not be changed
    1696             : !> \author Patrick Seewald
    1697             : ! **************************************************************************************************
    1698      956837 :    SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
    1699             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1700             :       INTEGER, INTENT(IN), OPTIONAL                      :: state
    1701             :       LOGICAL, INTENT(IN), OPTIONAL                      :: opt_grid
    1702             : 
    1703      956837 :       IF (PRESENT(opt_grid)) THEN
    1704      691618 :          matrix%has_opt_pgrid = opt_grid
    1705      691618 :          matrix%dist%info%strict_split(1) = .TRUE.
    1706             :       END IF
    1707             : 
    1708      956837 :       IF (PRESENT(state)) THEN
    1709      717086 :          matrix%do_batched = state
    1710      486471 :          SELECT CASE (state)
    1711             :          CASE (0, 1)
    1712             :             ! reset to default
    1713      486471 :             IF (matrix%has_opt_pgrid) THEN
    1714      349590 :                matrix%dist%info%strict_split(1) = .TRUE.
    1715             :             ELSE
    1716      136881 :                matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
    1717             :             END IF
    1718             :          CASE (2, 3)
    1719      230615 :             matrix%dist%info%strict_split(1) = .TRUE.
    1720             :          CASE DEFAULT
    1721      717086 :             CPABORT("should not happen")
    1722             :          END SELECT
    1723             :       END IF
    1724      956837 :    END SUBROUTINE
    1725             : 
    1726             : ! **************************************************************************************************
    1727             : !> \brief ...
    1728             : !> \param matrix ...
    1729             : !> \param warn ...
    1730             : !> \author Patrick Seewald
    1731             : ! **************************************************************************************************
    1732      780209 :    SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
    1733             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1734             :       LOGICAL, INTENT(IN), OPTIONAL                      :: warn
    1735             : 
    1736      780209 :       IF (matrix%do_batched == 0) RETURN
    1737             :       ASSOCIATE (storage => matrix%mm_storage)
    1738       71721 :          IF (PRESENT(warn)) THEN
    1739        1840 :             IF (warn .AND. matrix%do_batched == 3) THEN
    1740             :                CALL cp_warn(__LOCATION__, &
    1741           0 :                             "Optimizations for batched multiplication are disabled because of conflicting data access")
    1742             :             END IF
    1743             :          END IF
    1744       71721 :          IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
    1745             : 
    1746             :             CALL dbt_tas_merge(storage%store_batched%matrix, &
    1747       25690 :                                storage%store_batched_repl, move_data=.TRUE.)
    1748             : 
    1749             :             CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.TRUE., &
    1750       25690 :                                  transposed=storage%batched_trans, move_data=.TRUE.)
    1751       25690 :             CALL dbt_tas_destroy(storage%store_batched)
    1752       25690 :             DEALLOCATE (storage%store_batched)
    1753             :          END IF
    1754             : 
    1755      143442 :          IF (ASSOCIATED(storage%store_batched_repl)) THEN
    1756       57264 :             CALL dbt_tas_destroy(storage%store_batched_repl)
    1757       57264 :             DEALLOCATE (storage%store_batched_repl)
    1758             :          END IF
    1759             :       END ASSOCIATE
    1760             : 
    1761       71721 :       CALL dbt_tas_set_batched_state(matrix, state=2)
    1762             : 
    1763             :    END SUBROUTINE
    1764             : 
    1765      258300 : END MODULE

Generated by: LCOV version 1.15