LCOV - code coverage report
Current view: top level - src/dbt/tas - dbt_tas_unittest.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 99.0 % 100 99
Test Date: 2025-07-25 12:55:17 Functions: 100.0 % 2 2

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Unit testing for tall-and-skinny matrices
      10              : !> \author Patrick Seewald
      11              : ! **************************************************************************************************
      12            2 : PROGRAM dbt_tas_unittest
      13            2 :    USE cp_dbcsr_api,                    ONLY: dbcsr_finalize_lib,&
      14              :                                               dbcsr_init_lib
      15              :    USE dbm_api,                         ONLY: dbm_get_name,&
      16              :                                               dbm_library_finalize,&
      17              :                                               dbm_library_init,&
      18              :                                               dbm_library_print_stats
      19              :    USE dbt_tas_base,                    ONLY: dbt_tas_create,&
      20              :                                               dbt_tas_destroy,&
      21              :                                               dbt_tas_info,&
      22              :                                               dbt_tas_nblkcols_total,&
      23              :                                               dbt_tas_nblkrows_total
      24              :    USE dbt_tas_io,                      ONLY: dbt_tas_write_split_info
      25              :    USE dbt_tas_test,                    ONLY: dbt_tas_random_bsizes,&
      26              :                                               dbt_tas_reset_randmat_seed,&
      27              :                                               dbt_tas_setup_test_matrix,&
      28              :                                               dbt_tas_test_mm
      29              :    USE dbt_tas_types,                   ONLY: dbt_tas_type
      30              :    USE kinds,                           ONLY: dp,&
      31              :                                               int_8
      32              :    USE machine,                         ONLY: default_output_unit
      33              :    USE message_passing,                 ONLY: mp_cart_type,&
      34              :                                               mp_comm_type,&
      35              :                                               mp_world_finalize,&
      36              :                                               mp_world_init
      37              :    USE offload_api,                     ONLY: offload_get_device_count,&
      38              :                                               offload_set_chosen_device
      39              : #include "../../base/base_uses.f90"
      40              : 
      41              :    IMPLICIT NONE
      42              : 
      43              :    INTEGER(KIND=int_8), PARAMETER :: m = 100, k = 20, n = 10
      44           98 :    TYPE(dbt_tas_type)             :: A, B, C, At, Bt, Ct, A_out, B_out, C_out, At_out, Bt_out, Ct_out
      45              :    INTEGER, DIMENSION(m)          :: bsize_m
      46              :    INTEGER, DIMENSION(n)          :: bsize_n
      47              :    INTEGER, DIMENSION(k)          :: bsize_k
      48              :    REAL(KIND=dp), PARAMETER   :: sparsity = 0.1
      49              :    INTEGER                        :: mynode, io_unit
      50              :    TYPE(mp_comm_type)             :: mp_comm
      51            2 :    TYPE(mp_cart_type) :: mp_comm_A, mp_comm_At, mp_comm_B, mp_comm_Bt, mp_comm_C, mp_comm_Ct
      52              :    REAL(KIND=dp), PARAMETER   :: filter_eps = 1.0E-08
      53              : 
      54            2 :    CALL mp_world_init(mp_comm)
      55              : 
      56            2 :    mynode = mp_comm%mepos
      57              : 
      58              :    ! Select active offload device when available.
      59            2 :    IF (offload_get_device_count() > 0) THEN
      60            0 :       CALL offload_set_chosen_device(MOD(mynode, offload_get_device_count()))
      61              :    END IF
      62              : 
      63            2 :    io_unit = -1
      64            2 :    IF (mynode .EQ. 0) io_unit = default_output_unit
      65              : 
      66            2 :    CALL dbcsr_init_lib(mp_comm%get_handle(), io_unit) ! Needed for DBM_VALIDATE_AGAINST_DBCSR.
      67            2 :    CALL dbm_library_init()
      68              : 
      69            2 :    CALL dbt_tas_reset_randmat_seed()
      70              : 
      71            2 :    CALL dbt_tas_random_bsizes([13, 8, 5, 25, 12], 2, bsize_m)
      72            2 :    CALL dbt_tas_random_bsizes([3, 78, 33, 12, 3, 15], 1, bsize_n)
      73            2 :    CALL dbt_tas_random_bsizes([9, 64, 23, 2], 3, bsize_k)
      74              : 
      75            2 :    CALL dbt_tas_setup_test_matrix(A, mp_comm_A, mp_comm, m, k, bsize_m, bsize_k, [5, 1], "A", sparsity)
      76            2 :    CALL dbt_tas_setup_test_matrix(At, mp_comm_At, mp_comm, k, m, bsize_k, bsize_m, [3, 8], "A^t", sparsity)
      77            2 :    CALL dbt_tas_setup_test_matrix(B, mp_comm_B, mp_comm, n, m, bsize_n, bsize_m, [3, 2], "B", sparsity)
      78            2 :    CALL dbt_tas_setup_test_matrix(Bt, mp_comm_Bt, mp_comm, m, n, bsize_m, bsize_n, [1, 3], "B^t", sparsity)
      79            2 :    CALL dbt_tas_setup_test_matrix(C, mp_comm_C, mp_comm, k, n, bsize_k, bsize_n, [5, 7], "C", sparsity)
      80            2 :    CALL dbt_tas_setup_test_matrix(Ct, mp_comm_Ct, mp_comm, n, k, bsize_n, bsize_k, [1, 1], "C^t", sparsity)
      81              : 
      82            2 :    CALL dbt_tas_create(A, A_out)
      83            2 :    CALL dbt_tas_create(At, At_out)
      84            2 :    CALL dbt_tas_create(B, B_out)
      85            2 :    CALL dbt_tas_create(Bt, Bt_out)
      86            2 :    CALL dbt_tas_create(C, C_out)
      87            2 :    CALL dbt_tas_create(Ct, Ct_out)
      88              : 
      89            2 :    IF (mynode == 0) WRITE (io_unit, '(A)') "DBM TALL-AND-SKINNY MATRICES"
      90            1 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      91            1 :       TRIM(dbm_get_name(A%matrix)), &
      92            2 :       dbt_tas_nblkrows_total(A), 'X', dbt_tas_nblkcols_total(A)
      93            2 :    CALL dbt_tas_write_split_info(dbt_tas_info(A), io_unit, name="A")
      94            3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      95            1 :       TRIM(dbm_get_name(At%matrix)), &
      96            2 :       dbt_tas_nblkrows_total(At), 'X', dbt_tas_nblkcols_total(At)
      97            2 :    CALL dbt_tas_write_split_info(dbt_tas_info(At), io_unit, name="At")
      98            3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      99            1 :       TRIM(dbm_get_name(B%matrix)), &
     100            2 :       dbt_tas_nblkrows_total(B), 'X', dbt_tas_nblkcols_total(B)
     101            2 :    CALL dbt_tas_write_split_info(dbt_tas_info(B), io_unit, name="B")
     102            3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
     103            1 :       TRIM(dbm_get_name(Bt%matrix)), &
     104            2 :       dbt_tas_nblkrows_total(Bt), 'X', dbt_tas_nblkcols_total(Bt)
     105            2 :    CALL dbt_tas_write_split_info(dbt_tas_info(Bt), io_unit, name="Bt")
     106            3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
     107            1 :       TRIM(dbm_get_name(C%matrix)), &
     108            2 :       dbt_tas_nblkrows_total(C), 'X', dbt_tas_nblkcols_total(C)
     109            2 :    CALL dbt_tas_write_split_info(dbt_tas_info(C), io_unit, name="C")
     110            3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
     111            1 :       TRIM(dbm_get_name(Ct%matrix)), &
     112            2 :       dbt_tas_nblkrows_total(Ct), 'X', dbt_tas_nblkcols_total(Ct)
     113            2 :    CALL dbt_tas_write_split_info(dbt_tas_info(Ct), io_unit, name="Ct")
     114              : 
     115            2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., B, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     116            2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., Bt, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     117            2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., B, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     118            2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., Bt, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     119            2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., B, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     120            2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., Bt, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     121            2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., B, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     122            2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., Bt, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     123              : 
     124            2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., A, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     125            2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., At, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     126            2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., A, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     127            2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., At, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     128              : 
     129            2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., A, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     130            2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., At, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     131            2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., A, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     132            2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., At, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     133              : 
     134            2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., C, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     135            2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., Ct, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     136            2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., C, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     137            2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., Ct, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     138              : 
     139            2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., C, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     140            2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., Ct, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     141            2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., C, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     142            2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., Ct, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     143              : 
     144            2 :    CALL dbt_tas_destroy(A)
     145            2 :    CALL dbt_tas_destroy(At)
     146            2 :    CALL dbt_tas_destroy(B)
     147            2 :    CALL dbt_tas_destroy(Bt)
     148            2 :    CALL dbt_tas_destroy(C)
     149            2 :    CALL dbt_tas_destroy(Ct)
     150            2 :    CALL dbt_tas_destroy(A_out)
     151            2 :    CALL dbt_tas_destroy(At_out)
     152            2 :    CALL dbt_tas_destroy(B_out)
     153            2 :    CALL dbt_tas_destroy(Bt_out)
     154            2 :    CALL dbt_tas_destroy(C_out)
     155            2 :    CALL dbt_tas_destroy(Ct_out)
     156              : 
     157            2 :    CALL mp_comm_A%free()
     158            2 :    CALL mp_comm_At%free()
     159            2 :    CALL mp_comm_B%free()
     160            2 :    CALL mp_comm_Bt%free()
     161            2 :    CALL mp_comm_C%free()
     162            2 :    CALL mp_comm_Ct%free()
     163              : 
     164            2 :    CALL dbm_library_print_stats(mp_comm, io_unit)
     165            2 :    CALL dbm_library_finalize()
     166            2 :    CALL dbcsr_finalize_lib() ! Needed for DBM_VALIDATE_AGAINST_DBCSR.
     167            2 :    CALL mp_world_finalize()
     168              : 
     169            2 : END PROGRAM
        

Generated by: LCOV version 2.0-1