LCOV - code coverage report
Current view: top level - src/dbt - dbt_reshape_ops.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:0de0cc2) Lines: 55 56 98.2 %
Date: 2024-03-28 07:31:50 Functions: 2 4 50.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Routines to reshape / redistribute tensors
      10             : !> \author Patrick Seewald
      11             : ! **************************************************************************************************
      12             : MODULE dbt_reshape_ops
      13             :    #:include "dbt_macros.fypp"
      14             :    #:set maxdim = maxrank
      15             :    #:set ndims = range(2,maxdim+1)
      16             : 
      17             :    USE dbt_allocate_wrap, ONLY: allocate_any
      18             :    USE dbt_tas_base, ONLY: dbt_tas_copy, dbt_tas_get_info, dbt_tas_info
      19             :    USE dbt_block, ONLY: &
      20             :       block_nd, create_block, destroy_block, dbt_iterator_type, dbt_iterator_next_block, &
      21             :       dbt_iterator_blocks_left, dbt_iterator_start, dbt_iterator_stop, dbt_get_block, &
      22             :       dbt_reserve_blocks, dbt_put_block
      23             :    USE dbt_types, ONLY: dbt_blk_sizes, &
      24             :                         dbt_create, &
      25             :                         dbt_type, &
      26             :                         ndims_tensor, &
      27             :                         dbt_get_stored_coordinates, &
      28             :                         dbt_clear
      29             :    USE kinds, ONLY: default_string_length
      30             :    USE kinds, ONLY: dp, dp
      31             :    USE message_passing, ONLY: &
      32             :       mp_waitall, mp_comm_type, mp_request_type
      33             : 
      34             : #include "../base/base_uses.f90"
      35             : 
      36             :    IMPLICIT NONE
      37             :    PRIVATE
      38             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_reshape_ops'
      39             : 
      40             :    PUBLIC :: dbt_reshape
      41             : 
      42             :    TYPE block_buffer_type
      43             :       INTEGER, DIMENSION(:, :), ALLOCATABLE      :: blocks
      44             :       REAL(dp), DIMENSION(:), ALLOCATABLE        :: data
      45             :    END TYPE
      46             : 
      47             : CONTAINS
      48             : 
      49             : ! **************************************************************************************************
      50             : !> \brief copy data (involves reshape)
      51             : !>        tensor_out = tensor_out + tensor_in move_data memory optimization:
      52             : !>        transfer data from tensor_in to tensor_out s.t. tensor_in is empty on return
      53             : !> \author Ole Schuett
      54             : ! **************************************************************************************************
      55      147344 :    SUBROUTINE dbt_reshape(tensor_in, tensor_out, summation, move_data)
      56             : 
      57             :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in, tensor_out
      58             :       LOGICAL, INTENT(IN), OPTIONAL                    :: summation
      59             :       LOGICAL, INTENT(IN), OPTIONAL                    :: move_data
      60             : 
      61             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_reshape'
      62             : 
      63             :       INTEGER                                            :: iproc, numnodes, &
      64             :                                                             handle, iblk, jblk, offset, ndata, &
      65             :                                                             nblks_recv_mythread
      66      147344 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: blks_to_allocate
      67             :       TYPE(dbt_iterator_type)                            :: iter
      68      147344 :       TYPE(block_nd)                                     :: blk_data
      69      147344 :       TYPE(block_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
      70      147344 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))        :: blk_size, ind_nd
      71             :       LOGICAL :: found, summation_prv, move_prv
      72             : 
      73      147344 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: nblks_send_total, ndata_send_total, &
      74      147344 :                                                             nblks_recv_total, ndata_recv_total, &
      75      147344 :                                                             nblks_send_mythread, ndata_send_mythread
      76             :       TYPE(mp_comm_type) :: mp_comm
      77             : 
      78      147344 :       CALL timeset(routineN, handle)
      79             : 
      80      147344 :       IF (PRESENT(summation)) THEN
      81       55278 :          summation_prv = summation
      82             :       ELSE
      83             :          summation_prv = .FALSE.
      84             :       END IF
      85             : 
      86      147344 :       IF (PRESENT(move_data)) THEN
      87      147344 :          move_prv = move_data
      88             :       ELSE
      89             :          move_prv = .FALSE.
      90             :       END IF
      91             : 
      92      147344 :       CPASSERT(tensor_out%valid)
      93             : 
      94      147344 :       IF (.NOT. summation_prv) CALL dbt_clear(tensor_out)
      95             : 
      96      147344 :       mp_comm = tensor_in%pgrid%mp_comm_2d
      97      147344 :       numnodes = mp_comm%num_pe
      98     1314100 :       ALLOCATE (buffer_send(0:numnodes - 1), buffer_recv(0:numnodes - 1))
      99     1314100 :       ALLOCATE (nblks_send_total(0:numnodes - 1), ndata_send_total(0:numnodes - 1), source=0)
     100     1314100 :       ALLOCATE (nblks_recv_total(0:numnodes - 1), ndata_recv_total(0:numnodes - 1), source=0)
     101             : 
     102             : !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
     103             : !$OMP SHARED(tensor_in,tensor_out,summation) &
     104             : !$OMP SHARED(buffer_send,buffer_recv,mp_comm,numnodes) &
     105             : !$OMP SHARED(nblks_send_total,ndata_send_total,nblks_recv_total,ndata_recv_total) &
     106             : !$OMP PRIVATE(nblks_send_mythread,ndata_send_mythread,nblks_recv_mythread) &
     107             : !$OMP PRIVATE(iter,ind_nd,blk_size,blk_data,found,iproc) &
     108      147344 : !$OMP PRIVATE(blks_to_allocate,offset,ndata,iblk,jblk)
     109             :       ALLOCATE (nblks_send_mythread(0:numnodes - 1), ndata_send_mythread(0:numnodes - 1), source=0)
     110             : 
     111             :       CALL dbt_iterator_start(iter, tensor_in)
     112             :       DO WHILE (dbt_iterator_blocks_left(iter))
     113             :          CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
     114             :          CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
     115             :          nblks_send_mythread(iproc) = nblks_send_mythread(iproc) + 1
     116             :          ndata_send_mythread(iproc) = ndata_send_mythread(iproc) + PRODUCT(blk_size)
     117             :       END DO
     118             :       CALL dbt_iterator_stop(iter)
     119             : !$OMP CRITICAL
     120             :       nblks_send_total(:) = nblks_send_total(:) + nblks_send_mythread(:)
     121             :       ndata_send_total(:) = ndata_send_total(:) + ndata_send_mythread(:)
     122             :       nblks_send_mythread(:) = nblks_send_total(:) ! current totals indicate slot for this thread
     123             :       ndata_send_mythread(:) = ndata_send_total(:)
     124             : !$OMP END CRITICAL
     125             : !$OMP BARRIER
     126             : 
     127             : !$OMP MASTER
     128             :       CALL mp_comm%alltoall(nblks_send_total, nblks_recv_total, 1)
     129             :       CALL mp_comm%alltoall(ndata_send_total, ndata_recv_total, 1)
     130             : !$OMP END MASTER
     131             : !$OMP BARRIER
     132             : 
     133             : !$OMP DO
     134             :       DO iproc = 0, numnodes - 1
     135             :          ALLOCATE (buffer_send(iproc)%data(ndata_send_total(iproc)))
     136             :          ALLOCATE (buffer_recv(iproc)%data(ndata_recv_total(iproc)))
     137             :          ! going to use buffer%blocks(:,0) to store data offsets
     138             :          ALLOCATE (buffer_send(iproc)%blocks(nblks_send_total(iproc), 0:ndims_tensor(tensor_in)))
     139             :          ALLOCATE (buffer_recv(iproc)%blocks(nblks_recv_total(iproc), 0:ndims_tensor(tensor_in)))
     140             :       END DO
     141             : !$OMP END DO
     142             : !$OMP BARRIER
     143             : 
     144             :       CALL dbt_iterator_start(iter, tensor_in)
     145             :       DO WHILE (dbt_iterator_blocks_left(iter))
     146             :          CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
     147             :          CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
     148             :          CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
     149             :          CPASSERT(found)
     150             :          ! insert block data
     151             :          ndata = PRODUCT(blk_size)
     152             :          ndata_send_mythread(iproc) = ndata_send_mythread(iproc) - ndata
     153             :          offset = ndata_send_mythread(iproc)
     154             :          buffer_send(iproc)%data(offset + 1:offset + ndata) = blk_data%blk(:)
     155             :          ! insert block index
     156             :          nblks_send_mythread(iproc) = nblks_send_mythread(iproc) - 1
     157             :          iblk = nblks_send_mythread(iproc) + 1
     158             :          buffer_send(iproc)%blocks(iblk, 1:) = ind_nd(:)
     159             :          buffer_send(iproc)%blocks(iblk, 0) = offset
     160             :          CALL destroy_block(blk_data)
     161             :       END DO
     162             :       CALL dbt_iterator_stop(iter)
     163             : !$OMP BARRIER
     164             : 
     165             :       CALL dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
     166             : !$OMP BARRIER
     167             : 
     168             : !$OMP DO
     169             :       DO iproc = 0, numnodes - 1
     170             :          DEALLOCATE (buffer_send(iproc)%blocks, buffer_send(iproc)%data)
     171             :       END DO
     172             : !$OMP END DO
     173             : 
     174             :       nblks_recv_mythread = 0
     175             :       DO iproc = 0, numnodes - 1
     176             : !$OMP DO
     177             :          DO iblk = 1, nblks_recv_total(iproc)
     178             :             nblks_recv_mythread = nblks_recv_mythread + 1
     179             :          END DO
     180             : !$OMP END DO
     181             :       END DO
     182             :       ALLOCATE (blks_to_allocate(nblks_recv_mythread, ndims_tensor(tensor_in)))
     183             : 
     184             :       jblk = 0
     185             :       DO iproc = 0, numnodes - 1
     186             : !$OMP DO
     187             :          DO iblk = 1, nblks_recv_total(iproc)
     188             :             jblk = jblk + 1
     189             :             blks_to_allocate(jblk, :) = buffer_recv(iproc)%blocks(iblk, 1:)
     190             :          END DO
     191             : !$OMP END DO
     192             :       END DO
     193             :       CPASSERT(jblk == nblks_recv_mythread)
     194             :       CALL dbt_reserve_blocks(tensor_out, blks_to_allocate)
     195             :       DEALLOCATE (blks_to_allocate)
     196             : 
     197             :       DO iproc = 0, numnodes - 1
     198             : !$OMP DO
     199             :          DO iblk = 1, nblks_recv_total(iproc)
     200             :             ind_nd(:) = buffer_recv(iproc)%blocks(iblk, 1:)
     201             :             CALL dbt_blk_sizes(tensor_out, ind_nd, blk_size)
     202             :             offset = buffer_recv(iproc)%blocks(iblk, 0)
     203             :             ndata = PRODUCT(blk_size)
     204             :             CALL create_block(blk_data, blk_size, &
     205             :                               array=buffer_recv(iproc)%data(offset + 1:offset + ndata))
     206             :             CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
     207             :             CALL destroy_block(blk_data)
     208             :          END DO
     209             : !$OMP END DO
     210             :       END DO
     211             : !$OMP END PARALLEL
     212             : 
     213      436034 :       DO iproc = 0, numnodes - 1
     214      436034 :          DEALLOCATE (buffer_recv(iproc)%blocks, buffer_recv(iproc)%data)
     215             :       END DO
     216             : 
     217      147344 :       IF (move_prv) CALL dbt_clear(tensor_in)
     218             : 
     219      147344 :       CALL timestop(handle)
     220      872068 :    END SUBROUTINE dbt_reshape
     221             : 
     222             : ! **************************************************************************************************
     223             : !> \brief communicate buffer
     224             : !> \author Patrick Seewald
     225             : ! **************************************************************************************************
     226      147344 :    SUBROUTINE dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
     227             :       TYPE(mp_comm_type), INTENT(IN)                    :: mp_comm
     228             :       TYPE(block_buffer_type), DIMENSION(0:), INTENT(INOUT) :: buffer_recv, buffer_send
     229             : 
     230             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_communicate_buffer'
     231             : 
     232             :       INTEGER                                               :: iproc, numnodes, &
     233             :                                                                rec_counter, send_counter, i
     234      147344 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:, :)                 :: req_array
     235             :       INTEGER                                               :: handle
     236             : 
     237      147344 :       CALL timeset(routineN, handle)
     238      147344 :       numnodes = mp_comm%num_pe
     239             : 
     240      147344 :       IF (numnodes > 1) THEN
     241      141346 : !$OMP MASTER
     242      141346 :          send_counter = 0
     243      141346 :          rec_counter = 0
     244             : 
     245     2120190 :          ALLOCATE (req_array(1:numnodes, 4))
     246             : 
     247      424038 :          DO iproc = 0, numnodes - 1
     248      424038 :             IF (SIZE(buffer_recv(iproc)%blocks) > 0) THEN
     249      180078 :                rec_counter = rec_counter + 1
     250      180078 :                CALL mp_comm%irecv(buffer_recv(iproc)%blocks, iproc, req_array(rec_counter, 3), tag=4)
     251      180078 :                CALL mp_comm%irecv(buffer_recv(iproc)%data, iproc, req_array(rec_counter, 4), tag=7)
     252             :             END IF
     253             :          END DO
     254             : 
     255      424038 :          DO iproc = 0, numnodes - 1
     256      424038 :             IF (SIZE(buffer_send(iproc)%blocks) > 0) THEN
     257      180078 :                send_counter = send_counter + 1
     258      180078 :                CALL mp_comm%isend(buffer_send(iproc)%blocks, iproc, req_array(send_counter, 1), tag=4)
     259      180078 :                CALL mp_comm%isend(buffer_send(iproc)%data, iproc, req_array(send_counter, 2), tag=7)
     260             :             END IF
     261             :          END DO
     262             : 
     263      141346 :          IF (send_counter > 0) THEN
     264      120396 :             CALL mp_waitall(req_array(1:send_counter, 1:2))
     265             :          END IF
     266      141346 :          IF (rec_counter > 0) THEN
     267      113028 :             CALL mp_waitall(req_array(1:rec_counter, 3:4))
     268             :          END IF
     269             : !$OMP END MASTER
     270             : 
     271             :       ELSE
     272        5998 : !$OMP DO SCHEDULE(static, 512)
     273             :          DO i = 1, SIZE(buffer_send(0)%blocks, 1)
     274     3696320 :             buffer_recv(0)%blocks(i, :) = buffer_send(0)%blocks(i, :)
     275             :          END DO
     276             : !$OMP END DO
     277        5998 : !$OMP DO SCHEDULE(static, 512)
     278             :          DO i = 1, SIZE(buffer_send(0)%data)
     279   415351583 :             buffer_recv(0)%data(i) = buffer_send(0)%data(i)
     280             :          END DO
     281             : !$OMP END DO
     282             :       END IF
     283      147344 :       CALL timestop(handle)
     284             : 
     285      147344 :    END SUBROUTINE dbt_communicate_buffer
     286             : 
     287           0 : END MODULE dbt_reshape_ops

Generated by: LCOV version 1.15