Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Routines to reshape / redistribute tensors
10 : !> \author Patrick Seewald
11 : ! **************************************************************************************************
12 : MODULE dbt_reshape_ops
13 : #:include "dbt_macros.fypp"
14 : #:set maxdim = maxrank
15 : #:set ndims = range(2,maxdim+1)
16 :
17 : USE dbt_allocate_wrap, ONLY: allocate_any
18 : USE dbt_tas_base, ONLY: dbt_tas_copy, dbt_tas_get_info, dbt_tas_info
19 : USE dbt_block, ONLY: &
20 : block_nd, create_block, destroy_block, dbt_iterator_type, dbt_iterator_next_block, &
21 : dbt_iterator_blocks_left, dbt_iterator_start, dbt_iterator_stop, dbt_get_block, &
22 : dbt_reserve_blocks, dbt_put_block
23 : USE dbt_types, ONLY: dbt_blk_sizes, &
24 : dbt_create, &
25 : dbt_type, &
26 : ndims_tensor, &
27 : dbt_get_stored_coordinates, &
28 : dbt_clear
29 : USE kinds, ONLY: default_string_length
30 : USE kinds, ONLY: dp, dp
31 : USE message_passing, ONLY: &
32 : mp_waitall, mp_comm_type, mp_request_type
33 :
34 : #include "../base/base_uses.f90"
35 :
36 : IMPLICIT NONE
37 : PRIVATE
38 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_reshape_ops'
39 :
40 : PUBLIC :: dbt_reshape
41 :
42 : TYPE block_buffer_type
43 : INTEGER, DIMENSION(:, :), ALLOCATABLE :: blocks
44 : REAL(dp), DIMENSION(:), ALLOCATABLE :: data
45 : END TYPE block_buffer_type
46 :
47 : CONTAINS
48 :
49 : ! **************************************************************************************************
50 : !> \brief copy data (involves reshape)
51 : !> tensor_out = tensor_out + tensor_in move_data memory optimization:
52 : !> transfer data from tensor_in to tensor_out s.t. tensor_in is empty on return
53 : !> \author Ole Schuett
54 : ! **************************************************************************************************
55 207941 : SUBROUTINE dbt_reshape(tensor_in, tensor_out, summation, move_data)
56 :
57 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in, tensor_out
58 : LOGICAL, INTENT(IN), OPTIONAL :: summation
59 : LOGICAL, INTENT(IN), OPTIONAL :: move_data
60 :
61 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_reshape'
62 :
63 : INTEGER :: iproc, numnodes, &
64 : handle, iblk, jblk, offset, ndata, &
65 : nblks_recv_mythread
66 207941 : INTEGER, ALLOCATABLE, DIMENSION(:, :) :: blks_to_allocate
67 : TYPE(dbt_iterator_type) :: iter
68 207941 : TYPE(block_nd) :: blk_data
69 207941 : TYPE(block_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
70 207941 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: blk_size, ind_nd
71 : LOGICAL :: found, summation_prv, move_prv
72 :
73 207941 : INTEGER, ALLOCATABLE, DIMENSION(:) :: nblks_send_total, ndata_send_total, &
74 207941 : nblks_recv_total, ndata_recv_total, &
75 207941 : nblks_send_mythread, ndata_send_mythread
76 : TYPE(mp_comm_type) :: mp_comm
77 :
78 207941 : CALL timeset(routineN, handle)
79 :
80 207941 : IF (PRESENT(summation)) THEN
81 76851 : summation_prv = summation
82 : ELSE
83 : summation_prv = .FALSE.
84 : END IF
85 :
86 207941 : IF (PRESENT(move_data)) THEN
87 207941 : move_prv = move_data
88 : ELSE
89 : move_prv = .FALSE.
90 : END IF
91 :
92 207941 : CPASSERT(tensor_out%valid)
93 :
94 207941 : IF (.NOT. summation_prv) CALL dbt_clear(tensor_out)
95 :
96 207941 : mp_comm = tensor_in%pgrid%mp_comm_2d
97 207941 : numnodes = mp_comm%num_pe
98 1603426 : ALLOCATE (buffer_send(0:numnodes - 1), buffer_recv(0:numnodes - 1))
99 1603426 : ALLOCATE (nblks_send_total(0:numnodes - 1), ndata_send_total(0:numnodes - 1), source=0)
100 1395485 : ALLOCATE (nblks_recv_total(0:numnodes - 1), ndata_recv_total(0:numnodes - 1), source=0)
101 :
102 : !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
103 : !$OMP SHARED(tensor_in,tensor_out,summation) &
104 : !$OMP SHARED(buffer_send,buffer_recv,mp_comm,numnodes) &
105 : !$OMP SHARED(nblks_send_total,ndata_send_total,nblks_recv_total,ndata_recv_total) &
106 : !$OMP PRIVATE(nblks_send_mythread,ndata_send_mythread,nblks_recv_mythread) &
107 : !$OMP PRIVATE(iter,ind_nd,blk_size,blk_data,found,iproc) &
108 207941 : !$OMP PRIVATE(blks_to_allocate,offset,ndata,iblk,jblk)
109 : ALLOCATE (nblks_send_mythread(0:numnodes - 1), ndata_send_mythread(0:numnodes - 1), source=0)
110 :
111 : CALL dbt_iterator_start(iter, tensor_in)
112 : DO WHILE (dbt_iterator_blocks_left(iter))
113 : CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
114 : CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
115 : nblks_send_mythread(iproc) = nblks_send_mythread(iproc) + 1
116 : ndata_send_mythread(iproc) = ndata_send_mythread(iproc) + PRODUCT(blk_size)
117 : END DO
118 : CALL dbt_iterator_stop(iter)
119 : !$OMP CRITICAL(omp_dbt_reshape)
120 : nblks_send_total(:) = nblks_send_total(:) + nblks_send_mythread(:)
121 : ndata_send_total(:) = ndata_send_total(:) + ndata_send_mythread(:)
122 : nblks_send_mythread(:) = nblks_send_total(:) ! current totals indicate slot for this thread
123 : ndata_send_mythread(:) = ndata_send_total(:)
124 : !$OMP END CRITICAL(omp_dbt_reshape)
125 : !$OMP BARRIER
126 :
127 : !$OMP MASTER
128 : CALL mp_comm%alltoall(nblks_send_total, nblks_recv_total, 1)
129 : CALL mp_comm%alltoall(ndata_send_total, ndata_recv_total, 1)
130 : !$OMP END MASTER
131 : !$OMP BARRIER
132 :
133 : !$OMP DO
134 : DO iproc = 0, numnodes - 1
135 : ALLOCATE (buffer_send(iproc)%data(ndata_send_total(iproc)))
136 : ALLOCATE (buffer_recv(iproc)%data(ndata_recv_total(iproc)))
137 : ! going to use buffer%blocks(:,0) to store data offsets
138 : ALLOCATE (buffer_send(iproc)%blocks(nblks_send_total(iproc), 0:ndims_tensor(tensor_in)))
139 : ALLOCATE (buffer_recv(iproc)%blocks(nblks_recv_total(iproc), 0:ndims_tensor(tensor_in)))
140 : END DO
141 : !$OMP END DO
142 : !$OMP BARRIER
143 :
144 : CALL dbt_iterator_start(iter, tensor_in)
145 : DO WHILE (dbt_iterator_blocks_left(iter))
146 : CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
147 : CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
148 : CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
149 : CPASSERT(found)
150 : ! insert block data
151 : ndata = PRODUCT(blk_size)
152 : ndata_send_mythread(iproc) = ndata_send_mythread(iproc) - ndata
153 : offset = ndata_send_mythread(iproc)
154 : buffer_send(iproc)%data(offset + 1:offset + ndata) = blk_data%blk(:)
155 : ! insert block index
156 : nblks_send_mythread(iproc) = nblks_send_mythread(iproc) - 1
157 : iblk = nblks_send_mythread(iproc) + 1
158 : buffer_send(iproc)%blocks(iblk, 1:) = ind_nd(:)
159 : buffer_send(iproc)%blocks(iblk, 0) = offset
160 : CALL destroy_block(blk_data)
161 : END DO
162 : CALL dbt_iterator_stop(iter)
163 : DEALLOCATE (nblks_send_mythread, ndata_send_mythread)
164 : !$OMP BARRIER
165 :
166 : CALL dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
167 : !$OMP BARRIER
168 :
169 : !$OMP DO
170 : DO iproc = 0, numnodes - 1
171 : DEALLOCATE (buffer_send(iproc)%blocks, buffer_send(iproc)%data)
172 : END DO
173 : !$OMP END DO NOWAIT
174 :
175 : nblks_recv_mythread = 0
176 : DO iproc = 0, numnodes - 1
177 : !$OMP DO
178 : DO iblk = 1, nblks_recv_total(iproc)
179 : nblks_recv_mythread = nblks_recv_mythread + 1
180 : END DO
181 : !$OMP END DO
182 : END DO
183 : ALLOCATE (blks_to_allocate(nblks_recv_mythread, ndims_tensor(tensor_in)))
184 :
185 : jblk = 0
186 : DO iproc = 0, numnodes - 1
187 : !$OMP DO
188 : DO iblk = 1, nblks_recv_total(iproc)
189 : jblk = jblk + 1
190 : blks_to_allocate(jblk, :) = buffer_recv(iproc)%blocks(iblk, 1:)
191 : END DO
192 : !$OMP END DO
193 : END DO
194 : CPASSERT(jblk == nblks_recv_mythread)
195 : CALL dbt_reserve_blocks(tensor_out, blks_to_allocate)
196 : DEALLOCATE (blks_to_allocate)
197 :
198 : DO iproc = 0, numnodes - 1
199 : !$OMP DO
200 : DO iblk = 1, nblks_recv_total(iproc)
201 : ind_nd(:) = buffer_recv(iproc)%blocks(iblk, 1:)
202 : CALL dbt_blk_sizes(tensor_out, ind_nd, blk_size)
203 : offset = buffer_recv(iproc)%blocks(iblk, 0)
204 : ndata = PRODUCT(blk_size)
205 : CALL create_block(blk_data, blk_size, &
206 : array=buffer_recv(iproc)%data(offset + 1:offset + ndata))
207 : CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
208 : CALL destroy_block(blk_data)
209 : END DO
210 : !$OMP END DO
211 : END DO
212 :
213 : !$OMP DO
214 : DO iproc = 0, numnodes - 1
215 : DEALLOCATE (buffer_recv(iproc)%blocks, buffer_recv(iproc)%data)
216 : END DO
217 : !$OMP END DO
218 : !$OMP END PARALLEL
219 :
220 207941 : DEALLOCATE (nblks_recv_total, ndata_recv_total)
221 207941 : DEALLOCATE (nblks_send_total, ndata_send_total)
222 979603 : DEALLOCATE (buffer_send, buffer_recv)
223 :
224 207941 : IF (move_prv) CALL dbt_clear(tensor_in)
225 :
226 207941 : CALL timestop(handle)
227 415882 : END SUBROUTINE dbt_reshape
228 :
229 : ! **************************************************************************************************
230 : !> \brief communicate buffer
231 : !> \author Patrick Seewald
232 : ! **************************************************************************************************
233 207941 : SUBROUTINE dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
234 : TYPE(mp_comm_type), INTENT(IN) :: mp_comm
235 : TYPE(block_buffer_type), DIMENSION(0:), INTENT(INOUT) :: buffer_recv, buffer_send
236 :
237 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_communicate_buffer'
238 :
239 : INTEGER :: iproc, numnodes, &
240 : rec_counter, send_counter, i
241 207941 : TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:, :) :: req_array
242 : INTEGER :: handle
243 :
244 207941 : CALL timeset(routineN, handle)
245 207941 : numnodes = mp_comm%num_pe
246 :
247 207941 : IF (numnodes > 1) THEN
248 177890 : !$OMP MASTER
249 177890 : send_counter = 0
250 177890 : rec_counter = 0
251 :
252 2668350 : ALLOCATE (req_array(1:numnodes, 4))
253 :
254 533670 : DO iproc = 0, numnodes - 1
255 1245230 : IF (SIZE(buffer_recv(iproc)%blocks) > 0) THEN
256 224978 : rec_counter = rec_counter + 1
257 224978 : CALL mp_comm%irecv(buffer_recv(iproc)%blocks, iproc, req_array(rec_counter, 3), tag=4)
258 224978 : CALL mp_comm%irecv(buffer_recv(iproc)%data, iproc, req_array(rec_counter, 4), tag=7)
259 : END IF
260 : END DO
261 :
262 533670 : DO iproc = 0, numnodes - 1
263 1245230 : IF (SIZE(buffer_send(iproc)%blocks) > 0) THEN
264 224978 : send_counter = send_counter + 1
265 224978 : CALL mp_comm%isend(buffer_send(iproc)%blocks, iproc, req_array(send_counter, 1), tag=4)
266 224978 : CALL mp_comm%isend(buffer_send(iproc)%data, iproc, req_array(send_counter, 2), tag=7)
267 : END IF
268 : END DO
269 :
270 177890 : IF (send_counter > 0) THEN
271 156623 : CALL mp_waitall(req_array(1:send_counter, 1:2))
272 : END IF
273 177890 : IF (rec_counter > 0) THEN
274 148958 : CALL mp_waitall(req_array(1:rec_counter, 3:4))
275 : END IF
276 : !$OMP END MASTER
277 :
278 : ELSE
279 30051 : !$OMP DO SCHEDULE(static)
280 : DO i = 1, SIZE(buffer_send(0)%blocks, 1)
281 4811460 : buffer_recv(0)%blocks(i, :) = buffer_send(0)%blocks(i, :)
282 : END DO
283 : !$OMP END DO NOWAIT
284 30051 : !$OMP DO SCHEDULE(static)
285 : DO i = 1, SIZE(buffer_send(0)%data)
286 440070653 : buffer_recv(0)%data(i) = buffer_send(0)%data(i)
287 : END DO
288 : !$OMP END DO
289 : END IF
290 207941 : CALL timestop(handle)
291 :
292 207941 : END SUBROUTINE dbt_communicate_buffer
293 :
294 0 : END MODULE dbt_reshape_ops
|