LCOV - code coverage report
Current view: top level - src/dbt/tas - dbt_tas_mm.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 92.7 % 742 688
Test Date: 2025-07-25 12:55:17 Functions: 100.0 % 15 15

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Matrix multiplication for tall-and-skinny matrices.
      10              : !>        This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
      11              : !>        as long as the two smaller dimensions have the same size.
      12              : !>        Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
      13              : !>        submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
      14              : !>        parameters (group sizes and process grid dimensions) can not be derived from matrix
      15              : !>        dimensions and need to be set manually.
      16              : !> \author Patrick Seewald
      17              : ! **************************************************************************************************
      18              : MODULE dbt_tas_mm
      19              :    USE dbm_api,                         ONLY: &
      20              :         dbm_add, dbm_clear, dbm_copy, dbm_create, dbm_create_from_template, dbm_distribution_new, &
      21              :         dbm_distribution_obj, dbm_distribution_release, dbm_get_col_block_sizes, &
      22              :         dbm_get_distribution, dbm_get_name, dbm_get_nze, dbm_get_row_block_sizes, dbm_multiply, &
      23              :         dbm_redistribute, dbm_release, dbm_scale, dbm_type, dbm_zero
      24              :    USE dbt_tas_base,                    ONLY: &
      25              :         dbt_tas_clear, dbt_tas_copy, dbt_tas_create, dbt_tas_destroy, dbt_tas_distribution_new, &
      26              :         dbt_tas_filter, dbt_tas_get_info, dbt_tas_get_nze_total, dbt_tas_info, &
      27              :         dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
      28              :         dbt_tas_iterator_stop, dbt_tas_nblkcols_total, dbt_tas_nblkrows_total, dbt_tas_put_block, &
      29              :         dbt_tas_reserve_blocks
      30              :    USE dbt_tas_global,                  ONLY: dbt_tas_blk_size_one,&
      31              :                                               dbt_tas_default_distvec,&
      32              :                                               dbt_tas_dist_arb,&
      33              :                                               dbt_tas_dist_arb_default,&
      34              :                                               dbt_tas_dist_cyclic,&
      35              :                                               dbt_tas_distribution,&
      36              :                                               dbt_tas_rowcol_data
      37              :    USE dbt_tas_io,                      ONLY: dbt_tas_write_dist,&
      38              :                                               dbt_tas_write_matrix_info,&
      39              :                                               dbt_tas_write_split_info,&
      40              :                                               prep_output_unit
      41              :    USE dbt_tas_reshape_ops,             ONLY: dbt_tas_merge,&
      42              :                                               dbt_tas_replicate,&
      43              :                                               dbt_tas_reshape
      44              :    USE dbt_tas_split,                   ONLY: &
      45              :         accept_pgrid_dims, colsplit, dbt_tas_create_split, dbt_tas_get_split_info, &
      46              :         dbt_tas_info_hold, dbt_tas_mp_comm, dbt_tas_release_info, default_nsplit_accept_ratio, &
      47              :         rowsplit
      48              :    USE dbt_tas_types,                   ONLY: dbt_tas_distribution_type,&
      49              :                                               dbt_tas_iterator,&
      50              :                                               dbt_tas_split_info,&
      51              :                                               dbt_tas_type
      52              :    USE dbt_tas_util,                    ONLY: array_eq,&
      53              :                                               swap
      54              :    USE kinds,                           ONLY: default_string_length,&
      55              :                                               dp,&
      56              :                                               int_8
      57              :    USE message_passing,                 ONLY: mp_cart_type
      58              : #include "../../base/base_uses.f90"
      59              : 
      60              :    IMPLICIT NONE
      61              :    PRIVATE
      62              : 
      63              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
      64              : 
      65              :    PUBLIC :: &
      66              :       dbt_tas_multiply, &
      67              :       dbt_tas_batched_mm_init, &
      68              :       dbt_tas_batched_mm_finalize, &
      69              :       dbt_tas_set_batched_state, &
      70              :       dbt_tas_batched_mm_complete
      71              : 
      72              : CONTAINS
      73              : 
      74              : ! **************************************************************************************************
      75              : !> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
      76              : !>        to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
      77              : !> \param transa ...
      78              : !> \param transb ...
      79              : !> \param transc ...
      80              : !> \param alpha ...
      81              : !> \param matrix_a ...
      82              : !> \param matrix_b ...
      83              : !> \param beta ...
      84              : !> \param matrix_c ...
      85              : !> \param optimize_dist Whether distribution should be optimized internally. In the current
      86              : !>                      implementation this guarantees optimal parameters only for dense matrices.
      87              : !> \param split_opt optionally return split info containing optimal grid and split parameters.
      88              : !>                  This can be used to choose optimal process grids for subsequent matrix
      89              : !>                  multiplications with matrices of similar shape and sparsity.
      90              : !> \param filter_eps ...
      91              : !> \param flop ...
      92              : !> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
      93              : !>                   (for internal use only)
      94              : !> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
      95              : !>                   (for internal use only)
      96              : !> \param retain_sparsity ...
      97              : !> \param simple_split ...
      98              : !> \param unit_nr unit number for logging output
      99              : !> \param log_verbose only for testing: verbose output
     100              : !> \author Patrick Seewald
     101              : ! **************************************************************************************************
     102       910896 :    RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
     103              :                                          optimize_dist, split_opt, filter_eps, flop, move_data_a, &
     104              :                                          move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
     105              : 
     106              :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
     107              :       REAL(dp), INTENT(IN)                               :: alpha
     108              :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_a, matrix_b
     109              :       REAL(dp), INTENT(IN)                               :: beta
     110              :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_c
     111              :       LOGICAL, INTENT(IN), OPTIONAL                      :: optimize_dist
     112              :       TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL    :: split_opt
     113              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
     114              :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL         :: flop
     115              :       LOGICAL, INTENT(IN), OPTIONAL                      :: move_data_a, move_data_b, &
     116              :                                                             retain_sparsity, simple_split
     117              :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
     118              :       LOGICAL, INTENT(IN), OPTIONAL                      :: log_verbose
     119              : 
     120              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'dbt_tas_multiply'
     121              : 
     122              :       INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
     123              :          nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
     124              :          unit_nr_prv
     125              :       INTEGER(KIND=int_8)                                :: nze_a, nze_b, nze_c, nze_c_sum
     126              :       INTEGER(KIND=int_8), DIMENSION(2)                  :: dims_a, dims_b, dims_c
     127              :       INTEGER(KIND=int_8), DIMENSION(3)                  :: dims
     128              :       INTEGER, DIMENSION(2)                              :: pdims, pdims_sub
     129              :       LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
     130              :          simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
     131              :       REAL(KIND=dp)                                      :: filter_eps_prv
     132              :       TYPE(dbm_type)                                     :: matrix_a_mm, matrix_b_mm, matrix_c_mm
     133      3543786 :       TYPE(dbt_tas_split_info)                           :: info, info_a, info_b, info_c
     134              :       TYPE(dbt_tas_type), POINTER                        :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
     135              :                                                             matrix_b_rs, matrix_c_rep, matrix_c_rs
     136       208458 :       TYPE(mp_cart_type)                                 :: comm_tmp, mp_comm, mp_comm_group, &
     137       208458 :                                                             mp_comm_mm, mp_comm_opt
     138              : 
     139       208458 :       CALL timeset(routineN, handle)
     140       208458 :       CALL matrix_a%dist%info%mp_comm%sync()
     141       208458 :       CALL timeset("dbt_tas_total", handle2)
     142              : 
     143       208458 :       NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
     144              : 
     145       208458 :       unit_nr_prv = prep_output_unit(unit_nr)
     146              : 
     147       208458 :       IF (PRESENT(simple_split)) THEN
     148        56594 :          simple_split_prv = simple_split
     149              :       ELSE
     150       151864 :          simple_split_prv = .FALSE.
     151              : 
     152       455592 :          info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
     153       151864 :          IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .TRUE.
     154              :       END IF
     155              : 
     156       208458 :       nodata_3 = .TRUE.
     157       208458 :       IF (PRESENT(retain_sparsity)) THEN
     158         4794 :          IF (retain_sparsity) nodata_3 = .FALSE.
     159              :       END IF
     160              : 
     161              :       ! get prestored info for multiplication strategy in case of batched mm
     162       208458 :       batched_repl = 0
     163       208458 :       do_batched = .FALSE.
     164       208458 :       IF (matrix_a%do_batched > 0) THEN
     165        49248 :          do_batched = .TRUE.
     166        49248 :          IF (matrix_a%do_batched == 3) THEN
     167              :             CPASSERT(batched_repl == 0)
     168        17746 :             batched_repl = 1
     169              :             CALL dbt_tas_get_split_info( &
     170              :                dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
     171        17746 :                nsplit=nsplit_batched)
     172        17746 :             CPASSERT(nsplit_batched > 0)
     173              :             max_mm_dim_batched = 3
     174              :          END IF
     175              :       END IF
     176              : 
     177       208458 :       IF (matrix_b%do_batched > 0) THEN
     178        15182 :          do_batched = .TRUE.
     179        15182 :          IF (matrix_b%do_batched == 3) THEN
     180         2816 :             CPASSERT(batched_repl == 0)
     181         2816 :             batched_repl = 2
     182              :             CALL dbt_tas_get_split_info( &
     183              :                dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
     184         2816 :                nsplit=nsplit_batched)
     185         2816 :             CPASSERT(nsplit_batched > 0)
     186              :             max_mm_dim_batched = 1
     187              :          END IF
     188              :       END IF
     189              : 
     190       208458 :       IF (matrix_c%do_batched > 0) THEN
     191        37992 :          do_batched = .TRUE.
     192        37992 :          IF (matrix_c%do_batched == 3) THEN
     193         8964 :             CPASSERT(batched_repl == 0)
     194         8964 :             batched_repl = 3
     195              :             CALL dbt_tas_get_split_info( &
     196              :                dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
     197         8964 :                nsplit=nsplit_batched)
     198         8964 :             CPASSERT(nsplit_batched > 0)
     199              :             max_mm_dim_batched = 2
     200              :          END IF
     201              :       END IF
     202              : 
     203       208458 :       move_a = .FALSE.
     204       208458 :       move_b = .FALSE.
     205              : 
     206       208458 :       IF (PRESENT(move_data_a)) move_a = move_data_a
     207       208458 :       IF (PRESENT(move_data_b)) move_b = move_data_b
     208              : 
     209       208458 :       transa_prv = transa; transb_prv = transb; transc_prv = transc
     210              : 
     211       625374 :       dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
     212       625374 :       dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
     213       625374 :       dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
     214              : 
     215       208458 :       IF (unit_nr_prv > 0) THEN
     216           34 :          WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
     217              :          WRITE (unit_nr_prv, "(A)") &
     218              :             "DBT TAS MATRIX MULTIPLICATION: "// &
     219              :             TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
     220              :             TRIM(dbm_get_name(matrix_b%matrix))//" = "// &
     221           34 :             TRIM(dbm_get_name(matrix_c%matrix))
     222           34 :          WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
     223              :       END IF
     224       208458 :       IF (do_batched) THEN
     225        98686 :          IF (unit_nr_prv > 0) THEN
     226              :             WRITE (unit_nr_prv, "(T2,A)") &
     227            0 :                "BATCHED PROCESSING OF MATMUL"
     228            0 :             IF (batched_repl > 0) THEN
     229            0 :                WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
     230              :             END IF
     231              :          END IF
     232              :       END IF
     233              : 
     234       208458 :       IF (transa_prv) THEN
     235        65612 :          CALL swap(dims_a)
     236              :       END IF
     237              : 
     238       208458 :       IF (transb_prv) THEN
     239       106340 :          CALL swap(dims_b)
     240              :       END IF
     241              : 
     242       625374 :       dims_c = [dims_a(1), dims_b(2)]
     243              : 
     244       208458 :       IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
     245            0 :          CPABORT("inconsistent matrix dimensions")
     246              :       END IF
     247              : 
     248       833832 :       dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
     249              : 
     250       208458 :       IF (unit_nr_prv > 0) THEN
     251           34 :          WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
     252              :       END IF
     253              : 
     254       208458 :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
     255       208458 :       numproc = mp_comm%num_pe
     256              : 
     257              :       ! derive optimal matrix layout and split factor from occupancies
     258       208458 :       nze_a = dbt_tas_get_nze_total(matrix_a)
     259       208458 :       nze_b = dbt_tas_get_nze_total(matrix_b)
     260              : 
     261       208458 :       IF (.NOT. simple_split_prv) THEN
     262              :          CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
     263              :                                           estimated_nze=nze_c, filter_eps=filter_eps, &
     264        56712 :                                           retain_sparsity=retain_sparsity)
     265              : 
     266       226848 :          max_mm_dim = MAXLOC(dims, 1)
     267        56712 :          nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
     268        56712 :          nsplit_opt = nsplit
     269              : 
     270        56712 :          IF (unit_nr_prv > 0) THEN
     271              :             WRITE (unit_nr_prv, "(T2,A)") &
     272           34 :                "MM PARAMETERS"
     273           34 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
     274           68 :                (nze_c + numproc - 1)/numproc
     275              : 
     276           34 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
     277              :          END IF
     278              : 
     279       151746 :       ELSEIF (batched_repl > 0) THEN
     280        29526 :          nsplit = nsplit_batched
     281        29526 :          nsplit_opt = nsplit
     282        29526 :          max_mm_dim = max_mm_dim_batched
     283        29526 :          IF (unit_nr_prv > 0) THEN
     284              :             WRITE (unit_nr_prv, "(T2,A)") &
     285            0 :                "MM PARAMETERS"
     286            0 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
     287              :          END IF
     288              : 
     289              :       ELSE
     290       122220 :          nsplit = 0
     291       488880 :          max_mm_dim = MAXLOC(dims, 1)
     292              :       END IF
     293              : 
     294              :       ! reshape matrices to the optimal layout and split factor
     295       208458 :       split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
     296        59176 :       SELECT CASE (max_mm_dim)
     297              :       CASE (1)
     298              : 
     299              :          split_a = rowsplit; split_c = rowsplit
     300              :          CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
     301              :                                     new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
     302              :                                     nsplit=nsplit, &
     303              :                                     opt_nsplit=batched_repl == 0, &
     304              :                                     split_rc_1=split_a, split_rc_2=split_c, &
     305              :                                     nodata2=nodata_3, comm_new=comm_tmp, &
     306        59176 :                                     move_data_1=move_a, unit_nr=unit_nr_prv)
     307              : 
     308        59176 :          info = dbt_tas_info(matrix_a_rs)
     309        59176 :          CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
     310              : 
     311        59176 :          new_b = .FALSE.
     312        59176 :          IF (matrix_b%do_batched <= 2) THEN
     313       281800 :             ALLOCATE (matrix_b_rs)
     314        56360 :             CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
     315        56360 :             transb_prv = .FALSE.
     316        56360 :             new_b = .TRUE.
     317              :          END IF
     318              : 
     319        59176 :          tr_case = transa_prv
     320              : 
     321       118363 :          IF (unit_nr_prv > 0) THEN
     322           11 :             IF (.NOT. tr_case) THEN
     323           11 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
     324              :             ELSE
     325            0 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
     326              :             END IF
     327              :          END IF
     328              : 
     329              :       CASE (2)
     330              : 
     331        64112 :          split_a = colsplit; split_b = rowsplit
     332              :          CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
     333              :                                     optimize_dist=optimize_dist, &
     334              :                                     nsplit=nsplit, &
     335              :                                     opt_nsplit=batched_repl == 0, &
     336              :                                     split_rc_1=split_a, split_rc_2=split_b, &
     337              :                                     comm_new=comm_tmp, &
     338        64112 :                                     move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
     339              : 
     340        64112 :          info = dbt_tas_info(matrix_a_rs)
     341        64112 :          CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
     342              : 
     343        64112 :          IF (matrix_c%do_batched == 1) THEN
     344        27562 :             matrix_c%mm_storage%batched_beta = beta
     345        36550 :          ELSEIF (matrix_c%do_batched > 1) THEN
     346        10116 :             matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
     347              :          END IF
     348              : 
     349        64112 :          IF (matrix_c%do_batched <= 2) THEN
     350       275740 :             ALLOCATE (matrix_c_rs)
     351        55148 :             CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
     352        55148 :             transc_prv = .FALSE.
     353              : 
     354              :             ! just leave sparsity structure for retain sparsity but no values
     355        55148 :             IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
     356              : 
     357        55148 :             IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
     358         8964 :          ELSEIF (matrix_c%do_batched == 3) THEN
     359         8964 :             matrix_c_rs => matrix_c%mm_storage%store_batched
     360              :          END IF
     361              : 
     362        64112 :          new_c = matrix_c%do_batched == 0
     363        64112 :          tr_case = transa_prv
     364              : 
     365       128237 :          IF (unit_nr_prv > 0) THEN
     366           13 :             IF (.NOT. tr_case) THEN
     367            2 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
     368              :             ELSE
     369           11 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
     370              :             END IF
     371              :          END IF
     372              : 
     373              :       CASE (3)
     374              : 
     375        85170 :          split_b = colsplit; split_c = colsplit
     376              :          CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
     377              :                                     transc_prv, optimize_dist=optimize_dist, &
     378              :                                     nsplit=nsplit, &
     379              :                                     opt_nsplit=batched_repl == 0, &
     380              :                                     split_rc_1=split_b, split_rc_2=split_c, &
     381              :                                     nodata2=nodata_3, comm_new=comm_tmp, &
     382        85170 :                                     move_data_1=move_b, unit_nr=unit_nr_prv)
     383        85170 :          info = dbt_tas_info(matrix_b_rs)
     384        85170 :          CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
     385              : 
     386        85170 :          new_a = .FALSE.
     387        85170 :          IF (matrix_a%do_batched <= 2) THEN
     388       337120 :             ALLOCATE (matrix_a_rs)
     389        67424 :             CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
     390        67424 :             transa_prv = .FALSE.
     391        67424 :             new_a = .TRUE.
     392              :          END IF
     393              : 
     394        85170 :          tr_case = transb_prv
     395              : 
     396       378798 :          IF (unit_nr_prv > 0) THEN
     397           10 :             IF (.NOT. tr_case) THEN
     398            0 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
     399              :             ELSE
     400           10 :                WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
     401              :             END IF
     402              :          END IF
     403              : 
     404              :       END SELECT
     405              : 
     406       208458 :       CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
     407              : 
     408       208458 :       numproc = mp_comm%num_pe
     409       625374 :       pdims_sub = mp_comm_group%num_pe_cart
     410              : 
     411       208458 :       opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.)
     412              : 
     413       208458 :       IF (PRESENT(filter_eps)) THEN
     414       168710 :          filter_eps_prv = filter_eps
     415              :       ELSE
     416        39748 :          filter_eps_prv = 0.0_dp
     417              :       END IF
     418              : 
     419       208458 :       IF (unit_nr_prv /= 0) THEN
     420        52152 :          IF (unit_nr_prv > 0) THEN
     421           34 :             WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
     422              :          END IF
     423        52152 :          CALL dbt_tas_write_split_info(info, unit_nr_prv)
     424        52152 :          IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
     425        52152 :          IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
     426        52152 :          IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
     427        52152 :          IF (unit_nr_prv > 0) THEN
     428           34 :             IF (opt_pgrid) THEN
     429            0 :                WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
     430              :             ELSE
     431           34 :                WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
     432              :             END IF
     433              :          END IF
     434              :       END IF
     435              : 
     436       208458 :       pdims = 0
     437       208458 :       CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
     438              : 
     439              :       ! Convert DBM submatrices to optimized process grids and multiply
     440        59176 :       SELECT CASE (max_mm_dim)
     441              :       CASE (1)
     442        59176 :          IF (matrix_b%do_batched <= 2) THEN
     443       281800 :             ALLOCATE (matrix_b_rep)
     444        56360 :             CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.)
     445        56360 :             IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
     446         7922 :                matrix_b%mm_storage%store_batched_repl => matrix_b_rep
     447         7922 :                CALL dbt_tas_set_batched_state(matrix_b, state=3)
     448              :             END IF
     449         2816 :          ELSEIF (matrix_b%do_batched == 3) THEN
     450         2816 :             matrix_b_rep => matrix_b%mm_storage%store_batched_repl
     451              :          END IF
     452              : 
     453        59176 :          IF (new_b) THEN
     454        56360 :             CALL dbt_tas_destroy(matrix_b_rs)
     455        56360 :             DEALLOCATE (matrix_b_rs)
     456              :          END IF
     457        59176 :          IF (unit_nr_prv /= 0) THEN
     458          438 :             CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
     459          438 :             CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
     460              :          END IF
     461              : 
     462        59176 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
     463              : 
     464              :          ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
     465        59176 :          info_a = dbt_tas_info(matrix_a_rs)
     466        59176 :          CALL dbt_tas_info_hold(info_a)
     467              : 
     468        59176 :          IF (new_a) THEN
     469         6082 :             CALL dbt_tas_destroy(matrix_a_rs)
     470         6082 :             DEALLOCATE (matrix_a_rs)
     471              :          END IF
     472              :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
     473        59176 :                                    move_data=matrix_b%do_batched == 0)
     474              : 
     475        59176 :          info_b = dbt_tas_info(matrix_b_rep)
     476        59176 :          CALL dbt_tas_info_hold(info_b)
     477              : 
     478        59176 :          IF (matrix_b%do_batched == 0) THEN
     479        48438 :             CALL dbt_tas_destroy(matrix_b_rep)
     480        48438 :             DEALLOCATE (matrix_b_rep)
     481              :          END IF
     482              : 
     483        59176 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
     484              : 
     485        59176 :          info_c = dbt_tas_info(matrix_c_rs)
     486        59176 :          CALL dbt_tas_info_hold(info_c)
     487              : 
     488        59176 :          CALL matrix_a%dist%info%mp_comm%sync()
     489        59176 :          CALL timeset("dbt_tas_dbm", handle4)
     490        59176 :          IF (.NOT. tr_case) THEN
     491        53380 :             CALL timeset("dbt_tas_mm_1N", handle3)
     492              : 
     493              :             CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
     494              :                               matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
     495        53380 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     496        53380 :             CALL timestop(handle3)
     497              :          ELSE
     498         5796 :             CALL timeset("dbt_tas_mm_1T", handle3)
     499              :             CALL dbm_multiply(transa=.TRUE., transb=.FALSE., alpha=alpha, &
     500              :                               matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
     501         5796 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     502              : 
     503         5796 :             CALL timestop(handle3)
     504              :          END IF
     505        59176 :          CALL matrix_a%dist%info%mp_comm%sync()
     506        59176 :          CALL timestop(handle4)
     507              : 
     508        59176 :          CALL dbm_release(matrix_a_mm)
     509        59176 :          CALL dbm_release(matrix_b_mm)
     510              : 
     511        59176 :          nze_c = dbm_get_nze(matrix_c_mm)
     512              : 
     513        59176 :          IF (.NOT. new_c) THEN
     514        53324 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
     515              :          ELSE
     516         5852 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
     517              :          END IF
     518              : 
     519        59176 :          CALL dbm_release(matrix_c_mm)
     520              : 
     521        59176 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
     522              : 
     523       237142 :          IF (unit_nr_prv /= 0) THEN
     524          438 :             CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
     525              :          END IF
     526              : 
     527              :       CASE (2)
     528        64112 :          IF (matrix_c%do_batched <= 1) THEN
     529       269980 :             ALLOCATE (matrix_c_rep)
     530        53996 :             CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
     531        53996 :             IF (matrix_c%do_batched == 1) THEN
     532        27562 :                matrix_c%mm_storage%store_batched_repl => matrix_c_rep
     533        27562 :                CALL dbt_tas_set_batched_state(matrix_c, state=3)
     534              :             END IF
     535        10116 :          ELSEIF (matrix_c%do_batched == 2) THEN
     536         5760 :             ALLOCATE (matrix_c_rep)
     537         1152 :             CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
     538              :             ! just leave sparsity structure for retain sparsity but no values
     539         1152 :             IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
     540         1152 :             matrix_c%mm_storage%store_batched_repl => matrix_c_rep
     541         1152 :             CALL dbt_tas_set_batched_state(matrix_c, state=3)
     542         8964 :          ELSEIF (matrix_c%do_batched == 3) THEN
     543         8964 :             matrix_c_rep => matrix_c%mm_storage%store_batched_repl
     544              :          END IF
     545              : 
     546        64112 :          IF (unit_nr_prv /= 0) THEN
     547        22724 :             CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
     548        22724 :             CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
     549              :          END IF
     550              : 
     551        64112 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
     552              : 
     553              :          ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
     554        64112 :          info_a = dbt_tas_info(matrix_a_rs)
     555        64112 :          CALL dbt_tas_info_hold(info_a)
     556              : 
     557        64112 :          IF (new_a) THEN
     558          594 :             CALL dbt_tas_destroy(matrix_a_rs)
     559          594 :             DEALLOCATE (matrix_a_rs)
     560              :          END IF
     561              : 
     562        64112 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
     563              : 
     564        64112 :          info_b = dbt_tas_info(matrix_b_rs)
     565        64112 :          CALL dbt_tas_info_hold(info_b)
     566              : 
     567        64112 :          IF (new_b) THEN
     568          938 :             CALL dbt_tas_destroy(matrix_b_rs)
     569          938 :             DEALLOCATE (matrix_b_rs)
     570              :          END IF
     571              : 
     572        64112 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
     573              : 
     574        64112 :          info_c = dbt_tas_info(matrix_c_rep)
     575        64112 :          CALL dbt_tas_info_hold(info_c)
     576              : 
     577        64112 :          CALL matrix_a%dist%info%mp_comm%sync()
     578        64112 :          CALL timeset("dbt_tas_dbm", handle4)
     579        64112 :          CALL timeset("dbt_tas_mm_2", handle3)
     580              :          CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
     581              :                            matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
     582        64112 :                            filter_eps=filter_eps_prv/REAL(nsplit, KIND=dp), retain_sparsity=retain_sparsity, flop=flop)
     583        64112 :          CALL matrix_a%dist%info%mp_comm%sync()
     584        64112 :          CALL timestop(handle3)
     585        64112 :          CALL timestop(handle4)
     586              : 
     587        64112 :          CALL dbm_release(matrix_a_mm)
     588        64112 :          CALL dbm_release(matrix_b_mm)
     589              : 
     590        64112 :          nze_c = dbm_get_nze(matrix_c_mm)
     591              : 
     592        64112 :          CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
     593        64112 :          nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
     594              : 
     595        64112 :          CALL dbm_release(matrix_c_mm)
     596              : 
     597        64112 :          IF (unit_nr_prv /= 0) THEN
     598        22724 :             CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
     599              :          END IF
     600              : 
     601        64112 :          IF (matrix_c%do_batched == 0) THEN
     602        26434 :             CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.)
     603              :          ELSE
     604        37678 :             matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
     605              :          END IF
     606              : 
     607        64112 :          IF (matrix_c%do_batched == 0) THEN
     608        26434 :             CALL dbt_tas_destroy(matrix_c_rep)
     609        26434 :             DEALLOCATE (matrix_c_rep)
     610              :          END IF
     611              : 
     612        64112 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
     613              : 
     614              :          ! set upper limit on memory consumption for replicated matrix and complete batched mm
     615              :          ! if limit is exceeded
     616       322312 :          IF (nze_c_sum > default_nsplit_accept_ratio*MAX(nze_a, nze_b)) THEN
     617         1752 :             CALL dbt_tas_batched_mm_complete(matrix_c)
     618              :          END IF
     619              : 
     620              :       CASE (3)
     621        85170 :          IF (matrix_a%do_batched <= 2) THEN
     622       337120 :             ALLOCATE (matrix_a_rep)
     623        67424 :             CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.)
     624        67424 :             IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
     625        27542 :                matrix_a%mm_storage%store_batched_repl => matrix_a_rep
     626        27542 :                CALL dbt_tas_set_batched_state(matrix_a, state=3)
     627              :             END IF
     628        17746 :          ELSEIF (matrix_a%do_batched == 3) THEN
     629        17746 :             matrix_a_rep => matrix_a%mm_storage%store_batched_repl
     630              :          END IF
     631              : 
     632        85170 :          IF (new_a) THEN
     633        67424 :             CALL dbt_tas_destroy(matrix_a_rs)
     634        67424 :             DEALLOCATE (matrix_a_rs)
     635              :          END IF
     636        85170 :          IF (unit_nr_prv /= 0) THEN
     637        28990 :             CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
     638        28990 :             CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
     639              :          END IF
     640              : 
     641              :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
     642        85170 :                                    move_data=matrix_a%do_batched == 0)
     643              : 
     644              :          ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
     645        85170 :          info_a = dbt_tas_info(matrix_a_rep)
     646        85170 :          CALL dbt_tas_info_hold(info_a)
     647              : 
     648        85170 :          IF (matrix_a%do_batched == 0) THEN
     649        39882 :             CALL dbt_tas_destroy(matrix_a_rep)
     650        39882 :             DEALLOCATE (matrix_a_rep)
     651              :          END IF
     652              : 
     653        85170 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
     654              : 
     655        85170 :          info_b = dbt_tas_info(matrix_b_rs)
     656        85170 :          CALL dbt_tas_info_hold(info_b)
     657              : 
     658        85170 :          IF (new_b) THEN
     659           16 :             CALL dbt_tas_destroy(matrix_b_rs)
     660           16 :             DEALLOCATE (matrix_b_rs)
     661              :          END IF
     662        85170 :          CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
     663              : 
     664        85170 :          info_c = dbt_tas_info(matrix_c_rs)
     665        85170 :          CALL dbt_tas_info_hold(info_c)
     666              : 
     667        85170 :          CALL matrix_a%dist%info%mp_comm%sync()
     668        85170 :          CALL timeset("dbt_tas_dbm", handle4)
     669        85170 :          IF (.NOT. tr_case) THEN
     670        39586 :             CALL timeset("dbt_tas_mm_3N", handle3)
     671              :             CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
     672              :                               matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
     673        39586 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     674        39586 :             CALL timestop(handle3)
     675              :          ELSE
     676        45584 :             CALL timeset("dbt_tas_mm_3T", handle3)
     677              :             CALL dbm_multiply(transa=.FALSE., transb=.TRUE., alpha=alpha, &
     678              :                               matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
     679        45584 :                               filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
     680        45584 :             CALL timestop(handle3)
     681              :          END IF
     682        85170 :          CALL matrix_a%dist%info%mp_comm%sync()
     683        85170 :          CALL timestop(handle4)
     684              : 
     685        85170 :          CALL dbm_release(matrix_a_mm)
     686        85170 :          CALL dbm_release(matrix_b_mm)
     687              : 
     688        85170 :          nze_c = dbm_get_nze(matrix_c_mm)
     689              : 
     690        85170 :          IF (.NOT. new_c) THEN
     691        78032 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
     692              :          ELSE
     693         7138 :             CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
     694              :          END IF
     695              : 
     696        85170 :          CALL dbm_release(matrix_c_mm)
     697              : 
     698        85170 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
     699              : 
     700       549138 :          IF (unit_nr_prv /= 0) THEN
     701        28990 :             CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
     702              :          END IF
     703              :       END SELECT
     704              : 
     705       208458 :       CALL mp_comm_mm%free()
     706              : 
     707       208458 :       CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
     708              : 
     709       208458 :       IF (PRESENT(split_opt)) THEN
     710        99599 :          SELECT CASE (max_mm_dim)
     711              :          CASE (1, 3)
     712        99599 :             CALL mp_comm%sum(nze_c)
     713              :          CASE (2)
     714        52217 :             CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
     715        52217 :             CALL mp_comm%sum(nze_c)
     716       204033 :             CALL mp_comm%max(nze_c)
     717              : 
     718              :          END SELECT
     719       151816 :          nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
     720              :          ! ideally we should rederive the split factor from the actual sparsity of C, but
     721              :          ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
     722       151816 :          mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
     723       151816 :          CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.)
     724       151816 :          IF (unit_nr_prv > 0) THEN
     725              :             WRITE (unit_nr_prv, "(T2,A)") &
     726           10 :                "MM PARAMETERS"
     727           10 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
     728           20 :                (nze_c + numproc - 1)/numproc
     729              : 
     730           10 :             WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
     731              :          END IF
     732              : 
     733              :       END IF
     734              : 
     735       208458 :       IF (new_c) THEN
     736        39424 :          CALL dbm_scale(matrix_c%matrix, beta)
     737              :          CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., &
     738              :                               transposed=(transc_prv .NEQV. transc), &
     739        39424 :                               move_data=.TRUE.)
     740        39424 :          CALL dbt_tas_destroy(matrix_c_rs)
     741        39424 :          DEALLOCATE (matrix_c_rs)
     742        39424 :          IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
     743       169034 :       ELSEIF (matrix_c%do_batched > 0) THEN
     744        37984 :          IF (matrix_c%mm_storage%batched_out) THEN
     745        37678 :             matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
     746              :          END IF
     747              :       END IF
     748              : 
     749       208458 :       IF (PRESENT(move_data_a)) THEN
     750       208410 :          IF (move_data_a) CALL dbt_tas_clear(matrix_a)
     751              :       END IF
     752       208458 :       IF (PRESENT(move_data_b)) THEN
     753       208410 :          IF (move_data_b) CALL dbt_tas_clear(matrix_b)
     754              :       END IF
     755              : 
     756       208458 :       IF (PRESENT(flop)) THEN
     757       119093 :          CALL mp_comm%sum(flop)
     758       119093 :          flop = (flop + numproc - 1)/numproc
     759              :       END IF
     760              : 
     761       208458 :       IF (PRESENT(optimize_dist)) THEN
     762           48 :          IF (optimize_dist) CALL comm_tmp%free()
     763              :       END IF
     764       208458 :       IF (unit_nr_prv > 0) THEN
     765           34 :          WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
     766           34 :          WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
     767           34 :          WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
     768              :       END IF
     769              : 
     770       208458 :       CALL dbt_tas_release_info(info_a)
     771       208458 :       CALL dbt_tas_release_info(info_b)
     772       208458 :       CALL dbt_tas_release_info(info_c)
     773              : 
     774       208458 :       CALL matrix_a%dist%info%mp_comm%sync()
     775       208458 :       CALL timestop(handle2)
     776       208458 :       CALL timestop(handle)
     777              : 
     778       416916 :    END SUBROUTINE
     779              : 
     780              : ! **************************************************************************************************
     781              : !> \brief ...
     782              : !> \param matrix_in ...
     783              : !> \param matrix_out ...
     784              : !> \param local_copy ...
     785              : !> \param alpha ...
     786              : !> \author Patrick Seewald
     787              : ! **************************************************************************************************
     788       208458 :    SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
     789              :       TYPE(dbm_type), INTENT(IN)                         :: matrix_in
     790              :       TYPE(dbm_type), INTENT(INOUT)                      :: matrix_out
     791              :       LOGICAL, INTENT(IN), OPTIONAL                      :: local_copy
     792              :       REAL(dp), INTENT(IN)                               :: alpha
     793              : 
     794              :       LOGICAL                                            :: local_copy_prv
     795              :       TYPE(dbm_type)                                     :: matrix_tmp
     796              : 
     797       208458 :       IF (PRESENT(local_copy)) THEN
     798       208458 :          local_copy_prv = local_copy
     799              :       ELSE
     800              :          local_copy_prv = .FALSE.
     801              :       END IF
     802              : 
     803       208458 :       IF (alpha /= 1.0_dp) THEN
     804       129979 :          CALL dbm_scale(matrix_out, alpha)
     805              :       END IF
     806              : 
     807       208458 :       IF (.NOT. local_copy_prv) THEN
     808            0 :          CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
     809            0 :          CALL dbm_redistribute(matrix_in, matrix_tmp)
     810            0 :          CALL dbm_add(matrix_out, matrix_tmp)
     811            0 :          CALL dbm_release(matrix_tmp)
     812              :       ELSE
     813       208458 :          CALL dbm_add(matrix_out, matrix_in)
     814              :       END IF
     815              : 
     816       208458 :    END SUBROUTINE
     817              : 
     818              : ! **************************************************************************************************
     819              : !> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
     820              : !>        the same process grid as the other 2 matrices.
     821              : !> \param mp_comm communicator that defines Cartesian topology
     822              : !> \param matrix_in ...
     823              : !> \param matrix_out ...
     824              : !> \param transposed Whether matrix_out should be transposed
     825              : !> \param nodata Data of matrix_in should not be copied to matrix_out
     826              : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
     827              : !> \author Patrick Seewald
     828              : ! **************************************************************************************************
     829      1252524 :    SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
     830              :       TYPE(mp_cart_type), INTENT(IN)                     :: mp_comm
     831              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_in
     832              :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix_out
     833              :       LOGICAL, INTENT(IN)                                :: transposed
     834              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata, move_data
     835              : 
     836              :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reshape_mm_small'
     837              : 
     838              :       INTEGER                                            :: handle
     839              :       INTEGER(KIND=int_8), DIMENSION(2)                  :: dims
     840              :       INTEGER, DIMENSION(2)                              :: pdims
     841              :       LOGICAL                                            :: nodata_prv
     842       178932 :       TYPE(dbt_tas_dist_arb)                             :: new_col_dist, new_row_dist
     843       894660 :       TYPE(dbt_tas_distribution_type)                    :: dist
     844              : 
     845       178932 :       CALL timeset(routineN, handle)
     846              : 
     847       178932 :       IF (PRESENT(nodata)) THEN
     848        55148 :          nodata_prv = nodata
     849              :       ELSE
     850              :          nodata_prv = .FALSE.
     851              :       END IF
     852              : 
     853       536796 :       pdims = mp_comm%num_pe_cart
     854              : 
     855       536796 :       dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
     856              : 
     857       178932 :       IF (transposed) CALL swap(dims)
     858              : 
     859       178932 :       IF (.NOT. transposed) THEN
     860       125332 :          new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
     861       125332 :          new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
     862       125332 :          CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
     863              :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
     864       125332 :                              matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
     865              :       ELSE
     866        53600 :          new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
     867        53600 :          new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
     868        53600 :          CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
     869              :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
     870        53600 :                              matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
     871              :       END IF
     872       178932 :       IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
     873              : 
     874       178932 :       CALL timestop(handle)
     875              : 
     876       178932 :    END SUBROUTINE
     877              : 
     878              : ! **************************************************************************************************
     879              : !> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
     880              : !>        with the same split factor.
     881              : !> \param matrix1_in ...
     882              : !> \param matrix2_in ...
     883              : !> \param matrix1_out ...
     884              : !> \param matrix2_out ...
     885              : !> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
     886              : !> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
     887              : !> \param trans1 transpose flag of matrix1_in for multiplication
     888              : !> \param trans2 transpose flag of matrix2_in for multiplication
     889              : !> \param optimize_dist experimental: optimize matrix splitting and distribution
     890              : !> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
     891              : !> \param opt_nsplit ...
     892              : !> \param split_rc_1 Whether to split rows or columns for matrix 1
     893              : !> \param split_rc_2 Whether to split rows or columns for matrix 2
     894              : !> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
     895              : !> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
     896              : !> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
     897              : !> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
     898              : !> \param comm_new returns the new communicator only if optimize_dist
     899              : !> \param unit_nr output unit
     900              : !> \author Patrick Seewald
     901              : ! **************************************************************************************************
     902       208458 :    SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
     903              :                                     optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
     904              :                                     move_data_1, move_data_2, comm_new, unit_nr)
     905              :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix1_in, matrix2_in
     906              :       TYPE(dbt_tas_type), INTENT(OUT), POINTER           :: matrix1_out, matrix2_out
     907              :       LOGICAL, INTENT(OUT)                               :: new1, new2
     908              :       LOGICAL, INTENT(INOUT)                             :: trans1, trans2
     909              :       LOGICAL, INTENT(IN), OPTIONAL                      :: optimize_dist
     910              :       INTEGER, INTENT(IN), OPTIONAL                      :: nsplit
     911              :       LOGICAL, INTENT(IN), OPTIONAL                      :: opt_nsplit
     912              :       INTEGER, INTENT(INOUT)                             :: split_rc_1, split_rc_2
     913              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata1, nodata2
     914              :       LOGICAL, INTENT(INOUT), OPTIONAL                   :: move_data_1, move_data_2
     915              :       TYPE(mp_cart_type), INTENT(OUT), OPTIONAL          :: comm_new
     916              :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
     917              : 
     918              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_compatible'
     919              : 
     920              :       INTEGER                                            :: handle, nsplit_prv, ref, split_rc_ref, &
     921              :                                                             unit_nr_prv
     922              :       INTEGER(KIND=int_8)                                :: d1, d2, nze1, nze2
     923              :       INTEGER(KIND=int_8), DIMENSION(2)                  :: dims1, dims2, dims_ref
     924              :       INTEGER, DIMENSION(2)                              :: pdims
     925              :       LOGICAL                                            :: nodata1_prv, nodata2_prv, &
     926              :                                                             optimize_dist_prv, trans1_newdist, &
     927              :                                                             trans2_newdist
     928              :       TYPE(dbt_tas_dist_cyclic)                          :: col_dist_1, col_dist_2, row_dist_1, &
     929              :                                                             row_dist_2
     930      1876122 :       TYPE(dbt_tas_distribution_type)                    :: dist_1, dist_2
     931      1042290 :       TYPE(dbt_tas_split_info)                           :: split_info
     932       208458 :       TYPE(mp_cart_type)                                 :: mp_comm
     933              : 
     934       208458 :       CALL timeset(routineN, handle)
     935       208458 :       new1 = .FALSE.; new2 = .FALSE.
     936              : 
     937       208458 :       IF (PRESENT(nodata1)) THEN
     938            0 :          nodata1_prv = nodata1
     939              :       ELSE
     940              :          nodata1_prv = .FALSE.
     941              :       END IF
     942              : 
     943       208458 :       IF (PRESENT(nodata2)) THEN
     944       144346 :          nodata2_prv = nodata2
     945              :       ELSE
     946              :          nodata2_prv = .FALSE.
     947              :       END IF
     948              : 
     949       208458 :       unit_nr_prv = prep_output_unit(unit_nr)
     950              : 
     951       208458 :       NULLIFY (matrix1_out, matrix2_out)
     952              : 
     953       208458 :       IF (PRESENT(optimize_dist)) THEN
     954           48 :          optimize_dist_prv = optimize_dist
     955              :       ELSE
     956              :          optimize_dist_prv = .FALSE.
     957              :       END IF
     958              : 
     959       625374 :       dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
     960       625374 :       dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
     961       208458 :       nze1 = dbt_tas_get_nze_total(matrix1_in)
     962       208458 :       nze2 = dbt_tas_get_nze_total(matrix2_in)
     963              : 
     964       208458 :       IF (trans1) split_rc_1 = MOD(split_rc_1, 2) + 1
     965              : 
     966       208458 :       IF (trans2) split_rc_2 = MOD(split_rc_2, 2) + 1
     967              : 
     968       208458 :       IF (nze1 >= nze2) THEN
     969       195311 :          ref = 1
     970       195311 :          split_rc_ref = split_rc_1
     971       195311 :          dims_ref = dims1
     972              :       ELSE
     973        13147 :          ref = 2
     974        13147 :          split_rc_ref = split_rc_2
     975        13147 :          dims_ref = dims2
     976              :       END IF
     977              : 
     978       208458 :       IF (PRESENT(nsplit)) THEN
     979       208458 :          nsplit_prv = nsplit
     980              :       ELSE
     981            0 :          nsplit_prv = 0
     982              :       END IF
     983              : 
     984       208458 :       IF (optimize_dist_prv) THEN
     985           48 :          CPASSERT(PRESENT(comm_new))
     986              :       END IF
     987              : 
     988       208410 :       IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
     989              :          CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
     990       193698 :                            move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
     991       193698 :          CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
     992              :          CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
     993       193698 :                            move_data=move_data_2, nodata=nodata2, opt_nsplit=.FALSE.)
     994       193698 :          IF (unit_nr_prv > 0) THEN
     995           10 :             WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
     996           10 :                TRIM(dbm_get_name(matrix1_in%matrix)), &
     997           20 :                "and", TRIM(dbm_get_name(matrix2_in%matrix))
     998           10 :             IF (new1) THEN
     999            0 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1000            0 :                   TRIM(dbm_get_name(matrix1_in%matrix)), ": Yes"
    1001              :             ELSE
    1002           10 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1003           20 :                   TRIM(dbm_get_name(matrix1_in%matrix)), ": No"
    1004              :             END IF
    1005           10 :             IF (new2) THEN
    1006            0 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1007            0 :                   TRIM(dbm_get_name(matrix2_in%matrix)), ": Yes"
    1008              :             ELSE
    1009           10 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
    1010           20 :                   TRIM(dbm_get_name(matrix2_in%matrix)), ": No"
    1011              :             END IF
    1012              :          END IF
    1013              :       ELSE
    1014              : 
    1015        14712 :          IF (optimize_dist_prv) THEN
    1016           48 :             IF (unit_nr_prv > 0) THEN
    1017           24 :                WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
    1018           24 :                   TRIM(dbm_get_name(matrix1_in%matrix)), &
    1019           48 :                   "and", TRIM(dbm_get_name(matrix2_in%matrix))
    1020              :             END IF
    1021              : 
    1022           48 :             trans1_newdist = (split_rc_1 == colsplit)
    1023           48 :             trans2_newdist = (split_rc_2 == colsplit)
    1024              : 
    1025           48 :             IF (trans1_newdist) THEN
    1026           24 :                CALL swap(dims1)
    1027           24 :                trans1 = .NOT. trans1
    1028              :             END IF
    1029              : 
    1030           48 :             IF (trans2_newdist) THEN
    1031           24 :                CALL swap(dims2)
    1032           24 :                trans2 = .NOT. trans2
    1033              :             END IF
    1034              : 
    1035           48 :             IF (nsplit_prv == 0) THEN
    1036            0 :                SELECT CASE (split_rc_ref)
    1037              :                CASE (rowsplit)
    1038            0 :                   d1 = dims_ref(1)
    1039            0 :                   d2 = dims_ref(2)
    1040              :                CASE (colsplit)
    1041            0 :                   d1 = dims_ref(2)
    1042            0 :                   d2 = dims_ref(1)
    1043              :                END SELECT
    1044            0 :                nsplit_prv = INT((d1 - 1)/d2 + 1)
    1045              :             END IF
    1046              : 
    1047           48 :             CPASSERT(nsplit_prv > 0)
    1048              : 
    1049           48 :             CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
    1050           48 :             comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
    1051           48 :             CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
    1052              : 
    1053          144 :             pdims = comm_new%num_pe_cart
    1054              : 
    1055              :             ! use a very simple cyclic distribution that may not be load balanced if block
    1056              :             ! sizes are not equal. However we can not use arbitrary distributions
    1057              :             ! for large dimensions since this would require storing distribution vectors as arrays
    1058              :             ! which can not be stored for large dimensions.
    1059           48 :             row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
    1060           48 :             col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
    1061              : 
    1062           48 :             row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
    1063           48 :             col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
    1064              : 
    1065           48 :             CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
    1066           48 :             CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
    1067           48 :             CALL dbt_tas_release_info(split_info)
    1068              : 
    1069          240 :             ALLOCATE (matrix1_out)
    1070           48 :             IF (.NOT. trans1_newdist) THEN
    1071              :                CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
    1072           24 :                                    matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.)
    1073              : 
    1074              :             ELSE
    1075              :                CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
    1076           24 :                                    matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.)
    1077              :             END IF
    1078              : 
    1079          240 :             ALLOCATE (matrix2_out)
    1080           48 :             IF (.NOT. trans2_newdist) THEN
    1081              :                CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
    1082           24 :                                    matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.)
    1083              :             ELSE
    1084              :                CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
    1085           24 :                                    matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.)
    1086              :             END IF
    1087              : 
    1088           48 :             IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
    1089           48 :             IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
    1090           48 :             new1 = .TRUE.
    1091           48 :             new2 = .TRUE.
    1092              : 
    1093              :          ELSE
    1094        13856 :             SELECT CASE (ref)
    1095              :             CASE (1)
    1096        13856 :                IF (unit_nr_prv > 0) THEN
    1097            0 :                   WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
    1098            0 :                      TRIM(dbm_get_name(matrix2_in%matrix))
    1099              :                END IF
    1100              : 
    1101              :                CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
    1102        13856 :                                  move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
    1103              : 
    1104        69280 :                ALLOCATE (matrix2_out)
    1105              :                CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
    1106        13856 :                                         nodata=nodata2, move_data=move_data_2)
    1107        13856 :                new2 = .TRUE.
    1108              :             CASE (2)
    1109          856 :                IF (unit_nr_prv > 0) THEN
    1110            0 :                   WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
    1111            0 :                      TRIM(dbm_get_name(matrix1_in%matrix))
    1112              :                END IF
    1113              : 
    1114              :                CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
    1115          856 :                                  move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
    1116              : 
    1117         4280 :                ALLOCATE (matrix1_out)
    1118              :                CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
    1119          856 :                                         nodata=nodata1, move_data=move_data_1)
    1120        30280 :                new1 = .TRUE.
    1121              :             END SELECT
    1122              :          END IF
    1123              :       END IF
    1124              : 
    1125       208458 :       IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE.
    1126       208458 :       IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE.
    1127              : 
    1128       208458 :       CALL timestop(handle)
    1129              : 
    1130       625374 :    END SUBROUTINE
    1131              : 
    1132              : ! **************************************************************************************************
    1133              : !> \brief Change split factor without redistribution
    1134              : !> \param matrix_in ...
    1135              : !> \param matrix_out ...
    1136              : !> \param nsplit new split factor, set to 0 to not change split of matrix_in
    1137              : !> \param split_rowcol split rows or columns
    1138              : !> \param is_new whether matrix_out is new or a pointer to matrix_in
    1139              : !> \param opt_nsplit whether nsplit should be optimized for current process grid
    1140              : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
    1141              : !> \param nodata Data of matrix_in should not be copied to matrix_out
    1142              : !> \author Patrick Seewald
    1143              : ! **************************************************************************************************
    1144       402108 :    SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
    1145              :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_in
    1146              :       TYPE(dbt_tas_type), INTENT(OUT), POINTER           :: matrix_out
    1147              :       INTEGER, INTENT(IN)                                :: nsplit, split_rowcol
    1148              :       LOGICAL, INTENT(OUT)                               :: is_new
    1149              :       LOGICAL, INTENT(IN), OPTIONAL                      :: opt_nsplit
    1150              :       LOGICAL, INTENT(INOUT), OPTIONAL                   :: move_data
    1151              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    1152              : 
    1153              :       CHARACTER(len=default_string_length)               :: name
    1154              :       INTEGER                                            :: handle, nsplit_new, nsplit_old, &
    1155              :                                                             nsplit_prv, split_rc
    1156              :       LOGICAL                                            :: nodata_prv
    1157      2010540 :       TYPE(dbt_tas_distribution_type)                    :: dist
    1158      2010540 :       TYPE(dbt_tas_split_info)                           :: split_info
    1159       402108 :       TYPE(mp_cart_type)                                 :: mp_comm
    1160              : 
    1161      1608432 :       CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
    1162       804216 :       CLASS(dbt_tas_rowcol_data), ALLOCATABLE  :: rbsize, cbsize
    1163              :       CHARACTER(LEN=*), PARAMETER                :: routineN = 'change_split'
    1164              : 
    1165       402108 :       NULLIFY (matrix_out)
    1166              : 
    1167       402108 :       is_new = .TRUE.
    1168              : 
    1169              :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
    1170       402108 :                                   split_rowcol=split_rc, nsplit=nsplit_old)
    1171              : 
    1172       402108 :       IF (nsplit == 0) THEN
    1173       122220 :          IF (split_rowcol == split_rc) THEN
    1174       119163 :             matrix_out => matrix_in
    1175       119163 :             is_new = .FALSE.
    1176       119163 :             RETURN
    1177              :          ELSE
    1178         3057 :             nsplit_prv = 1
    1179              :          END IF
    1180              :       ELSE
    1181       279888 :          nsplit_prv = nsplit
    1182              :       END IF
    1183              : 
    1184       282945 :       CALL timeset(routineN, handle)
    1185              : 
    1186       282945 :       nodata_prv = .FALSE.
    1187       282945 :       IF (PRESENT(nodata)) nodata_prv = nodata
    1188              : 
    1189              :       CALL dbt_tas_get_info(matrix_in, name=name, &
    1190              :                             row_blk_size=rbsize, col_blk_size=cbsize, &
    1191              :                             proc_row_dist=rdist, proc_col_dist=cdist)
    1192              : 
    1193       282945 :       CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
    1194              : 
    1195       282945 :       CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
    1196              : 
    1197       282945 :       IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
    1198       277133 :          matrix_out => matrix_in
    1199       277133 :          is_new = .FALSE.
    1200       277133 :          CALL dbt_tas_release_info(split_info)
    1201       277133 :          CALL timestop(handle)
    1202       277133 :          RETURN
    1203              :       END IF
    1204              : 
    1205              :       CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
    1206         5812 :                                     split_info=split_info)
    1207              : 
    1208         5812 :       CALL dbt_tas_release_info(split_info)
    1209              : 
    1210        29060 :       ALLOCATE (matrix_out)
    1211         5812 :       CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.TRUE.)
    1212              : 
    1213         5812 :       IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
    1214              : 
    1215         5812 :       IF (PRESENT(move_data)) THEN
    1216         5812 :          IF (.NOT. nodata_prv) THEN
    1217         5812 :             IF (move_data) CALL dbt_tas_clear(matrix_in)
    1218         5812 :             move_data = .TRUE.
    1219              :          END IF
    1220              :       END IF
    1221              : 
    1222         5812 :       CALL timestop(handle)
    1223      1429346 :    END SUBROUTINE
    1224              : 
    1225              : ! **************************************************************************************************
    1226              : !> \brief Check whether matrices have same distribution and same split.
    1227              : !> \param mat_a ...
    1228              : !> \param mat_b ...
    1229              : !> \param split_rc_a ...
    1230              : !> \param split_rc_b ...
    1231              : !> \param unit_nr ...
    1232              : !> \return ...
    1233              : !> \author Patrick Seewald
    1234              : ! **************************************************************************************************
    1235       208410 :    FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
    1236              :       TYPE(dbt_tas_type), INTENT(IN)                     :: mat_a, mat_b
    1237              :       INTEGER, INTENT(IN)                                :: split_rc_a, split_rc_b
    1238              :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    1239              :       LOGICAL                                            :: dist_compatible
    1240              : 
    1241              :       INTEGER                                            :: numproc, same_local_rowcols, &
    1242              :                                                             split_check_a, split_check_b, &
    1243              :                                                             unit_nr_prv
    1244       208410 :       INTEGER(int_8), ALLOCATABLE, DIMENSION(:)          :: local_rowcols_a, local_rowcols_b
    1245              :       INTEGER, DIMENSION(2)                              :: pdims_a, pdims_b
    1246      1875690 :       TYPE(dbt_tas_split_info)                           :: info_a, info_b
    1247              : 
    1248       208410 :       unit_nr_prv = prep_output_unit(unit_nr)
    1249              : 
    1250       208410 :       dist_compatible = .FALSE.
    1251              : 
    1252       208410 :       info_a = dbt_tas_info(mat_a)
    1253       208410 :       info_b = dbt_tas_info(mat_b)
    1254       208410 :       CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
    1255       208410 :       CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
    1256       208410 :       IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
    1257        14632 :          IF (unit_nr_prv > 0) THEN
    1258            0 :             WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
    1259            0 :             WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
    1260              :          END IF
    1261        14688 :          RETURN
    1262              :       END IF
    1263              : 
    1264              :       ! check if communicators are equivalent
    1265              :       ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
    1266              :       ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
    1267       193778 :       numproc = info_b%mp_comm%num_pe
    1268       581334 :       pdims_a = info_a%mp_comm%num_pe_cart
    1269       581334 :       pdims_b = info_b%mp_comm%num_pe_cart
    1270       193778 :       IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
    1271           56 :          IF (unit_nr_prv > 0) THEN
    1272            0 :             WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
    1273              :          END IF
    1274           56 :          RETURN
    1275              :       END IF
    1276              : 
    1277              :       ! check that distribution is the same by comparing local rows / columns for each matrix
    1278       132594 :       SELECT CASE (split_rc_a)
    1279              :       CASE (rowsplit)
    1280       132594 :          CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
    1281       132594 :          CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
    1282              :       CASE (colsplit)
    1283        61128 :          CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
    1284       254850 :          CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
    1285              :       END SELECT
    1286              : 
    1287       193722 :       same_local_rowcols = MERGE(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
    1288              : 
    1289       193722 :       CALL info_a%mp_comm%sum(same_local_rowcols)
    1290              : 
    1291       193722 :       IF (same_local_rowcols == numproc) THEN
    1292              :          dist_compatible = .TRUE.
    1293              :       ELSE
    1294           24 :          IF (unit_nr_prv > 0) THEN
    1295            0 :             WRITE (unit_nr_prv, *) "local rowcols not compatible"
    1296            0 :             WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
    1297            0 :             WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
    1298              :          END IF
    1299              :       END IF
    1300              : 
    1301       416820 :    END FUNCTION
    1302              : 
    1303              : ! **************************************************************************************************
    1304              : !> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
    1305              : !> \param template ...
    1306              : !> \param matrix_in ...
    1307              : !> \param matrix_out ...
    1308              : !> \param trans ...
    1309              : !> \param split_rc ...
    1310              : !> \param nodata ...
    1311              : !> \param move_data ...
    1312              : !> \author Patrick Seewald
    1313              : ! **************************************************************************************************
    1314       102984 :    SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
    1315              :       TYPE(dbt_tas_type), INTENT(IN)                     :: template
    1316              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_in
    1317              :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix_out
    1318              :       LOGICAL, INTENT(INOUT)                             :: trans
    1319              :       INTEGER, INTENT(IN)                                :: split_rc
    1320              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata, move_data
    1321              : 
    1322        14712 :       CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
    1323              : 
    1324        88272 :       TYPE(dbt_tas_distribution_type)          :: dist_new
    1325       161832 :       TYPE(dbt_tas_split_info)                 :: info_template, info_matrix
    1326              :       INTEGER                                    :: dim_split_template, dim_split_matrix, &
    1327              :                                                     handle
    1328              :       INTEGER, DIMENSION(2)                      :: pdims
    1329              :       LOGICAL                                    :: nodata_prv, transposed
    1330        14712 :       TYPE(mp_cart_type) :: mp_comm
    1331              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template'
    1332              : 
    1333        14712 :       CALL timeset(routineN, handle)
    1334              : 
    1335        14712 :       IF (PRESENT(nodata)) THEN
    1336        12958 :          nodata_prv = nodata
    1337              :       ELSE
    1338              :          nodata_prv = .FALSE.
    1339              :       END IF
    1340              : 
    1341        14712 :       info_template = dbt_tas_info(template)
    1342        14712 :       info_matrix = dbt_tas_info(matrix_in)
    1343              : 
    1344        14712 :       dim_split_template = info_template%split_rowcol
    1345        14712 :       dim_split_matrix = split_rc
    1346              : 
    1347        14712 :       transposed = dim_split_template .NE. dim_split_matrix
    1348        14712 :       IF (transposed) trans = .NOT. trans
    1349              : 
    1350        44136 :       pdims = info_template%mp_comm%num_pe_cart
    1351              : 
    1352         7946 :       SELECT CASE (dim_split_template)
    1353              :       CASE (1)
    1354         7946 :          IF (.NOT. transposed) THEN
    1355           44 :             ALLOCATE (row_dist, source=template%dist%row_dist)
    1356           44 :             ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
    1357              :          ELSE
    1358         7902 :             ALLOCATE (row_dist, source=template%dist%row_dist)
    1359         7902 :             ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
    1360              :          END IF
    1361              :       CASE (2)
    1362        14712 :          IF (.NOT. transposed) THEN
    1363          120 :             ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
    1364          120 :             ALLOCATE (col_dist, source=template%dist%col_dist)
    1365              :          ELSE
    1366        13412 :             ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
    1367        13412 :             ALLOCATE (col_dist, source=template%dist%col_dist)
    1368              :          END IF
    1369              :       END SELECT
    1370              : 
    1371        14712 :       CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
    1372        14712 :       CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
    1373        14712 :       IF (.NOT. transposed) THEN
    1374              :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
    1375          104 :                              matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
    1376              :       ELSE
    1377              :          CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
    1378        14608 :                              matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
    1379              :       END IF
    1380              : 
    1381        14712 :       IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
    1382              : 
    1383        14712 :       CALL timestop(handle)
    1384              : 
    1385        29424 :    END SUBROUTINE
    1386              : 
    1387              : ! **************************************************************************************************
    1388              : !> \brief Estimate sparsity pattern of C resulting from A x B = C
    1389              : !>         by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
    1390              : !> \param transa ...
    1391              : !> \param transb ...
    1392              : !> \param transc ...
    1393              : !> \param matrix_a ...
    1394              : !> \param matrix_b ...
    1395              : !> \param matrix_c ...
    1396              : !> \param estimated_nze ...
    1397              : !> \param filter_eps ...
    1398              : !> \param unit_nr ...
    1399              : !> \param retain_sparsity ...
    1400              : !> \author Patrick Seewald
    1401              : ! **************************************************************************************************
    1402        56712 :    SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
    1403              :                                           estimated_nze, filter_eps, unit_nr, retain_sparsity)
    1404              :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
    1405              :       TYPE(dbt_tas_type), INTENT(INOUT), TARGET          :: matrix_a, matrix_b, matrix_c
    1406              :       INTEGER(int_8), INTENT(OUT)                        :: estimated_nze
    1407              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
    1408              :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    1409              :       LOGICAL, INTENT(IN), OPTIONAL                      :: retain_sparsity
    1410              : 
    1411              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_estimate_result_nze'
    1412              : 
    1413              :       INTEGER                                            :: col_size, handle, row_size
    1414              :       INTEGER(int_8)                                     :: col, row
    1415              :       LOGICAL                                            :: retain_sparsity_prv
    1416              :       TYPE(dbt_tas_iterator)                             :: iter
    1417              :       TYPE(dbt_tas_type), POINTER                        :: matrix_a_bnorm, matrix_b_bnorm, &
    1418              :                                                             matrix_c_bnorm
    1419        56712 :       TYPE(mp_cart_type)                                 :: mp_comm
    1420              : 
    1421        56712 :       CALL timeset(routineN, handle)
    1422              : 
    1423        56712 :       IF (PRESENT(retain_sparsity)) THEN
    1424          118 :          retain_sparsity_prv = retain_sparsity
    1425              :       ELSE
    1426              :          retain_sparsity_prv = .FALSE.
    1427              :       END IF
    1428              : 
    1429          118 :       IF (.NOT. retain_sparsity_prv) THEN
    1430      1075286 :          ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
    1431        56594 :          CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
    1432        56594 :          CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
    1433        56594 :          CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.)
    1434              : 
    1435              :          CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
    1436              :                                matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
    1437              :                                filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., &
    1438        56594 :                                simple_split=.TRUE., unit_nr=unit_nr)
    1439        56594 :          CALL dbt_tas_destroy(matrix_a_bnorm)
    1440        56594 :          CALL dbt_tas_destroy(matrix_b_bnorm)
    1441              : 
    1442        56594 :          DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
    1443              :       ELSE
    1444              :          matrix_c_bnorm => matrix_c
    1445              :       END IF
    1446              : 
    1447        56712 :       estimated_nze = 0
    1448              : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
    1449        56712 : !$OMP PRIVATE(iter,row,col,row_size,col_size)
    1450              :       CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
    1451              :       DO WHILE (dbt_tas_iterator_blocks_left(iter))
    1452              :          CALL dbt_tas_iterator_next_block(iter, row, col)
    1453              :          row_size = matrix_c%row_blk_size%data(row)
    1454              :          col_size = matrix_c%col_blk_size%data(col)
    1455              :          estimated_nze = estimated_nze + row_size*col_size
    1456              :       END DO
    1457              :       CALL dbt_tas_iterator_stop(iter)
    1458              : !$OMP END PARALLEL
    1459              : 
    1460        56712 :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
    1461        56712 :       CALL mp_comm%sum(estimated_nze)
    1462              : 
    1463        56712 :       IF (.NOT. retain_sparsity_prv) THEN
    1464        56594 :          CALL dbt_tas_destroy(matrix_c_bnorm)
    1465        56594 :          DEALLOCATE (matrix_c_bnorm)
    1466              :       END IF
    1467              : 
    1468        56712 :       CALL timestop(handle)
    1469              : 
    1470       113424 :    END SUBROUTINE
    1471              : 
    1472              : ! **************************************************************************************************
    1473              : !> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
    1474              : !>        This estimate is based on the minimization of communication volume whereby the
    1475              : !>        communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
    1476              : !> \param max_mm_dim ...
    1477              : !> \param nze_a number of non-zeroes in A
    1478              : !> \param nze_b number of non-zeroes in B
    1479              : !> \param nze_c number of non-zeroes in C
    1480              : !> \param numnodes number of MPI ranks
    1481              : !> \return estimated split factor
    1482              : !> \author Patrick Seewald
    1483              : ! **************************************************************************************************
    1484       208528 :    FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
    1485              :       INTEGER, INTENT(IN)                                :: max_mm_dim
    1486              :       INTEGER(KIND=int_8), INTENT(IN)                    :: nze_a, nze_b, nze_c
    1487              :       INTEGER, INTENT(IN)                                :: numnodes
    1488              :       INTEGER                                            :: nsplit
    1489              : 
    1490              :       INTEGER(KIND=int_8)                                :: max_nze, min_nze
    1491              :       REAL(dp)                                           :: s_opt_factor
    1492              : 
    1493       208528 :       s_opt_factor = 1.0_dp ! Could be further tuned.
    1494              : 
    1495       267688 :       SELECT CASE (max_mm_dim)
    1496              :       CASE (1)
    1497        59160 :          min_nze = MAX(nze_b, 1_int_8)
    1498       177480 :          max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8)
    1499              :       CASE (2)
    1500        64096 :          min_nze = MAX(nze_c, 1_int_8)
    1501       192288 :          max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8)
    1502              :       CASE (3)
    1503        85272 :          min_nze = MAX(nze_a, 1_int_8)
    1504       255816 :          max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8)
    1505              :       CASE DEFAULT
    1506       208528 :          CPABORT("")
    1507              :       END SELECT
    1508              : 
    1509       208528 :       nsplit = INT(MIN(INT(numnodes, KIND=int_8), NINT(REAL(max_nze, dp)/(REAL(min_nze, dp)*s_opt_factor), KIND=int_8)))
    1510       208528 :       IF (nsplit == 0) nsplit = 1
    1511              : 
    1512       208528 :    END FUNCTION
    1513              : 
    1514              : ! **************************************************************************************************
    1515              : !> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
    1516              : !> \param matrix_in ...
    1517              : !> \param matrix_out ...
    1518              : !> \param nodata ...
    1519              : !> \author Patrick Seewald
    1520              : ! **************************************************************************************************
    1521      1018692 :    SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
    1522              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_in
    1523              :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix_out
    1524              :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    1525              : 
    1526              :       CHARACTER(len=default_string_length)               :: name
    1527              :       INTEGER(KIND=int_8)                                :: column, nblkcols, nblkrows, row
    1528              :       LOGICAL                                            :: nodata_prv
    1529              :       REAL(dp), DIMENSION(1, 1)                          :: blk_put
    1530       169782 :       REAL(dp), DIMENSION(:, :), POINTER                 :: blk_get
    1531              :       TYPE(dbt_tas_blk_size_one)                         :: col_blk_size, row_blk_size
    1532              :       TYPE(dbt_tas_iterator)                             :: iter
    1533              : 
    1534              : !REAL(dp), DIMENSION(:, :), POINTER        :: dbt_put
    1535              : 
    1536       169782 :       CPASSERT(matrix_in%valid)
    1537              : 
    1538       169782 :       IF (PRESENT(nodata)) THEN
    1539        56594 :          nodata_prv = nodata
    1540              :       ELSE
    1541              :          nodata_prv = .FALSE.
    1542              :       END IF
    1543              : 
    1544       169782 :       CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
    1545       169782 :       row_blk_size = dbt_tas_blk_size_one(nblkrows)
    1546       169782 :       col_blk_size = dbt_tas_blk_size_one(nblkcols)
    1547              : 
    1548              :       ! not sure if assumption that same distribution can be taken still holds
    1549       169782 :       CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
    1550              : 
    1551       169782 :       IF (.NOT. nodata_prv) THEN
    1552       113188 :          CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
    1553              : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
    1554       113188 : !$OMP PRIVATE(iter,row,column,blk_get,blk_put)
    1555              :          CALL dbt_tas_iterator_start(iter, matrix_in)
    1556              :          DO WHILE (dbt_tas_iterator_blocks_left(iter))
    1557              :             CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
    1558              :             blk_put(1, 1) = NORM2(blk_get)
    1559              :             CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
    1560              :          END DO
    1561              :          CALL dbt_tas_iterator_stop(iter)
    1562              : !$OMP END PARALLEL
    1563              :       END IF
    1564              : 
    1565       169782 :    END SUBROUTINE
    1566              : 
    1567              : ! **************************************************************************************************
    1568              : !> \brief Convert a DBM matrix to a new process grid
    1569              : !> \param mp_comm_cart new process grid
    1570              : !> \param matrix_in ...
    1571              : !> \param matrix_out ...
    1572              : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
    1573              : !> \param nodata Data of matrix_in should not be copied to matrix_out
    1574              : !> \param optimize_pgrid Whether to change process grid
    1575              : !> \author Patrick Seewald
    1576              : ! **************************************************************************************************
    1577       625374 :    SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
    1578              :       TYPE(mp_cart_type), INTENT(IN)                     :: mp_comm_cart
    1579              :       TYPE(dbm_type), INTENT(INOUT)                      :: matrix_in
    1580              :       TYPE(dbm_type), INTENT(OUT)                        :: matrix_out
    1581              :       LOGICAL, INTENT(IN), OPTIONAL                      :: move_data, nodata, optimize_pgrid
    1582              : 
    1583              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_to_new_pgrid'
    1584              : 
    1585              :       CHARACTER(len=default_string_length)               :: name
    1586              :       INTEGER                                            :: handle, nbcols, nbrows
    1587       625374 :       INTEGER, CONTIGUOUS, DIMENSION(:), POINTER         :: col_dist, rbsize, rcsize, row_dist
    1588              :       INTEGER, DIMENSION(2)                              :: pdims
    1589              :       LOGICAL                                            :: nodata_prv, optimize_pgrid_prv
    1590              :       TYPE(dbm_distribution_obj)                         :: dist, dist_old
    1591              : 
    1592       625374 :       NULLIFY (row_dist, col_dist, rbsize, rcsize)
    1593              : 
    1594       625374 :       CALL timeset(routineN, handle)
    1595              : 
    1596       625374 :       IF (PRESENT(optimize_pgrid)) THEN
    1597       625374 :          optimize_pgrid_prv = optimize_pgrid
    1598              :       ELSE
    1599              :          optimize_pgrid_prv = .TRUE.
    1600              :       END IF
    1601              : 
    1602       625374 :       IF (PRESENT(nodata)) THEN
    1603       208458 :          nodata_prv = nodata
    1604              :       ELSE
    1605              :          nodata_prv = .FALSE.
    1606              :       END IF
    1607              : 
    1608       625374 :       name = dbm_get_name(matrix_in)
    1609              : 
    1610       625374 :       IF (.NOT. optimize_pgrid_prv) THEN
    1611       625374 :          CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
    1612       625374 :          IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
    1613       625374 :          CALL timestop(handle)
    1614       625374 :          RETURN
    1615              :       END IF
    1616              : 
    1617            0 :       rbsize => dbm_get_row_block_sizes(matrix_in)
    1618            0 :       rcsize => dbm_get_col_block_sizes(matrix_in)
    1619            0 :       nbrows = SIZE(rbsize)
    1620            0 :       nbcols = SIZE(rcsize)
    1621            0 :       dist_old = dbm_get_distribution(matrix_in)
    1622            0 :       pdims = mp_comm_cart%num_pe_cart
    1623              : 
    1624            0 :       ALLOCATE (row_dist(nbrows), col_dist(nbcols))
    1625            0 :       CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
    1626            0 :       CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
    1627              : 
    1628            0 :       CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
    1629            0 :       DEALLOCATE (row_dist, col_dist)
    1630              : 
    1631            0 :       CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
    1632            0 :       CALL dbm_distribution_release(dist)
    1633              : 
    1634            0 :       IF (.NOT. nodata_prv) THEN
    1635            0 :          CALL dbm_redistribute(matrix_in, matrix_out)
    1636            0 :          IF (PRESENT(move_data)) THEN
    1637            0 :             IF (move_data) CALL dbm_clear(matrix_in)
    1638              :          END IF
    1639              :       END IF
    1640              : 
    1641            0 :       CALL timestop(handle)
    1642       625374 :    END SUBROUTINE
    1643              : 
    1644              : ! **************************************************************************************************
    1645              : !> \brief ...
    1646              : !> \param matrix ...
    1647              : !> \author Patrick Seewald
    1648              : ! **************************************************************************************************
    1649        79575 :    SUBROUTINE dbt_tas_batched_mm_init(matrix)
    1650              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1651              : 
    1652        79575 :       CALL dbt_tas_set_batched_state(matrix, state=1)
    1653        79575 :       ALLOCATE (matrix%mm_storage)
    1654              :       matrix%mm_storage%batched_out = .FALSE.
    1655        79575 :    END SUBROUTINE
    1656              : 
    1657              : ! **************************************************************************************************
    1658              : !> \brief ...
    1659              : !> \param matrix ...
    1660              : !> \author Patrick Seewald
    1661              : ! **************************************************************************************************
    1662       159150 :    SUBROUTINE dbt_tas_batched_mm_finalize(matrix)
    1663              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1664              : 
    1665              :       INTEGER                                            :: handle
    1666              : 
    1667        79575 :       CALL matrix%dist%info%mp_comm%sync()
    1668        79575 :       CALL timeset("dbt_tas_total", handle)
    1669              : 
    1670        79575 :       IF (matrix%do_batched == 0) RETURN
    1671              : 
    1672        79575 :       IF (matrix%mm_storage%batched_out) THEN
    1673        27562 :          CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
    1674              :       END IF
    1675              : 
    1676        79575 :       CALL dbt_tas_batched_mm_complete(matrix)
    1677              : 
    1678        79575 :       matrix%mm_storage%batched_out = .FALSE.
    1679              : 
    1680        79575 :       DEALLOCATE (matrix%mm_storage)
    1681        79575 :       CALL dbt_tas_set_batched_state(matrix, state=0)
    1682              : 
    1683        79575 :       CALL matrix%dist%info%mp_comm%sync()
    1684        79575 :       CALL timestop(handle)
    1685              : 
    1686              :    END SUBROUTINE
    1687              : 
    1688              : ! **************************************************************************************************
    1689              : !> \brief set state flags during batched multiplication
    1690              : !> \param matrix ...
    1691              : !> \param state 0 no batched MM
    1692              : !>              1 batched MM but mm_storage not yet initialized
    1693              : !>              2 batched MM and mm_storage requires update
    1694              : !>              3 batched MM and mm_storage initialized
    1695              : !> \param opt_grid whether process grid was already optimized and should not be changed
    1696              : !> \author Patrick Seewald
    1697              : ! **************************************************************************************************
    1698      1192295 :    SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
    1699              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1700              :       INTEGER, INTENT(IN), OPTIONAL                      :: state
    1701              :       LOGICAL, INTENT(IN), OPTIONAL                      :: opt_grid
    1702              : 
    1703      1192295 :       IF (PRESENT(opt_grid)) THEN
    1704       886036 :          matrix%has_opt_pgrid = opt_grid
    1705       886036 :          matrix%dist%info%strict_split(1) = .TRUE.
    1706              :       END IF
    1707              : 
    1708      1192295 :       IF (PRESENT(state)) THEN
    1709       911183 :          matrix%do_batched = state
    1710       648088 :          SELECT CASE (state)
    1711              :          CASE (0, 1)
    1712              :             ! reset to default
    1713       648088 :             IF (matrix%has_opt_pgrid) THEN
    1714       401388 :                matrix%dist%info%strict_split(1) = .TRUE.
    1715              :             ELSE
    1716       246700 :                matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
    1717              :             END IF
    1718              :          CASE (2, 3)
    1719       263095 :             matrix%dist%info%strict_split(1) = .TRUE.
    1720              :          CASE DEFAULT
    1721       911183 :             CPABORT("should not happen")
    1722              :          END SELECT
    1723              :       END IF
    1724      1192295 :    END SUBROUTINE
    1725              : 
    1726              : ! **************************************************************************************************
    1727              : !> \brief ...
    1728              : !> \param matrix ...
    1729              : !> \param warn ...
    1730              : !> \author Patrick Seewald
    1731              : ! **************************************************************************************************
    1732       978717 :    SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
    1733              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix
    1734              :       LOGICAL, INTENT(IN), OPTIONAL                      :: warn
    1735              : 
    1736       978717 :       IF (matrix%do_batched == 0) RETURN
    1737              :       ASSOCIATE (storage => matrix%mm_storage)
    1738        82931 :          IF (PRESENT(warn)) THEN
    1739         1588 :             IF (warn .AND. matrix%do_batched == 3) THEN
    1740              :                CALL cp_warn(__LOCATION__, &
    1741            0 :                             "Optimizations for batched multiplication are disabled because of conflicting data access")
    1742              :             END IF
    1743              :          END IF
    1744        82931 :          IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
    1745              : 
    1746              :             CALL dbt_tas_merge(storage%store_batched%matrix, &
    1747        28714 :                                storage%store_batched_repl, move_data=.TRUE.)
    1748              : 
    1749              :             CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.TRUE., &
    1750        28714 :                                  transposed=storage%batched_trans, move_data=.TRUE.)
    1751        28714 :             CALL dbt_tas_destroy(storage%store_batched)
    1752        28714 :             DEALLOCATE (storage%store_batched)
    1753              :          END IF
    1754              : 
    1755       165862 :          IF (ASSOCIATED(storage%store_batched_repl)) THEN
    1756        64178 :             CALL dbt_tas_destroy(storage%store_batched_repl)
    1757        64178 :             DEALLOCATE (storage%store_batched_repl)
    1758              :          END IF
    1759              :       END ASSOCIATE
    1760              : 
    1761        82931 :       CALL dbt_tas_set_batched_state(matrix, state=2)
    1762              : 
    1763              :    END SUBROUTINE
    1764              : 
    1765      1735668 : END MODULE
        

Generated by: LCOV version 2.0-1