LCOV - code coverage report
Current view: top level - src/dbt/tas - dbt_tas_test.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 64.0 % 211 135
Test Date: 2025-07-25 12:55:17 Functions: 83.3 % 6 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief testing infrastructure for tall-and-skinny matrices
      10              : !> \author Patrick Seewald
      11              : ! **************************************************************************************************
      12              : MODULE dbt_tas_test
      13              :    USE dbm_api,                         ONLY: &
      14              :         dbm_add, dbm_checksum, dbm_create, dbm_distribution_new, dbm_distribution_obj, &
      15              :         dbm_distribution_release, dbm_finalize, dbm_get_col_block_sizes, dbm_get_name, &
      16              :         dbm_get_row_block_sizes, dbm_maxabs, dbm_multiply, dbm_redistribute, dbm_release, &
      17              :         dbm_scale, dbm_type
      18              :    USE dbm_tests,                       ONLY: generate_larnv_seed
      19              :    USE dbt_tas_base,                    ONLY: &
      20              :         dbt_tas_convert_to_dbm, dbt_tas_create, dbt_tas_distribution_new, dbt_tas_finalize, &
      21              :         dbt_tas_get_stored_coordinates, dbt_tas_info, dbt_tas_nblkcols_total, &
      22              :         dbt_tas_nblkrows_total, dbt_tas_put_block
      23              :    USE dbt_tas_global,                  ONLY: dbt_tas_blk_size_arb,&
      24              :                                               dbt_tas_default_distvec,&
      25              :                                               dbt_tas_dist_cyclic
      26              :    USE dbt_tas_mm,                      ONLY: dbt_tas_multiply
      27              :    USE dbt_tas_split,                   ONLY: dbt_tas_get_split_info,&
      28              :                                               dbt_tas_mp_comm
      29              :    USE dbt_tas_types,                   ONLY: dbt_tas_distribution_type,&
      30              :                                               dbt_tas_type
      31              :    USE kinds,                           ONLY: dp,&
      32              :                                               int_8
      33              :    USE message_passing,                 ONLY: mp_cart_type,&
      34              :                                               mp_comm_type
      35              : #include "../../base/base_uses.f90"
      36              : 
      37              :    IMPLICIT NONE
      38              :    PRIVATE
      39              : 
      40              :    PUBLIC :: &
      41              :       dbt_tas_benchmark_mm, &
      42              :       dbt_tas_checksum, &
      43              :       dbt_tas_random_bsizes, &
      44              :       dbt_tas_setup_test_matrix, &
      45              :       dbt_tas_test_mm, &
      46              :       dbt_tas_reset_randmat_seed
      47              : 
      48              :    INTEGER, SAVE :: randmat_counter = 0
      49              :    INTEGER, PARAMETER, PRIVATE :: rand_seed_init = 12341313
      50              : 
      51              : CONTAINS
      52              : 
      53              : ! **************************************************************************************************
      54              : !> \brief Setup tall-and-skinny matrix for testing
      55              : !> \param matrix ...
      56              : !> \param mp_comm_out ...
      57              : !> \param mp_comm ...
      58              : !> \param nrows ...
      59              : !> \param ncols ...
      60              : !> \param rbsizes ...
      61              : !> \param cbsizes ...
      62              : !> \param dist_splitsize ...
      63              : !> \param name ...
      64              : !> \param sparsity ...
      65              : !> \param reuse_comm ...
      66              : !> \author Patrick Seewald
      67              : ! **************************************************************************************************
      68           72 :    SUBROUTINE dbt_tas_setup_test_matrix(matrix, mp_comm_out, mp_comm, nrows, ncols, rbsizes, &
      69           12 :                                         cbsizes, dist_splitsize, name, sparsity, reuse_comm)
      70              :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix
      71              :       TYPE(mp_cart_type), INTENT(OUT)                    :: mp_comm_out
      72              : 
      73              :       CLASS(mp_comm_type), INTENT(IN)                     :: mp_comm
      74              :       INTEGER(KIND=int_8), INTENT(IN)                    :: nrows, ncols
      75              :       INTEGER, DIMENSION(nrows), INTENT(IN)              :: rbsizes
      76              :       INTEGER, DIMENSION(ncols), INTENT(IN)              :: cbsizes
      77              :       INTEGER, DIMENSION(2), INTENT(IN)                  :: dist_splitsize
      78              :       CHARACTER(len=*), INTENT(IN)                       :: name
      79              :       REAL(KIND=dp), INTENT(IN)                          :: sparsity
      80              :       LOGICAL, INTENT(IN), OPTIONAL                      :: reuse_comm
      81              : 
      82              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_setup_test_matrix'
      83              : 
      84              :       INTEGER                                            :: col_size, handle, max_col_size, max_nze, &
      85              :                                                             max_row_size, mynode, node_holds_blk, &
      86              :                                                             nze, row_size
      87              :       INTEGER(KIND=int_8)                                :: col, col_s, ncol, nrow, row, row_s
      88              :       INTEGER, DIMENSION(2)                              :: pdims
      89              :       INTEGER, DIMENSION(4)                              :: iseed, jseed
      90              :       LOGICAL                                            :: reuse_comm_prv, tr
      91           12 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: values
      92              :       REAL(KIND=dp), DIMENSION(1)                        :: rn
      93           24 :       TYPE(dbt_tas_blk_size_arb)                         :: cbsize_obj, rbsize_obj
      94              :       TYPE(dbt_tas_dist_cyclic)                          :: col_dist_obj, row_dist_obj
      95           60 :       TYPE(dbt_tas_distribution_type)                    :: dist
      96              : 
      97              :       ! we don't reserve blocks prior to putting them, so this time is meaningless and should not
      98              :       ! be considered in benchmark!
      99           12 :       CALL timeset(routineN, handle)
     100              : 
     101              :       ! Check that the counter was initialised (or has not overflowed)
     102           12 :       CPASSERT(randmat_counter .NE. 0)
     103              :       ! the counter goes into the seed. Every new call gives a new random matrix
     104           12 :       randmat_counter = randmat_counter + 1
     105              : 
     106           12 :       IF (PRESENT(reuse_comm)) THEN
     107            0 :          reuse_comm_prv = reuse_comm
     108              :       ELSE
     109              :          reuse_comm_prv = .FALSE.
     110              :       END IF
     111              : 
     112            0 :       IF (reuse_comm_prv) THEN
     113            0 :          mp_comm_out = mp_comm
     114              :       ELSE
     115           12 :          mp_comm_out = dbt_tas_mp_comm(mp_comm, nrows, ncols)
     116              :       END IF
     117              : 
     118           12 :       mynode = mp_comm_out%mepos
     119           36 :       pdims = mp_comm_out%num_pe_cart
     120              : 
     121           12 :       row_dist_obj = dbt_tas_dist_cyclic(dist_splitsize(1), pdims(1), nrows)
     122           12 :       col_dist_obj = dbt_tas_dist_cyclic(dist_splitsize(2), pdims(2), ncols)
     123              : 
     124           12 :       rbsize_obj = dbt_tas_blk_size_arb(rbsizes)
     125           12 :       cbsize_obj = dbt_tas_blk_size_arb(cbsizes)
     126              : 
     127           12 :       CALL dbt_tas_distribution_new(dist, mp_comm_out, row_dist_obj, col_dist_obj)
     128              :       CALL dbt_tas_create(matrix, name=TRIM(name), dist=dist, &
     129           12 :                           row_blk_size=rbsize_obj, col_blk_size=cbsize_obj, own_dist=.TRUE.)
     130              : 
     131          532 :       max_row_size = MAXVAL(rbsizes)
     132          532 :       max_col_size = MAXVAL(cbsizes)
     133           12 :       max_nze = max_row_size*max_col_size
     134              : 
     135           12 :       nrow = dbt_tas_nblkrows_total(matrix)
     136           12 :       ncol = dbt_tas_nblkcols_total(matrix)
     137              : 
     138           48 :       ALLOCATE (values(max_row_size, max_col_size))
     139              : 
     140           12 :       jseed = generate_larnv_seed(7, 42, 3, 42, randmat_counter)
     141              : 
     142          532 :       DO row = 1, dbt_tas_nblkrows_total(matrix)
     143        13332 :          DO col = 1, dbt_tas_nblkcols_total(matrix)
     144        12800 :             CALL dlarnv(1, jseed, 1, rn)
     145        13320 :             IF (rn(1) .LT. sparsity) THEN
     146         1306 :                tr = .FALSE.
     147         1306 :                row_s = row; col_s = col
     148         1306 :                CALL dbt_tas_get_stored_coordinates(matrix, row_s, col_s, node_holds_blk)
     149              : 
     150         1306 :                IF (node_holds_blk .EQ. mynode) THEN
     151          653 :                   row_size = rbsize_obj%data(row_s)
     152          653 :                   col_size = cbsize_obj%data(col_s)
     153          653 :                   nze = row_size*col_size
     154          653 :                   iseed = generate_larnv_seed(INT(row_s), INT(nrow), INT(col_s), INT(ncol), randmat_counter)
     155          653 :                   CALL dlarnv(1, iseed, max_nze, values)
     156          653 :                   CALL dbt_tas_put_block(matrix, row_s, col_s, values(1:row_size, 1:col_size))
     157              :                END IF
     158              :             END IF
     159              :          END DO
     160              :       END DO
     161              : 
     162           12 :       CALL dbt_tas_finalize(matrix)
     163              : 
     164           12 :       CALL timestop(handle)
     165              : 
     166           48 :    END SUBROUTINE
     167              : 
     168              : ! **************************************************************************************************
     169              : !> \brief Benchmark routine. Due to random sparsity (as opposed to structured sparsity pattern),
     170              : !>        this may not be representative for actual applications.
     171              : !> \param transa ...
     172              : !> \param transb ...
     173              : !> \param transc ...
     174              : !> \param matrix_a ...
     175              : !> \param matrix_b ...
     176              : !> \param matrix_c ...
     177              : !> \param compare_dbm ...
     178              : !> \param filter_eps ...
     179              : !> \param io_unit ...
     180              : !> \author Patrick Seewald
     181              : ! **************************************************************************************************
     182            0 :    SUBROUTINE dbt_tas_benchmark_mm(transa, transb, transc, matrix_a, matrix_b, matrix_c, compare_dbm, filter_eps, io_unit)
     183              : 
     184              :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
     185              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_a, matrix_b, matrix_c
     186              :       LOGICAL, INTENT(IN)                                :: compare_dbm
     187              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
     188              :       INTEGER, INTENT(IN), OPTIONAL                      :: io_unit
     189              : 
     190              :       INTEGER                                            :: handle1, handle2
     191            0 :       INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: cd_a, cd_b, cd_c, col_block_sizes_a, &
     192            0 :          col_block_sizes_b, col_block_sizes_c, rd_a, rd_b, rd_c, row_block_sizes_a, &
     193            0 :          row_block_sizes_b, row_block_sizes_c
     194              :       INTEGER, DIMENSION(2)                              :: npdims
     195              :       TYPE(dbm_distribution_obj)                         :: dist_a, dist_b, dist_c
     196              :       TYPE(dbm_type)                                     :: dbm_a, dbm_a_mm, dbm_b, dbm_b_mm, dbm_c, &
     197              :                                                             dbm_c_mm
     198            0 :       TYPE(mp_cart_type)                                 :: comm_dbm, mp_comm
     199              : 
     200              : !
     201              : ! TODO: Dedup with code in dbt_tas_test_mm.
     202              : !
     203            0 :       IF (PRESENT(io_unit)) THEN
     204            0 :       IF (io_unit > 0) THEN
     205            0 :          WRITE (io_unit, "(A)") "starting tall-and-skinny benchmark"
     206              :       END IF
     207              :       END IF
     208            0 :       CALL timeset("benchmark_tas_mm", handle1)
     209              :       CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a, matrix_b, &
     210              :                             0.0_dp, matrix_c, &
     211            0 :                             filter_eps=filter_eps, unit_nr=io_unit)
     212            0 :       CALL timestop(handle1)
     213            0 :       IF (PRESENT(io_unit)) THEN
     214            0 :       IF (io_unit > 0) THEN
     215            0 :          WRITE (io_unit, "(A)") "tall-and-skinny benchmark completed"
     216              :       END IF
     217              :       END IF
     218              : 
     219            0 :       IF (compare_dbm) THEN
     220            0 :          CALL dbt_tas_convert_to_dbm(matrix_a, dbm_a)
     221            0 :          CALL dbt_tas_convert_to_dbm(matrix_b, dbm_b)
     222            0 :          CALL dbt_tas_convert_to_dbm(matrix_c, dbm_c)
     223              : 
     224            0 :          CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
     225            0 :          npdims(:) = 0
     226            0 :          CALL comm_dbm%create(mp_comm, 2, npdims)
     227              : 
     228            0 :          ALLOCATE (rd_a(SIZE(dbm_get_row_block_sizes(dbm_a))))
     229            0 :          ALLOCATE (rd_b(SIZE(dbm_get_row_block_sizes(dbm_b))))
     230            0 :          ALLOCATE (rd_c(SIZE(dbm_get_row_block_sizes(dbm_c))))
     231            0 :          ALLOCATE (cd_a(SIZE(dbm_get_col_block_sizes(dbm_a))))
     232            0 :          ALLOCATE (cd_b(SIZE(dbm_get_col_block_sizes(dbm_b))))
     233            0 :          ALLOCATE (cd_c(SIZE(dbm_get_col_block_sizes(dbm_c))))
     234              : 
     235              :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_a))), &
     236            0 :                                       npdims(1), dbm_get_row_block_sizes(dbm_a), rd_a)
     237              :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_a))), &
     238            0 :                                       npdims(2), dbm_get_col_block_sizes(dbm_a), cd_a)
     239              :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_b))), &
     240            0 :                                       npdims(1), dbm_get_row_block_sizes(dbm_b), rd_b)
     241              :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_b))), &
     242            0 :                                       npdims(2), dbm_get_col_block_sizes(dbm_b), cd_b)
     243              :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_c))), &
     244            0 :                                       npdims(1), dbm_get_row_block_sizes(dbm_c), rd_c)
     245              :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_c))), &
     246            0 :                                       npdims(2), dbm_get_col_block_sizes(dbm_c), cd_c)
     247              : 
     248            0 :          CALL dbm_distribution_new(dist_a, comm_dbm, rd_a, cd_a)
     249            0 :          CALL dbm_distribution_new(dist_b, comm_dbm, rd_b, cd_b)
     250            0 :          CALL dbm_distribution_new(dist_c, comm_dbm, rd_c, cd_c)
     251            0 :          DEALLOCATE (rd_a, rd_b, rd_c, cd_a, cd_b, cd_c)
     252              : 
     253              :          ! Store pointers in intermediate variables to workaround a CCE error.
     254            0 :          row_block_sizes_a => dbm_get_row_block_sizes(dbm_a)
     255            0 :          col_block_sizes_a => dbm_get_col_block_sizes(dbm_a)
     256            0 :          row_block_sizes_b => dbm_get_row_block_sizes(dbm_b)
     257            0 :          col_block_sizes_b => dbm_get_col_block_sizes(dbm_b)
     258            0 :          row_block_sizes_c => dbm_get_row_block_sizes(dbm_c)
     259            0 :          col_block_sizes_c => dbm_get_col_block_sizes(dbm_c)
     260              : 
     261              :          CALL dbm_create(matrix=dbm_a_mm, name=dbm_get_name(dbm_a), dist=dist_a, &
     262            0 :                          row_block_sizes=row_block_sizes_a, col_block_sizes=col_block_sizes_a)
     263              : 
     264              :          CALL dbm_create(matrix=dbm_b_mm, name=dbm_get_name(dbm_b), dist=dist_b, &
     265            0 :                          row_block_sizes=row_block_sizes_b, col_block_sizes=col_block_sizes_b)
     266              : 
     267              :          CALL dbm_create(matrix=dbm_c_mm, name=dbm_get_name(dbm_c), dist=dist_c, &
     268            0 :                          row_block_sizes=row_block_sizes_c, col_block_sizes=col_block_sizes_c)
     269              : 
     270            0 :          CALL dbm_finalize(dbm_a_mm)
     271            0 :          CALL dbm_finalize(dbm_b_mm)
     272            0 :          CALL dbm_finalize(dbm_c_mm)
     273              : 
     274            0 :          CALL dbm_redistribute(dbm_a, dbm_a_mm)
     275            0 :          CALL dbm_redistribute(dbm_b, dbm_b_mm)
     276            0 :          IF (PRESENT(io_unit)) THEN
     277            0 :          IF (io_unit > 0) THEN
     278            0 :             WRITE (io_unit, "(A)") "starting dbm benchmark"
     279              :          END IF
     280              :          END IF
     281            0 :          CALL timeset("benchmark_block_mm", handle2)
     282              :          CALL dbm_multiply(transa, transb, 1.0_dp, dbm_a_mm, dbm_b_mm, &
     283            0 :                            0.0_dp, dbm_c_mm, filter_eps=filter_eps)
     284            0 :          CALL timestop(handle2)
     285            0 :          IF (PRESENT(io_unit)) THEN
     286            0 :          IF (io_unit > 0) THEN
     287            0 :             WRITE (io_unit, "(A)") "dbm benchmark completed"
     288              :          END IF
     289              :          END IF
     290              : 
     291            0 :          CALL dbm_release(dbm_a)
     292            0 :          CALL dbm_release(dbm_b)
     293            0 :          CALL dbm_release(dbm_c)
     294            0 :          CALL dbm_release(dbm_a_mm)
     295            0 :          CALL dbm_release(dbm_b_mm)
     296            0 :          CALL dbm_release(dbm_c_mm)
     297            0 :          CALL dbm_distribution_release(dist_a)
     298            0 :          CALL dbm_distribution_release(dist_b)
     299            0 :          CALL dbm_distribution_release(dist_c)
     300              : 
     301            0 :          CALL comm_dbm%free()
     302              :       END IF
     303              : 
     304            0 :    END SUBROUTINE
     305              : 
     306              : ! **************************************************************************************************
     307              : !> \brief Test tall-and-skinny matrix multiplication for accuracy
     308              : !> \param transa ...
     309              : !> \param transb ...
     310              : !> \param transc ...
     311              : !> \param matrix_a ...
     312              : !> \param matrix_b ...
     313              : !> \param matrix_c ...
     314              : !> \param filter_eps ...
     315              : !> \param unit_nr ...
     316              : !> \param log_verbose ...
     317              : !> \author Patrick Seewald
     318              : ! **************************************************************************************************
     319           48 :    SUBROUTINE dbt_tas_test_mm(transa, transb, transc, matrix_a, matrix_b, matrix_c, filter_eps, unit_nr, log_verbose)
     320              :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
     321              :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_a, matrix_b, matrix_c
     322              :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
     323              :       INTEGER, INTENT(IN)                                :: unit_nr
     324              :       LOGICAL, INTENT(IN), OPTIONAL                      :: log_verbose
     325              : 
     326              :       REAL(KIND=dp), PARAMETER                           :: test_tol = 1.0E-10_dp
     327              : 
     328              :       CHARACTER(LEN=8)                                   :: status_str
     329              :       INTEGER                                            :: io_unit, mynode
     330           48 :       INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: cd_a, cd_b, cd_c, col_block_sizes_a, &
     331           48 :          col_block_sizes_b, col_block_sizes_c, rd_a, rd_b, rd_c, row_block_sizes_a, &
     332           48 :          row_block_sizes_b, row_block_sizes_c
     333              :       INTEGER, DIMENSION(2)                              :: npdims
     334              :       LOGICAL                                            :: abort, transa_prv, transb_prv
     335              :       REAL(KIND=dp)                                      :: norm, rc_cs, sq_cs
     336              :       TYPE(dbm_distribution_obj)                         :: dist_a, dist_b, dist_c
     337              :       TYPE(dbm_type)                                     :: dbm_a, dbm_a_mm, dbm_b, dbm_b_mm, dbm_c, &
     338              :                                                             dbm_c_mm, dbm_c_mm_check
     339           48 :       TYPE(mp_cart_type)                                 :: comm_dbm, mp_comm
     340              : 
     341              : !
     342              : ! TODO: Dedup with code in dbt_tas_benchmark_mm.
     343              : !
     344              : 
     345           48 :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
     346           48 :       mynode = mp_comm%mepos
     347           48 :       abort = .FALSE.
     348           48 :       io_unit = -1
     349           48 :       IF (mynode .EQ. 0) io_unit = unit_nr
     350              : 
     351              :       CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a, matrix_b, &
     352              :                             0.0_dp, matrix_c, &
     353           48 :                             filter_eps=filter_eps, unit_nr=io_unit, log_verbose=log_verbose, optimize_dist=.TRUE.)
     354              : 
     355           48 :       CALL dbt_tas_convert_to_dbm(matrix_a, dbm_a)
     356           48 :       CALL dbt_tas_convert_to_dbm(matrix_b, dbm_b)
     357           48 :       CALL dbt_tas_convert_to_dbm(matrix_c, dbm_c)
     358              : 
     359           48 :       npdims(:) = 0
     360           48 :       CALL comm_dbm%create(mp_comm, 2, npdims)
     361              : 
     362          144 :       ALLOCATE (rd_a(SIZE(dbm_get_row_block_sizes(dbm_a))))
     363          144 :       ALLOCATE (rd_b(SIZE(dbm_get_row_block_sizes(dbm_b))))
     364          144 :       ALLOCATE (rd_c(SIZE(dbm_get_row_block_sizes(dbm_c))))
     365          144 :       ALLOCATE (cd_a(SIZE(dbm_get_col_block_sizes(dbm_a))))
     366          144 :       ALLOCATE (cd_b(SIZE(dbm_get_col_block_sizes(dbm_b))))
     367          144 :       ALLOCATE (cd_c(SIZE(dbm_get_col_block_sizes(dbm_c))))
     368              : 
     369              :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_a))), &
     370           48 :                                    npdims(1), dbm_get_row_block_sizes(dbm_a), rd_a)
     371              :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_a))), &
     372           48 :                                    npdims(2), dbm_get_col_block_sizes(dbm_a), cd_a)
     373              :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_b))), &
     374           48 :                                    npdims(1), dbm_get_row_block_sizes(dbm_b), rd_b)
     375              :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_b))), &
     376           48 :                                    npdims(2), dbm_get_col_block_sizes(dbm_b), cd_b)
     377              :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_c))), &
     378           48 :                                    npdims(1), dbm_get_row_block_sizes(dbm_c), rd_c)
     379              :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_c))), &
     380           48 :                                    npdims(2), dbm_get_col_block_sizes(dbm_c), cd_c)
     381              : 
     382           48 :       CALL dbm_distribution_new(dist_a, comm_dbm, rd_a, cd_a)
     383           48 :       CALL dbm_distribution_new(dist_b, comm_dbm, rd_b, cd_b)
     384           48 :       CALL dbm_distribution_new(dist_c, comm_dbm, rd_c, cd_c)
     385           48 :       DEALLOCATE (rd_a, rd_b, rd_c, cd_a, cd_b, cd_c)
     386              : 
     387              :       ! Store pointers in intermediate variables to workaround a CCE error.
     388           48 :       row_block_sizes_a => dbm_get_row_block_sizes(dbm_a)
     389           48 :       col_block_sizes_a => dbm_get_col_block_sizes(dbm_a)
     390           48 :       row_block_sizes_b => dbm_get_row_block_sizes(dbm_b)
     391           48 :       col_block_sizes_b => dbm_get_col_block_sizes(dbm_b)
     392           48 :       row_block_sizes_c => dbm_get_row_block_sizes(dbm_c)
     393           48 :       col_block_sizes_c => dbm_get_col_block_sizes(dbm_c)
     394              : 
     395              :       CALL dbm_create(matrix=dbm_a_mm, name="matrix a", dist=dist_a, &
     396           48 :                       row_block_sizes=row_block_sizes_a, col_block_sizes=col_block_sizes_a)
     397              : 
     398              :       CALL dbm_create(matrix=dbm_b_mm, name="matrix b", dist=dist_b, &
     399           48 :                       row_block_sizes=row_block_sizes_b, col_block_sizes=col_block_sizes_b)
     400              : 
     401              :       CALL dbm_create(matrix=dbm_c_mm, name="matrix c", dist=dist_c, &
     402           48 :                       row_block_sizes=row_block_sizes_c, col_block_sizes=col_block_sizes_c)
     403              : 
     404              :       CALL dbm_create(matrix=dbm_c_mm_check, name="matrix c check", dist=dist_c, &
     405           48 :                       row_block_sizes=row_block_sizes_c, col_block_sizes=col_block_sizes_c)
     406              : 
     407           48 :       CALL dbm_finalize(dbm_a_mm)
     408           48 :       CALL dbm_finalize(dbm_b_mm)
     409           48 :       CALL dbm_finalize(dbm_c_mm)
     410           48 :       CALL dbm_finalize(dbm_c_mm_check)
     411              : 
     412           48 :       CALL dbm_redistribute(dbm_a, dbm_a_mm)
     413           48 :       CALL dbm_redistribute(dbm_b, dbm_b_mm)
     414           48 :       CALL dbm_redistribute(dbm_c, dbm_c_mm_check)
     415              : 
     416           48 :       transa_prv = transa; transb_prv = transb
     417              : 
     418           48 :       IF (.NOT. transc) THEN
     419              :          CALL dbm_multiply(transa_prv, transb_prv, 1.0_dp, &
     420              :                            dbm_a_mm, dbm_b_mm, &
     421           24 :                            0.0_dp, dbm_c_mm, filter_eps=filter_eps)
     422              :       ELSE
     423           24 :          transa_prv = .NOT. transa_prv
     424           24 :          transb_prv = .NOT. transb_prv
     425              :          CALL dbm_multiply(transb_prv, transa_prv, 1.0_dp, &
     426              :                            dbm_b_mm, dbm_a_mm, &
     427           24 :                            0.0_dp, dbm_c_mm, filter_eps=filter_eps)
     428              :       END IF
     429              : 
     430           48 :       sq_cs = dbm_checksum(dbm_c_mm)
     431           48 :       rc_cs = dbm_checksum(dbm_c_mm_check)
     432           48 :       CALL dbm_scale(dbm_c_mm_check, -1.0_dp)
     433           48 :       CALL dbm_add(dbm_c_mm_check, dbm_c_mm)
     434           48 :       norm = dbm_maxabs(dbm_c_mm_check)
     435              : 
     436           48 :       IF (io_unit > 0) THEN
     437           24 :          IF (ABS(norm) > test_tol) THEN
     438            0 :             status_str = " failed!"
     439            0 :             abort = .TRUE.
     440              :          ELSE
     441           24 :             status_str = " passed!"
     442           24 :             abort = .FALSE.
     443              :          END IF
     444              :          WRITE (io_unit, "(A)") &
     445              :             TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
     446           24 :             TRIM(dbm_get_name(matrix_b%matrix))//TRIM(status_str)
     447           24 :          WRITE (io_unit, "(A,1X,E9.2,1X,E9.2)") "checksums", sq_cs, rc_cs
     448           24 :          WRITE (io_unit, "(A,1X,E9.2)") "difference norm", norm
     449           24 :          IF (abort) CPABORT("DBT TAS test failed")
     450              :       END IF
     451              : 
     452           48 :       CALL dbm_release(dbm_a)
     453           48 :       CALL dbm_release(dbm_a_mm)
     454           48 :       CALL dbm_release(dbm_b)
     455           48 :       CALL dbm_release(dbm_b_mm)
     456           48 :       CALL dbm_release(dbm_c)
     457           48 :       CALL dbm_release(dbm_c_mm)
     458           48 :       CALL dbm_release(dbm_c_mm_check)
     459              : 
     460           48 :       CALL dbm_distribution_release(dist_a)
     461           48 :       CALL dbm_distribution_release(dist_b)
     462           48 :       CALL dbm_distribution_release(dist_c)
     463              : 
     464           48 :       CALL comm_dbm%free()
     465              : 
     466          432 :    END SUBROUTINE dbt_tas_test_mm
     467              : 
     468              : ! **************************************************************************************************
     469              : !> \brief Calculate checksum of tall-and-skinny matrix consistent with dbm_checksum
     470              : !> \param matrix ...
     471              : !> \return ...
     472              : !> \author Patrick Seewald
     473              : ! **************************************************************************************************
     474           80 :    FUNCTION dbt_tas_checksum(matrix)
     475              :       TYPE(dbt_tas_type), INTENT(IN)                     :: matrix
     476              :       REAL(KIND=dp)                                      :: dbt_tas_checksum
     477              : 
     478              :       TYPE(dbm_type)                                     :: dbm_m
     479              : 
     480           80 :       CALL dbt_tas_convert_to_dbm(matrix, dbm_m)
     481           80 :       dbt_tas_checksum = dbm_checksum(dbm_m)
     482           80 :       CALL dbm_release(dbm_m)
     483           80 :    END FUNCTION
     484              : 
     485              : ! **************************************************************************************************
     486              : !> \brief Create random block sizes
     487              : !> \param sizes ...
     488              : !> \param repeat ...
     489              : !> \param dbt_sizes ...
     490              : !> \author Patrick Seewald
     491              : ! **************************************************************************************************
     492            6 :    SUBROUTINE dbt_tas_random_bsizes(sizes, repeat, dbt_sizes)
     493              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: sizes
     494              :       INTEGER, INTENT(IN)                                :: repeat
     495              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: dbt_sizes
     496              : 
     497              :       INTEGER                                            :: d, size_i
     498              : 
     499          266 :       DO d = 1, SIZE(dbt_sizes)
     500          260 :          size_i = MOD((d - 1)/repeat, SIZE(sizes)) + 1
     501          266 :          dbt_sizes(d) = sizes(size_i)
     502              :       END DO
     503            6 :    END SUBROUTINE
     504              : 
     505              : ! **************************************************************************************************
     506              : !> \brief Reset the seed used for generating random matrices to default value
     507              : !> \author Patrick Seewald
     508              : ! **************************************************************************************************
     509            2 :    SUBROUTINE dbt_tas_reset_randmat_seed()
     510            2 :       randmat_counter = rand_seed_init
     511            2 :    END SUBROUTINE
     512              : 
     513              : END MODULE
        

Generated by: LCOV version 2.0-1