Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief DBT tensor framework for block-sparse tensor contraction.
10 : !> Representation of n-rank tensors as DBT tall-and-skinny matrices.
11 : !> Support for arbitrary redistribution between different representations.
12 : !> Support for arbitrary tensor contractions
13 : !> \todo implement checks and error messages
14 : !> \author Patrick Seewald
15 : ! **************************************************************************************************
16 : MODULE dbt_methods
17 : #:include "dbt_macros.fypp"
18 : #:set maxdim = maxrank
19 : #:set ndims = range(2,maxdim+1)
20 :
21 : USE dbcsr_api, ONLY: &
22 : dbcsr_type, dbcsr_release, &
23 : dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
24 : dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
25 : USE dbt_allocate_wrap, ONLY: &
26 : allocate_any
27 : USE dbt_array_list_methods, ONLY: &
28 : get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, &
29 : create_array_list, destroy_array_list, sizes_of_arrays
30 : USE dbm_api, ONLY: &
31 : dbm_clear
32 : USE dbt_tas_types, ONLY: &
33 : dbt_tas_split_info
34 : USE dbt_tas_base, ONLY: &
35 : dbt_tas_copy, dbt_tas_finalize, dbt_tas_get_info, dbt_tas_info
36 : USE dbt_tas_mm, ONLY: &
37 : dbt_tas_multiply, dbt_tas_batched_mm_init, dbt_tas_batched_mm_finalize, &
38 : dbt_tas_batched_mm_complete, dbt_tas_set_batched_state
39 : USE dbt_block, ONLY: &
40 : dbt_iterator_type, dbt_get_block, dbt_put_block, dbt_iterator_start, &
41 : dbt_iterator_blocks_left, dbt_iterator_stop, dbt_iterator_next_block, &
42 : ndims_iterator, dbt_reserve_blocks, block_nd, destroy_block, checker_tr
43 : USE dbt_index, ONLY: &
44 : dbt_get_mapping_info, nd_to_2d_mapping, dbt_inverse_order, permute_index, get_nd_indices_tensor, &
45 : ndims_mapping_row, ndims_mapping_column, ndims_mapping
46 : USE dbt_types, ONLY: &
47 : dbt_create, dbt_type, ndims_tensor, dims_tensor, &
48 : dbt_distribution_type, dbt_distribution, dbt_nd_mp_comm, dbt_destroy, &
49 : dbt_distribution_destroy, dbt_distribution_new_expert, dbt_get_stored_coordinates, &
50 : blk_dims_tensor, dbt_hold, dbt_pgrid_type, mp_environ_pgrid, dbt_filter, &
51 : dbt_clear, dbt_finalize, dbt_get_num_blocks, dbt_scale, &
52 : dbt_get_num_blocks_total, dbt_get_info, ndims_matrix_row, ndims_matrix_column, &
53 : dbt_max_nblks_local, dbt_default_distvec, dbt_contraction_storage, dbt_nblks_total, &
54 : dbt_distribution_new, dbt_copy_contraction_storage, dbt_pgrid_destroy
55 : USE kinds, ONLY: &
56 : dp, default_string_length, int_8, dp
57 : USE message_passing, ONLY: &
58 : mp_cart_type
59 : USE util, ONLY: &
60 : sort
61 : USE dbt_reshape_ops, ONLY: &
62 : dbt_reshape
63 : USE dbt_tas_split, ONLY: &
64 : dbt_tas_mp_comm, rowsplit, colsplit, dbt_tas_info_hold, dbt_tas_release_info, default_nsplit_accept_ratio, &
65 : default_pdims_accept_ratio, dbt_tas_create_split
66 : USE dbt_split, ONLY: &
67 : dbt_split_copyback, dbt_make_compatible_blocks, dbt_crop
68 : USE dbt_io, ONLY: &
69 : dbt_write_tensor_info, dbt_write_tensor_dist, prep_output_unit, dbt_write_split_info
70 : USE message_passing, ONLY: mp_comm_type
71 :
72 : #include "../base/base_uses.f90"
73 :
74 : IMPLICIT NONE
75 : PRIVATE
76 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
77 :
78 : PUBLIC :: &
79 : dbt_contract, &
80 : dbt_copy, &
81 : dbt_get_block, &
82 : dbt_get_stored_coordinates, &
83 : dbt_inverse_order, &
84 : dbt_iterator_blocks_left, &
85 : dbt_iterator_next_block, &
86 : dbt_iterator_start, &
87 : dbt_iterator_stop, &
88 : dbt_iterator_type, &
89 : dbt_put_block, &
90 : dbt_reserve_blocks, &
91 : dbt_copy_matrix_to_tensor, &
92 : dbt_copy_tensor_to_matrix, &
93 : dbt_batched_contract_init, &
94 : dbt_batched_contract_finalize
95 :
96 : CONTAINS
97 :
98 : ! **************************************************************************************************
99 : !> \brief Copy tensor data.
100 : !> Redistributes tensor data according to distributions of target and source tensor.
101 : !> Permutes tensor index according to `order` argument (if present).
102 : !> Source and target tensor formats are arbitrary as long as the following requirements are met:
103 : !> * source and target tensors have the same rank and the same sizes in each dimension in terms
104 : !> of tensor elements (block sizes don't need to be the same).
105 : !> If `order` argument is present, sizes must match after index permutation.
106 : !> OR
107 : !> * target tensor is not yet created, in this case an exact copy of source tensor is returned.
108 : !> \param tensor_in Source
109 : !> \param tensor_out Target
110 : !> \param order Permutation of target tensor index.
111 : !> Exact same convention as order argument of RESHAPE intrinsic.
112 : !> \param bounds crop tensor data: start and end index for each tensor dimension
113 : !> \author Patrick Seewald
114 : ! **************************************************************************************************
115 710234 : SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
116 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
117 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
118 : INTENT(IN), OPTIONAL :: order
119 : LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
120 : INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
121 : INTENT(IN), OPTIONAL :: bounds
122 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
123 : INTEGER :: handle
124 :
125 355117 : CALL tensor_in%pgrid%mp_comm_2d%sync()
126 355117 : CALL timeset("dbt_total", handle)
127 :
128 : ! make sure that it is safe to use dbt_copy during a batched contraction
129 355117 : CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
130 355117 : CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
131 :
132 355117 : CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
133 355117 : CALL tensor_in%pgrid%mp_comm_2d%sync()
134 355117 : CALL timestop(handle)
135 355117 : END SUBROUTINE
136 :
137 : ! **************************************************************************************************
138 : !> \brief expert routine for copying a tensor. For internal use only.
139 : !> \author Patrick Seewald
140 : ! **************************************************************************************************
141 356351 : SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
142 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
143 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
144 : INTENT(IN), OPTIONAL :: order
145 : LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
146 : INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
147 : INTENT(IN), OPTIONAL :: bounds
148 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
149 :
150 : TYPE(dbt_type), POINTER :: in_tmp_1, in_tmp_2, &
151 : in_tmp_3, out_tmp_1
152 : INTEGER :: handle, unit_nr_prv
153 356351 : INTEGER, DIMENSION(:), ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
154 :
155 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy'
156 : LOGICAL :: dist_compatible_tas, dist_compatible_tensor, &
157 : summation_prv, new_in_1, new_in_2, &
158 : new_in_3, new_out_1, block_compatible, &
159 : move_prv
160 356351 : TYPE(array_list) :: blk_sizes_in
161 :
162 356351 : CALL timeset(routineN, handle)
163 :
164 356351 : CPASSERT(tensor_out%valid)
165 :
166 356351 : unit_nr_prv = prep_output_unit(unit_nr)
167 :
168 356351 : IF (PRESENT(move_data)) THEN
169 276339 : move_prv = move_data
170 : ELSE
171 80012 : move_prv = .FALSE.
172 : END IF
173 :
174 356351 : dist_compatible_tas = .FALSE.
175 356351 : dist_compatible_tensor = .FALSE.
176 356351 : block_compatible = .FALSE.
177 356351 : new_in_1 = .FALSE.
178 356351 : new_in_2 = .FALSE.
179 356351 : new_in_3 = .FALSE.
180 356351 : new_out_1 = .FALSE.
181 :
182 356351 : IF (PRESENT(summation)) THEN
183 96396 : summation_prv = summation
184 : ELSE
185 : summation_prv = .FALSE.
186 : END IF
187 :
188 356351 : IF (PRESENT(bounds)) THEN
189 39424 : ALLOCATE (in_tmp_1)
190 5632 : CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
191 5632 : new_in_1 = .TRUE.
192 5632 : move_prv = .TRUE.
193 : ELSE
194 : in_tmp_1 => tensor_in
195 : END IF
196 :
197 356351 : IF (PRESENT(order)) THEN
198 90634 : CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
199 90634 : block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
200 : ELSE
201 265717 : block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
202 : END IF
203 :
204 356351 : IF (.NOT. block_compatible) THEN
205 958243 : ALLOCATE (in_tmp_2, out_tmp_1)
206 : CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
207 73711 : nodata2=.NOT. summation_prv, move_data=move_prv)
208 73711 : new_in_2 = .TRUE.; new_out_1 = .TRUE.
209 73711 : move_prv = .TRUE.
210 : ELSE
211 : in_tmp_2 => in_tmp_1
212 : out_tmp_1 => tensor_out
213 : END IF
214 :
215 356351 : IF (PRESENT(order)) THEN
216 634438 : ALLOCATE (in_tmp_3)
217 90634 : CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
218 90634 : new_in_3 = .TRUE.
219 : ELSE
220 : in_tmp_3 => in_tmp_2
221 : END IF
222 :
223 1069053 : ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
224 1069053 : ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
225 356351 : CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
226 :
227 1069053 : ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
228 1069053 : ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
229 356351 : CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
230 :
231 356351 : IF (.NOT. PRESENT(order)) THEN
232 265717 : IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
233 254561 : dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
234 145308 : ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
235 10566 : dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
236 : END IF
237 : END IF
238 :
239 254561 : IF (dist_compatible_tas) THEN
240 205997 : CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
241 205997 : IF (move_prv) CALL dbt_clear(in_tmp_3)
242 150354 : ELSEIF (dist_compatible_tensor) THEN
243 2962 : CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
244 2962 : IF (move_prv) CALL dbt_clear(in_tmp_3)
245 : ELSE
246 147392 : CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
247 : END IF
248 :
249 356351 : IF (new_in_1) THEN
250 5632 : CALL dbt_destroy(in_tmp_1)
251 5632 : DEALLOCATE (in_tmp_1)
252 : END IF
253 :
254 356351 : IF (new_in_2) THEN
255 73711 : CALL dbt_destroy(in_tmp_2)
256 73711 : DEALLOCATE (in_tmp_2)
257 : END IF
258 :
259 356351 : IF (new_in_3) THEN
260 90634 : CALL dbt_destroy(in_tmp_3)
261 90634 : DEALLOCATE (in_tmp_3)
262 : END IF
263 :
264 356351 : IF (new_out_1) THEN
265 73711 : IF (unit_nr_prv /= 0) THEN
266 0 : CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
267 : END IF
268 73711 : CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
269 73711 : CALL dbt_destroy(out_tmp_1)
270 73711 : DEALLOCATE (out_tmp_1)
271 : END IF
272 :
273 356351 : CALL timestop(handle)
274 :
275 712702 : END SUBROUTINE
276 :
277 : ! **************************************************************************************************
278 : !> \brief copy without communication, requires that both tensors have same process grid and distribution
279 : !> \param summation Whether to sum matrices b = a + b
280 : !> \author Patrick Seewald
281 : ! **************************************************************************************************
282 2962 : SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
283 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
284 : TYPE(dbt_type), INTENT(INOUT) :: tensor_out
285 : LOGICAL, INTENT(IN), OPTIONAL :: summation
286 : TYPE(dbt_iterator_type) :: iter
287 2962 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: ind_nd
288 2962 : TYPE(block_nd) :: blk_data
289 : LOGICAL :: found
290 :
291 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_nocomm'
292 : INTEGER :: handle
293 :
294 2962 : CALL timeset(routineN, handle)
295 2962 : CPASSERT(tensor_out%valid)
296 :
297 2962 : IF (PRESENT(summation)) THEN
298 24 : IF (.NOT. summation) CALL dbt_clear(tensor_out)
299 : ELSE
300 2938 : CALL dbt_clear(tensor_out)
301 : END IF
302 :
303 2962 : CALL dbt_reserve_blocks(tensor_in, tensor_out)
304 :
305 : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
306 2962 : !$OMP PRIVATE(iter,ind_nd,blk_data,found)
307 : CALL dbt_iterator_start(iter, tensor_in)
308 : DO WHILE (dbt_iterator_blocks_left(iter))
309 : CALL dbt_iterator_next_block(iter, ind_nd)
310 : CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
311 : CPASSERT(found)
312 : CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
313 : CALL destroy_block(blk_data)
314 : END DO
315 : CALL dbt_iterator_stop(iter)
316 : !$OMP END PARALLEL
317 :
318 2962 : CALL timestop(handle)
319 5924 : END SUBROUTINE
320 :
321 : ! **************************************************************************************************
322 : !> \brief copy matrix to tensor.
323 : !> \param summation tensor_out = tensor_out + matrix_in
324 : !> \author Patrick Seewald
325 : ! **************************************************************************************************
326 56972 : SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
327 : TYPE(dbcsr_type), TARGET, INTENT(IN) :: matrix_in
328 : TYPE(dbt_type), INTENT(INOUT) :: tensor_out
329 : LOGICAL, INTENT(IN), OPTIONAL :: summation
330 : TYPE(dbcsr_type), POINTER :: matrix_in_desym
331 :
332 : INTEGER, DIMENSION(2) :: ind_2d
333 56972 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: block_arr
334 56972 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: block
335 : TYPE(dbcsr_iterator_type) :: iter
336 : LOGICAL :: tr
337 :
338 : INTEGER :: handle
339 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_matrix_to_tensor'
340 :
341 56972 : CALL timeset(routineN, handle)
342 56972 : CPASSERT(tensor_out%valid)
343 :
344 56972 : NULLIFY (block)
345 :
346 56972 : IF (dbcsr_has_symmetry(matrix_in)) THEN
347 4530 : ALLOCATE (matrix_in_desym)
348 4530 : CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
349 : ELSE
350 : matrix_in_desym => matrix_in
351 : END IF
352 :
353 56972 : IF (PRESENT(summation)) THEN
354 0 : IF (.NOT. summation) CALL dbt_clear(tensor_out)
355 : ELSE
356 56972 : CALL dbt_clear(tensor_out)
357 : END IF
358 :
359 56972 : CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
360 :
361 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
362 56972 : !$OMP PRIVATE(iter,ind_2d,block,tr,block_arr)
363 : CALL dbcsr_iterator_start(iter, matrix_in_desym)
364 : DO WHILE (dbcsr_iterator_blocks_left(iter))
365 : CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr)
366 : CALL allocate_any(block_arr, source=block)
367 : CALL dbt_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation)
368 : DEALLOCATE (block_arr)
369 : END DO
370 : CALL dbcsr_iterator_stop(iter)
371 : !$OMP END PARALLEL
372 :
373 56972 : IF (dbcsr_has_symmetry(matrix_in)) THEN
374 4530 : CALL dbcsr_release(matrix_in_desym)
375 4530 : DEALLOCATE (matrix_in_desym)
376 : END IF
377 :
378 56972 : CALL timestop(handle)
379 :
380 113944 : END SUBROUTINE
381 :
382 : ! **************************************************************************************************
383 : !> \brief copy tensor to matrix
384 : !> \param summation matrix_out = matrix_out + tensor_in
385 : !> \author Patrick Seewald
386 : ! **************************************************************************************************
387 35282 : SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
388 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
389 : TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out
390 : LOGICAL, INTENT(IN), OPTIONAL :: summation
391 : TYPE(dbt_iterator_type) :: iter
392 : INTEGER :: handle
393 : INTEGER, DIMENSION(2) :: ind_2d
394 35282 : REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: block
395 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_tensor_to_matrix'
396 : LOGICAL :: found
397 :
398 35282 : CALL timeset(routineN, handle)
399 :
400 35282 : IF (PRESENT(summation)) THEN
401 5840 : IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
402 : ELSE
403 29442 : CALL dbcsr_clear(matrix_out)
404 : END IF
405 :
406 35282 : CALL dbt_reserve_blocks(tensor_in, matrix_out)
407 :
408 : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
409 35282 : !$OMP PRIVATE(iter,ind_2d,block,found)
410 : CALL dbt_iterator_start(iter, tensor_in)
411 : DO WHILE (dbt_iterator_blocks_left(iter))
412 : CALL dbt_iterator_next_block(iter, ind_2d)
413 : IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE
414 :
415 : CALL dbt_get_block(tensor_in, ind_2d, block, found)
416 : CPASSERT(found)
417 :
418 : IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
419 : CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation)
420 : ELSE
421 : CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
422 : END IF
423 : DEALLOCATE (block)
424 : END DO
425 : CALL dbt_iterator_stop(iter)
426 : !$OMP END PARALLEL
427 :
428 35282 : CALL timestop(handle)
429 :
430 70564 : END SUBROUTINE
431 :
432 : ! **************************************************************************************************
433 : !> \brief Contract tensors by multiplying matrix representations.
434 : !> tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
435 : !> * tensor_2(contract_2, notcontract_2)
436 : !> + beta * tensor_3(map_1, map_2)
437 : !>
438 : !> \note
439 : !> note 1: block sizes of the corresponding indices need to be the same in all tensors.
440 : !>
441 : !> note 2: for best performance the tensors should have been created in matrix layouts
442 : !> compatible with the contraction, e.g. tensor_1 should have been created with either
443 : !> map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
444 : !> map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
445 : !> tensor_3 and map_1 / map_2).
446 : !> Furthermore the two largest tensors involved in the contraction should map both to either
447 : !> tall or short matrices: the largest matrix dimension should be "on the same side"
448 : !> and should have identical distribution (which is always the case if the distributions were
449 : !> obtained with dbt_default_distvec).
450 : !>
451 : !> note 3: if the same tensor occurs in multiple contractions, a different tensor object should
452 : !> be created for each contraction and the data should be copied between the tensors by use of
453 : !> dbt_copy. If the same tensor object is used in multiple contractions,
454 : !> matrix layouts are not compatible for all contractions (see note 2).
455 : !>
456 : !> note 4: automatic optimizations are enabled by using the feature of batched contraction, see
457 : !> dbt_batched_contract_init, dbt_batched_contract_finalize.
458 : !> The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
459 : !>
460 : !> \param tensor_1 first tensor (in)
461 : !> \param tensor_2 second tensor (in)
462 : !> \param contract_1 indices of tensor_1 to contract
463 : !> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
464 : !> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
465 : !> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
466 : !> \param notcontract_1 indices of tensor_1 not to contract
467 : !> \param notcontract_2 indices of tensor_2 not to contract
468 : !> \param tensor_3 contracted tensor (out)
469 : !> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
470 : !> start and end index of an index range over which to contract.
471 : !> For use in batched contraction.
472 : !> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
473 : !> For use in batched contraction.
474 : !> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
475 : !> For use in batched contraction.
476 : !> \param optimize_dist Whether distribution should be optimized internally. In the current
477 : !> implementation this guarantees optimal parameters only for dense matrices.
478 : !> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
479 : !> This can be used to choose optimal process grids for subsequent tensor
480 : !> contractions with tensors of similar shape and sparsity. Under some conditions,
481 : !> pgrid_opt_1 can not be returned, in this case the pointer is not associated.
482 : !> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
483 : !> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
484 : !> \param filter_eps As in DBM mm
485 : !> \param flop As in DBM mm
486 : !> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
487 : !> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
488 : !> \param unit_nr output unit for logging
489 : !> set it to -1 on ranks that should not write (and any valid unit number on
490 : !> ranks that should write output) if 0 on ALL ranks, no output is written
491 : !> \param log_verbose verbose logging (for testing only)
492 : !> \author Patrick Seewald
493 : ! **************************************************************************************************
494 222472 : SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
495 111236 : contract_1, notcontract_1, &
496 111236 : contract_2, notcontract_2, &
497 111236 : map_1, map_2, &
498 114954 : bounds_1, bounds_2, bounds_3, &
499 : optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
500 : filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
501 : REAL(dp), INTENT(IN) :: alpha
502 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
503 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
504 : REAL(dp), INTENT(IN) :: beta
505 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
506 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
507 : INTEGER, DIMENSION(:), INTENT(IN) :: map_1
508 : INTEGER, DIMENSION(:), INTENT(IN) :: map_2
509 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
510 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
511 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
512 : INTEGER, DIMENSION(2, SIZE(contract_1)), &
513 : INTENT(IN), OPTIONAL :: bounds_1
514 : INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
515 : INTENT(IN), OPTIONAL :: bounds_2
516 : INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
517 : INTENT(IN), OPTIONAL :: bounds_3
518 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
519 : TYPE(dbt_pgrid_type), INTENT(OUT), &
520 : POINTER, OPTIONAL :: pgrid_opt_1
521 : TYPE(dbt_pgrid_type), INTENT(OUT), &
522 : POINTER, OPTIONAL :: pgrid_opt_2
523 : TYPE(dbt_pgrid_type), INTENT(OUT), &
524 : POINTER, OPTIONAL :: pgrid_opt_3
525 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
526 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
527 : LOGICAL, INTENT(IN), OPTIONAL :: move_data
528 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
529 : INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
530 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
531 :
532 : INTEGER :: handle
533 :
534 111236 : CALL tensor_1%pgrid%mp_comm_2d%sync()
535 111236 : CALL timeset("dbt_total", handle)
536 : CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
537 : contract_1, notcontract_1, &
538 : contract_2, notcontract_2, &
539 : map_1, map_2, &
540 : bounds_1=bounds_1, &
541 : bounds_2=bounds_2, &
542 : bounds_3=bounds_3, &
543 : optimize_dist=optimize_dist, &
544 : pgrid_opt_1=pgrid_opt_1, &
545 : pgrid_opt_2=pgrid_opt_2, &
546 : pgrid_opt_3=pgrid_opt_3, &
547 : filter_eps=filter_eps, &
548 : flop=flop, &
549 : move_data=move_data, &
550 : retain_sparsity=retain_sparsity, &
551 : unit_nr=unit_nr, &
552 111236 : log_verbose=log_verbose)
553 111236 : CALL tensor_1%pgrid%mp_comm_2d%sync()
554 111236 : CALL timestop(handle)
555 :
556 181396 : END SUBROUTINE
557 :
558 : ! **************************************************************************************************
559 : !> \brief expert routine for tensor contraction. For internal use only.
560 : !> \param nblks_local number of local blocks on this MPI rank
561 : !> \author Patrick Seewald
562 : ! **************************************************************************************************
563 111236 : SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
564 111236 : contract_1, notcontract_1, &
565 111236 : contract_2, notcontract_2, &
566 111236 : map_1, map_2, &
567 111236 : bounds_1, bounds_2, bounds_3, &
568 : optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
569 : filter_eps, flop, move_data, retain_sparsity, &
570 : nblks_local, unit_nr, log_verbose)
571 : REAL(dp), INTENT(IN) :: alpha
572 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
573 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
574 : REAL(dp), INTENT(IN) :: beta
575 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
576 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
577 : INTEGER, DIMENSION(:), INTENT(IN) :: map_1
578 : INTEGER, DIMENSION(:), INTENT(IN) :: map_2
579 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
580 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
581 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
582 : INTEGER, DIMENSION(2, SIZE(contract_1)), &
583 : INTENT(IN), OPTIONAL :: bounds_1
584 : INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
585 : INTENT(IN), OPTIONAL :: bounds_2
586 : INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
587 : INTENT(IN), OPTIONAL :: bounds_3
588 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
589 : TYPE(dbt_pgrid_type), INTENT(OUT), &
590 : POINTER, OPTIONAL :: pgrid_opt_1
591 : TYPE(dbt_pgrid_type), INTENT(OUT), &
592 : POINTER, OPTIONAL :: pgrid_opt_2
593 : TYPE(dbt_pgrid_type), INTENT(OUT), &
594 : POINTER, OPTIONAL :: pgrid_opt_3
595 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
596 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
597 : LOGICAL, INTENT(IN), OPTIONAL :: move_data
598 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
599 : INTEGER, INTENT(OUT), OPTIONAL :: nblks_local
600 : INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
601 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
602 :
603 : TYPE(dbt_type), POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3
604 2113484 : TYPE(dbt_type), TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3
605 : TYPE(dbt_type), POINTER :: tensor_crop_1, tensor_crop_2
606 : TYPE(dbt_type), POINTER :: tensor_small, tensor_large
607 :
608 : LOGICAL :: assert_stmt, tensors_remapped
609 : INTEGER :: max_mm_dim, max_tensor, &
610 : unit_nr_prv, ref_tensor, handle
611 111236 : TYPE(mp_cart_type) :: mp_comm_opt
612 222472 : INTEGER, DIMENSION(SIZE(contract_1)) :: contract_1_mod
613 222472 : INTEGER, DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod
614 222472 : INTEGER, DIMENSION(SIZE(contract_2)) :: contract_2_mod
615 222472 : INTEGER, DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod
616 222472 : INTEGER, DIMENSION(SIZE(map_1)) :: map_1_mod
617 222472 : INTEGER, DIMENSION(SIZE(map_2)) :: map_2_mod
618 : LOGICAL :: trans_1, trans_2, trans_3
619 : LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2
620 : INTEGER :: ndims1, ndims2, ndims3
621 : INTEGER :: occ_1, occ_2
622 111236 : INTEGER, DIMENSION(:), ALLOCATABLE :: dims1, dims2, dims3
623 :
624 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_contract'
625 111236 : CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, &
626 111236 : indchar2_mod, indchar3_mod
627 : CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
628 : ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
629 222472 : INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
630 222472 : INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
631 : LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
632 : pgrid_changed_any, do_change_pgrid(2)
633 1446068 : TYPE(dbt_tas_split_info) :: split_opt, split, split_opt_avg
634 : INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
635 : REAL(dp) :: pdim_ratio, pdim_ratio_opt
636 :
637 111236 : NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
638 111236 : tensor_small)
639 :
640 111236 : CALL timeset(routineN, handle)
641 :
642 111236 : CPASSERT(tensor_1%valid)
643 111236 : CPASSERT(tensor_2%valid)
644 111236 : CPASSERT(tensor_3%valid)
645 :
646 111236 : assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
647 111236 : CPASSERT(assert_stmt)
648 :
649 111236 : assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
650 111236 : CPASSERT(assert_stmt)
651 :
652 111236 : assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
653 111236 : CPASSERT(assert_stmt)
654 :
655 111236 : assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
656 111236 : CPASSERT(assert_stmt)
657 :
658 111236 : assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
659 111236 : CPASSERT(assert_stmt)
660 :
661 111236 : assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
662 111236 : CPASSERT(assert_stmt)
663 :
664 111236 : unit_nr_prv = prep_output_unit(unit_nr)
665 :
666 111236 : IF (PRESENT(flop)) flop = 0
667 111236 : IF (PRESENT(nblks_local)) nblks_local = 0
668 :
669 111236 : IF (PRESENT(move_data)) THEN
670 30817 : move_data_1 = move_data
671 30817 : move_data_2 = move_data
672 : ELSE
673 80419 : move_data_1 = .FALSE.
674 80419 : move_data_2 = .FALSE.
675 : END IF
676 :
677 111236 : nodata_3 = .TRUE.
678 111236 : IF (PRESENT(retain_sparsity)) THEN
679 4762 : IF (retain_sparsity) nodata_3 = .FALSE.
680 : END IF
681 :
682 : CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
683 : contract_1, notcontract_1, &
684 : contract_2, notcontract_2, &
685 : bounds_t1, bounds_t2, &
686 : bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
687 111236 : do_crop_1=do_crop_1, do_crop_2=do_crop_2)
688 :
689 111236 : IF (do_crop_1) THEN
690 476588 : ALLOCATE (tensor_crop_1)
691 68084 : CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
692 68084 : move_data_1 = .TRUE.
693 : ELSE
694 : tensor_crop_1 => tensor_1
695 : END IF
696 :
697 111236 : IF (do_crop_2) THEN
698 461986 : ALLOCATE (tensor_crop_2)
699 65998 : CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
700 65998 : move_data_2 = .TRUE.
701 : ELSE
702 : tensor_crop_2 => tensor_2
703 : END IF
704 :
705 : ! shortcut for empty tensors
706 : ! this is needed to avoid unnecessary work in case user contracts different portions of a
707 : ! tensor consecutively to save memory
708 : ASSOCIATE (mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
709 111236 : occ_1 = dbt_get_num_blocks(tensor_crop_1)
710 111236 : CALL mp_comm%max(occ_1)
711 111236 : occ_2 = dbt_get_num_blocks(tensor_crop_2)
712 111236 : CALL mp_comm%max(occ_2)
713 : END ASSOCIATE
714 :
715 111236 : IF (occ_1 == 0 .OR. occ_2 == 0) THEN
716 7447 : CALL dbt_scale(tensor_3, beta)
717 7447 : IF (do_crop_1) THEN
718 2778 : CALL dbt_destroy(tensor_crop_1)
719 2778 : DEALLOCATE (tensor_crop_1)
720 : END IF
721 7447 : IF (do_crop_2) THEN
722 2850 : CALL dbt_destroy(tensor_crop_2)
723 2850 : DEALLOCATE (tensor_crop_2)
724 : END IF
725 :
726 7447 : CALL timestop(handle)
727 7447 : RETURN
728 : END IF
729 :
730 103789 : IF (unit_nr_prv /= 0) THEN
731 46146 : IF (unit_nr_prv > 0) THEN
732 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
733 10 : WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
734 20 : TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name)
735 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
736 : END IF
737 46146 : CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
738 46146 : CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
739 46146 : CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
740 46146 : CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
741 : END IF
742 :
743 : ! align tensor index with data, tensor data is not modified
744 103789 : ndims1 = ndims_tensor(tensor_crop_1)
745 103789 : ndims2 = ndims_tensor(tensor_crop_2)
746 103789 : ndims3 = ndims_tensor(tensor_3)
747 415156 : ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
748 415156 : ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
749 415156 : ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
750 :
751 : ! labeling tensor index with letters
752 :
753 884995 : indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
754 250775 : indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
755 244164 : indchar2(contract_2) = indchar1(contract_1)
756 223816 : indchar3(map_1) = indchar1(notcontract_1)
757 250775 : indchar3(map_2) = indchar2(notcontract_2)
758 :
759 103789 : IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
760 : tensor_crop_2, indchar2, &
761 46146 : tensor_3, indchar3, unit_nr_prv)
762 103789 : IF (unit_nr_prv > 0) THEN
763 10 : WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
764 : END IF
765 :
766 : CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
767 103789 : tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
768 :
769 : CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
770 103789 : tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
771 :
772 : CALL align_tensor(tensor_3, map_1, map_2, &
773 103789 : tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
774 :
775 103789 : IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
776 : tensor_algn_2, indchar2_mod, &
777 46146 : tensor_algn_3, indchar3_mod, unit_nr_prv)
778 :
779 311367 : ALLOCATE (dims1(ndims1))
780 311367 : ALLOCATE (dims2(ndims2))
781 311367 : ALLOCATE (dims3(ndims3))
782 :
783 : ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
784 : ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
785 103789 : CALL blk_dims_tensor(tensor_crop_1, dims1)
786 103789 : CALL blk_dims_tensor(tensor_crop_2, dims2)
787 103789 : CALL blk_dims_tensor(tensor_3, dims3)
788 :
789 : max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), &
790 : PRODUCT(INT(dims1(contract_1), int_8)), &
791 822544 : PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1)
792 1229932 : max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1)
793 24126 : SELECT CASE (max_mm_dim)
794 : CASE (1)
795 24126 : IF (unit_nr_prv > 0) THEN
796 3 : WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
797 3 : WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
798 : END IF
799 24126 : CALL index_linked_sort(contract_1_mod, contract_2_mod)
800 24126 : CALL index_linked_sort(map_2_mod, notcontract_2_mod)
801 23698 : SELECT CASE (max_tensor)
802 : CASE (1)
803 23698 : CALL index_linked_sort(notcontract_1_mod, map_1_mod)
804 : CASE (3)
805 428 : CALL index_linked_sort(map_1_mod, notcontract_1_mod)
806 : CASE DEFAULT
807 24126 : CPABORT("should not happen")
808 : END SELECT
809 :
810 : CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
811 : contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
812 : trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
813 24126 : move_data_1=move_data_1, unit_nr=unit_nr_prv)
814 :
815 : CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
816 24126 : new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
817 :
818 23698 : SELECT CASE (ref_tensor)
819 : CASE (1)
820 23698 : tensor_large => tensor_contr_1
821 : CASE (2)
822 24126 : tensor_large => tensor_contr_3
823 : END SELECT
824 24126 : tensor_small => tensor_contr_2
825 :
826 : CASE (2)
827 36988 : IF (unit_nr_prv > 0) THEN
828 5 : WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
829 5 : WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
830 : END IF
831 :
832 36988 : CALL index_linked_sort(notcontract_1_mod, map_1_mod)
833 36988 : CALL index_linked_sort(notcontract_2_mod, map_2_mod)
834 36238 : SELECT CASE (max_tensor)
835 : CASE (1)
836 36238 : CALL index_linked_sort(contract_1_mod, contract_2_mod)
837 : CASE (2)
838 750 : CALL index_linked_sort(contract_2_mod, contract_1_mod)
839 : CASE DEFAULT
840 36988 : CPABORT("should not happen")
841 : END SELECT
842 :
843 : CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
844 : notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
845 : trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
846 36988 : move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
847 36988 : trans_1 = .NOT. trans_1
848 :
849 : CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
850 36988 : new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
851 :
852 36238 : SELECT CASE (ref_tensor)
853 : CASE (1)
854 36238 : tensor_large => tensor_contr_1
855 : CASE (2)
856 36988 : tensor_large => tensor_contr_2
857 : END SELECT
858 36988 : tensor_small => tensor_contr_3
859 :
860 : CASE (3)
861 42675 : IF (unit_nr_prv > 0) THEN
862 2 : WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
863 2 : WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
864 : END IF
865 42675 : CALL index_linked_sort(map_1_mod, notcontract_1_mod)
866 42675 : CALL index_linked_sort(contract_2_mod, contract_1_mod)
867 42293 : SELECT CASE (max_tensor)
868 : CASE (2)
869 42293 : CALL index_linked_sort(notcontract_2_mod, map_2_mod)
870 : CASE (3)
871 382 : CALL index_linked_sort(map_2_mod, notcontract_2_mod)
872 : CASE DEFAULT
873 42675 : CPABORT("should not happen")
874 : END SELECT
875 :
876 : CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
877 : contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
878 : trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
879 42675 : move_data_1=move_data_2, unit_nr=unit_nr_prv)
880 :
881 42675 : trans_2 = .NOT. trans_2
882 42675 : trans_3 = .NOT. trans_3
883 :
884 : CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
885 42675 : trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
886 :
887 42293 : SELECT CASE (ref_tensor)
888 : CASE (1)
889 42293 : tensor_large => tensor_contr_2
890 : CASE (2)
891 42675 : tensor_large => tensor_contr_3
892 : END SELECT
893 146464 : tensor_small => tensor_contr_1
894 :
895 : END SELECT
896 :
897 103789 : IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
898 : tensor_contr_2, indchar2_mod, &
899 46146 : tensor_contr_3, indchar3_mod, unit_nr_prv)
900 103789 : IF (unit_nr_prv /= 0) THEN
901 46146 : IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
902 46146 : IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
903 46146 : IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
904 46146 : IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
905 : END IF
906 :
907 : CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
908 : tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
909 : beta, &
910 : tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
911 : unit_nr=unit_nr_prv, log_verbose=log_verbose, &
912 : split_opt=split_opt, &
913 103789 : move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
914 :
915 103789 : IF (PRESENT(pgrid_opt_1)) THEN
916 0 : IF (.NOT. new_1) THEN
917 0 : ALLOCATE (pgrid_opt_1)
918 0 : pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
919 : END IF
920 : END IF
921 :
922 103789 : IF (PRESENT(pgrid_opt_2)) THEN
923 0 : IF (.NOT. new_2) THEN
924 0 : ALLOCATE (pgrid_opt_2)
925 0 : pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
926 : END IF
927 : END IF
928 :
929 103789 : IF (PRESENT(pgrid_opt_3)) THEN
930 0 : IF (.NOT. new_3) THEN
931 0 : ALLOCATE (pgrid_opt_3)
932 0 : pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
933 : END IF
934 : END IF
935 :
936 103789 : do_batched = tensor_small%matrix_rep%do_batched > 0
937 :
938 103789 : tensors_remapped = .FALSE.
939 103789 : IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE.
940 :
941 103789 : IF (tensors_remapped .AND. do_batched) THEN
942 : CALL cp_warn(__LOCATION__, &
943 0 : "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
944 : END IF
945 :
946 : ! optimize process grid during batched contraction
947 103789 : do_change_pgrid(:) = .FALSE.
948 103789 : IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
949 559419 : ASSOCIATE (storage => tensor_small%contraction_storage)
950 0 : CPASSERT(storage%static)
951 79917 : split = dbt_tas_info(tensor_large%matrix_rep)
952 : do_change_pgrid(:) = &
953 79917 : update_contraction_storage(storage, split_opt, split)
954 :
955 798206 : IF (ANY(do_change_pgrid)) THEN
956 964 : mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
957 : CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
958 964 : NINT(storage%nsplit_avg), own_comm=.TRUE.)
959 2892 : pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
960 : END IF
961 :
962 : END ASSOCIATE
963 :
964 79917 : IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
965 : ! check if new grid has better subgrid, if not there is no need to change process grid
966 2892 : pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
967 2892 : pdims_sub = split%mp_comm_group%num_pe_cart
968 :
969 4820 : pdim_ratio = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
970 4820 : pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, dp))/MINVAL(pdims_sub_opt)
971 964 : IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
972 0 : do_change_pgrid(1) = .FALSE.
973 0 : CALL dbt_tas_release_info(split_opt_avg)
974 : END IF
975 : END IF
976 : END IF
977 :
978 103789 : IF (unit_nr_prv /= 0) THEN
979 46146 : do_write_3 = .TRUE.
980 46146 : IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
981 20424 : IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE.
982 : END IF
983 : IF (do_write_3) THEN
984 25760 : CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
985 25760 : CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
986 : END IF
987 : END IF
988 :
989 103789 : IF (new_3) THEN
990 : ! need redistribute if we created new tensor for tensor 3
991 188 : CALL dbt_scale(tensor_algn_3, beta)
992 188 : CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.)
993 188 : IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
994 : ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
995 : ! pointer to data of tensor_3
996 : END IF
997 :
998 : ! transfer contraction storage
999 103789 : CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
1000 103789 : CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
1001 103789 : CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
1002 :
1003 103789 : IF (unit_nr_prv /= 0) THEN
1004 46146 : IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1005 46146 : IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
1006 : END IF
1007 :
1008 103789 : CALL dbt_destroy(tensor_algn_1)
1009 103789 : CALL dbt_destroy(tensor_algn_2)
1010 103789 : CALL dbt_destroy(tensor_algn_3)
1011 :
1012 103789 : IF (do_crop_1) THEN
1013 65306 : CALL dbt_destroy(tensor_crop_1)
1014 65306 : DEALLOCATE (tensor_crop_1)
1015 : END IF
1016 :
1017 103789 : IF (do_crop_2) THEN
1018 63148 : CALL dbt_destroy(tensor_crop_2)
1019 63148 : DEALLOCATE (tensor_crop_2)
1020 : END IF
1021 :
1022 103789 : IF (new_1) THEN
1023 202 : CALL dbt_destroy(tensor_contr_1)
1024 202 : DEALLOCATE (tensor_contr_1)
1025 : END IF
1026 103789 : IF (new_2) THEN
1027 80 : CALL dbt_destroy(tensor_contr_2)
1028 80 : DEALLOCATE (tensor_contr_2)
1029 : END IF
1030 103789 : IF (new_3) THEN
1031 188 : CALL dbt_destroy(tensor_contr_3)
1032 188 : DEALLOCATE (tensor_contr_3)
1033 : END IF
1034 :
1035 103789 : IF (PRESENT(move_data)) THEN
1036 29876 : IF (move_data) THEN
1037 26052 : CALL dbt_clear(tensor_1)
1038 26052 : CALL dbt_clear(tensor_2)
1039 : END IF
1040 : END IF
1041 :
1042 103789 : IF (unit_nr_prv > 0) THEN
1043 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1044 10 : WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
1045 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1046 : END IF
1047 :
1048 309439 : IF (ANY(do_change_pgrid)) THEN
1049 964 : pgrid_changed_any = .FALSE.
1050 264 : SELECT CASE (max_mm_dim)
1051 : CASE (1)
1052 264 : IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1053 : CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1054 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1055 : pgrid_changed=pgrid_changed, &
1056 0 : unit_nr=unit_nr_prv)
1057 0 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1058 : CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1059 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1060 : pgrid_changed=pgrid_changed, &
1061 0 : unit_nr=unit_nr_prv)
1062 0 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1063 : END IF
1064 0 : IF (pgrid_changed_any) THEN
1065 0 : IF (tensor_2%matrix_rep%do_batched == 3) THEN
1066 : ! set flag that process grid has been optimized to make sure that no grid optimizations are done
1067 : ! in TAS multiply algorithm
1068 0 : CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
1069 : END IF
1070 : END IF
1071 : CASE (2)
1072 172 : IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
1073 : CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1074 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1075 : pgrid_changed=pgrid_changed, &
1076 172 : unit_nr=unit_nr_prv)
1077 172 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1078 : CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1079 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1080 : pgrid_changed=pgrid_changed, &
1081 172 : unit_nr=unit_nr_prv)
1082 172 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1083 : END IF
1084 8 : IF (pgrid_changed_any) THEN
1085 172 : IF (tensor_3%matrix_rep%do_batched == 3) THEN
1086 160 : CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
1087 : END IF
1088 : END IF
1089 : CASE (3)
1090 528 : IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1091 : CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1092 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1093 : pgrid_changed=pgrid_changed, &
1094 214 : unit_nr=unit_nr_prv)
1095 214 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1096 : CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1097 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1098 : pgrid_changed=pgrid_changed, &
1099 214 : unit_nr=unit_nr_prv)
1100 214 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1101 : END IF
1102 964 : IF (pgrid_changed_any) THEN
1103 214 : IF (tensor_1%matrix_rep%do_batched == 3) THEN
1104 214 : CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
1105 : END IF
1106 : END IF
1107 : END SELECT
1108 964 : CALL dbt_tas_release_info(split_opt_avg)
1109 : END IF
1110 :
1111 103789 : IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
1112 : ! freeze TAS process grids if tensor grids were optimized
1113 79917 : CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.)
1114 79917 : CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.)
1115 79917 : CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.)
1116 : END IF
1117 :
1118 103789 : CALL dbt_tas_release_info(split_opt)
1119 :
1120 103789 : CALL timestop(handle)
1121 :
1122 222472 : END SUBROUTINE
1123 :
1124 : ! **************************************************************************************************
1125 : !> \brief align tensor index with data
1126 : !> \author Patrick Seewald
1127 : ! **************************************************************************************************
1128 2490936 : SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1129 311367 : tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1130 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1131 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_in, notcontract_in
1132 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1133 : INTEGER, DIMENSION(SIZE(contract_in)), &
1134 : INTENT(OUT) :: contract_out
1135 : INTEGER, DIMENSION(SIZE(notcontract_in)), &
1136 : INTENT(OUT) :: notcontract_out
1137 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
1138 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
1139 311367 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
1140 :
1141 311367 : CALL dbt_align_index(tensor_in, tensor_out, order=align)
1142 712144 : contract_out = align(contract_in)
1143 725366 : notcontract_out = align(notcontract_in)
1144 1126143 : indp_out(align) = indp_in
1145 :
1146 311367 : END SUBROUTINE
1147 :
1148 : ! **************************************************************************************************
1149 : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1150 : !> matrix multiplication. This routine reshapes the two largest of the three tensors.
1151 : !> Redistribution is avoided if tensors already in a consistent layout.
1152 : !> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
1153 : !> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
1154 : !> 1:1 correspondence with ind1_linked
1155 : !> \param trans1 transpose flag of matrix rep. tensor 1
1156 : !> \param trans2 transpose flag of matrix rep. tensor 2
1157 : !> \param new1 whether a new tensor 1 was created
1158 : !> \param new2 whether a new tensor 2 was created
1159 : !> \param nodata1 don't copy data of tensor 1
1160 : !> \param nodata2 don't copy data of tensor 2
1161 : !> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
1162 : !> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
1163 : !> \param optimize_dist experimental: optimize distribution
1164 : !> \param unit_nr output unit
1165 : !> \author Patrick Seewald
1166 : ! **************************************************************************************************
1167 103789 : SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1168 103789 : ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1169 : nodata1, nodata2, move_data_1, &
1170 : move_data_2, optimize_dist, unit_nr)
1171 : TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor1
1172 : TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor2
1173 : TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor1_out, tensor2_out
1174 : INTEGER, DIMENSION(:), INTENT(IN) :: ind1_free, ind2_free
1175 : INTEGER, DIMENSION(:), INTENT(IN) :: ind1_linked, ind2_linked
1176 : LOGICAL, INTENT(OUT) :: trans1, trans2
1177 : LOGICAL, INTENT(OUT) :: new1, new2
1178 : INTEGER, INTENT(OUT) :: ref_tensor
1179 : LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
1180 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
1181 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
1182 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1183 : INTEGER :: compat1, compat1_old, compat2, compat2_old, &
1184 : unit_nr_prv
1185 103789 : TYPE(mp_cart_type) :: comm_2d
1186 103789 : TYPE(array_list) :: dist_list
1187 103789 : INTEGER, DIMENSION(:), ALLOCATABLE :: mp_dims
1188 726523 : TYPE(dbt_distribution_type) :: dist_in
1189 : INTEGER(KIND=int_8) :: nblkrows, nblkcols
1190 : LOGICAL :: optimize_dist_prv
1191 207578 : INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
1192 207578 : INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
1193 :
1194 103789 : NULLIFY (tensor1_out, tensor2_out)
1195 :
1196 103789 : unit_nr_prv = prep_output_unit(unit_nr)
1197 :
1198 103789 : CALL blk_dims_tensor(tensor1, dims1)
1199 103789 : CALL blk_dims_tensor(tensor2, dims2)
1200 :
1201 709535 : IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN
1202 102229 : ref_tensor = 1
1203 : ELSE
1204 1560 : ref_tensor = 2
1205 : END IF
1206 :
1207 103789 : IF (PRESENT(optimize_dist)) THEN
1208 182 : optimize_dist_prv = optimize_dist
1209 : ELSE
1210 : optimize_dist_prv = .FALSE.
1211 : END IF
1212 :
1213 103789 : compat1 = compat_map(tensor1%nd_index, ind1_linked)
1214 103789 : compat2 = compat_map(tensor2%nd_index, ind2_linked)
1215 103789 : compat1_old = compat1
1216 103789 : compat2_old = compat2
1217 :
1218 103789 : IF (unit_nr_prv > 0) THEN
1219 10 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":"
1220 6 : SELECT CASE (compat1)
1221 : CASE (0)
1222 6 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1223 : CASE (1)
1224 3 : WRITE (unit_nr_prv, '(A)') "Normal"
1225 : CASE (2)
1226 10 : WRITE (unit_nr_prv, '(A)') "Transposed"
1227 : END SELECT
1228 10 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":"
1229 5 : SELECT CASE (compat2)
1230 : CASE (0)
1231 5 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1232 : CASE (1)
1233 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1234 : CASE (2)
1235 10 : WRITE (unit_nr_prv, '(A)') "Transposed"
1236 : END SELECT
1237 : END IF
1238 :
1239 103789 : new1 = .FALSE.
1240 103789 : new2 = .FALSE.
1241 :
1242 103789 : IF (compat1 == 0 .OR. optimize_dist_prv) THEN
1243 194 : new1 = .TRUE.
1244 : END IF
1245 :
1246 103789 : IF (compat2 == 0 .OR. optimize_dist_prv) THEN
1247 254 : new2 = .TRUE.
1248 : END IF
1249 :
1250 103789 : IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
1251 102229 : IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1252 114 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name)
1253 342 : nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8))
1254 230 : nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8))
1255 114 : comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1256 798 : ALLOCATE (tensor1_out)
1257 : CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1258 114 : nodata=nodata1, move_data=move_data_1)
1259 114 : CALL comm_2d%free()
1260 114 : compat1 = 1
1261 : ELSE
1262 102115 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
1263 102115 : tensor1_out => tensor1
1264 : END IF
1265 102229 : IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1266 174 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
1267 8 : TRIM(tensor2%name), "compatible with", TRIM(tensor1%name)
1268 170 : dist_in = dbt_distribution(tensor1_out)
1269 170 : dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1270 170 : IF (compat1 == 1) THEN ! linked index is first 2d dimension
1271 : ! get distribution of linked index, tensor 2 must adopt this distribution
1272 : ! get grid dimensions of linked index
1273 480 : ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1274 160 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1275 1120 : ALLOCATE (tensor2_out)
1276 : CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1277 160 : dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1278 10 : ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
1279 : ! get distribution of linked index, tensor 2 must adopt this distribution
1280 : ! get grid dimensions of linked index
1281 30 : ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1282 10 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1283 70 : ALLOCATE (tensor2_out)
1284 : CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1285 10 : dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1286 : ELSE
1287 0 : CPABORT("should not happen")
1288 : END IF
1289 : compat2 = compat1
1290 : ELSE
1291 102059 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
1292 102059 : tensor2_out => tensor2
1293 : END IF
1294 : ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
1295 1560 : IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1296 84 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name)
1297 178 : nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8))
1298 170 : nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8))
1299 84 : comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1300 588 : ALLOCATE (tensor2_out)
1301 84 : CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1302 84 : CALL comm_2d%free()
1303 84 : compat2 = 1
1304 : ELSE
1305 1476 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
1306 1476 : tensor2_out => tensor2
1307 : END IF
1308 1560 : IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1309 83 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), &
1310 6 : "compatible with", TRIM(tensor2%name)
1311 80 : dist_in = dbt_distribution(tensor2_out)
1312 80 : dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1313 80 : IF (compat2 == 1) THEN
1314 234 : ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1315 78 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1316 546 : ALLOCATE (tensor1_out)
1317 : CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1318 78 : dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1319 2 : ELSEIF (compat2 == 2) THEN
1320 6 : ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1321 2 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1322 14 : ALLOCATE (tensor1_out)
1323 : CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1324 2 : dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1325 : ELSE
1326 0 : CPABORT("should not happen")
1327 : END IF
1328 : compat1 = compat2
1329 : ELSE
1330 1480 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
1331 1480 : tensor1_out => tensor1
1332 : END IF
1333 : END IF
1334 :
1335 66965 : SELECT CASE (compat1)
1336 : CASE (1)
1337 66965 : trans1 = .FALSE.
1338 : CASE (2)
1339 36824 : trans1 = .TRUE.
1340 : CASE DEFAULT
1341 103789 : CPABORT("should not happen")
1342 : END SELECT
1343 :
1344 69776 : SELECT CASE (compat2)
1345 : CASE (1)
1346 69776 : trans2 = .FALSE.
1347 : CASE (2)
1348 34013 : trans2 = .TRUE.
1349 : CASE DEFAULT
1350 103789 : CPABORT("should not happen")
1351 : END SELECT
1352 :
1353 103789 : IF (unit_nr_prv > 0) THEN
1354 10 : IF (compat1 .NE. compat1_old) THEN
1355 6 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":"
1356 0 : SELECT CASE (compat1)
1357 : CASE (0)
1358 0 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1359 : CASE (1)
1360 5 : WRITE (unit_nr_prv, '(A)') "Normal"
1361 : CASE (2)
1362 6 : WRITE (unit_nr_prv, '(A)') "Transposed"
1363 : END SELECT
1364 : END IF
1365 10 : IF (compat2 .NE. compat2_old) THEN
1366 5 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":"
1367 0 : SELECT CASE (compat2)
1368 : CASE (0)
1369 0 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1370 : CASE (1)
1371 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1372 : CASE (2)
1373 5 : WRITE (unit_nr_prv, '(A)') "Transposed"
1374 : END SELECT
1375 : END IF
1376 : END IF
1377 :
1378 103789 : IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE.
1379 103789 : IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE.
1380 :
1381 103789 : END SUBROUTINE
1382 :
1383 : ! **************************************************************************************************
1384 : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1385 : !> matrix multiplication. This routine reshapes the smallest of the three tensors.
1386 : !> \param ind1 index that should be mapped to first matrix dimension
1387 : !> \param ind2 index that should be mapped to second matrix dimension
1388 : !> \param trans transpose flag of matrix rep.
1389 : !> \param new whether a new tensor was created for tensor_out
1390 : !> \param nodata don't copy tensor data
1391 : !> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
1392 : !> \param unit_nr output unit
1393 : !> \author Patrick Seewald
1394 : ! **************************************************************************************************
1395 103789 : SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1396 : TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor_in
1397 : INTEGER, DIMENSION(:), INTENT(IN) :: ind1, ind2
1398 : TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor_out
1399 : LOGICAL, INTENT(OUT) :: trans
1400 : LOGICAL, INTENT(OUT) :: new
1401 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1402 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1403 : INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1404 : LOGICAL :: nodata_prv
1405 :
1406 103789 : NULLIFY (tensor_out)
1407 : IF (PRESENT(nodata)) THEN
1408 103789 : nodata_prv = nodata
1409 : ELSE
1410 : nodata_prv = .FALSE.
1411 : END IF
1412 :
1413 103789 : unit_nr_prv = prep_output_unit(unit_nr)
1414 :
1415 103789 : new = .FALSE.
1416 103789 : compat1 = compat_map(tensor_in%nd_index, ind1)
1417 103789 : compat2 = compat_map(tensor_in%nd_index, ind2)
1418 103789 : compat1_old = compat1; compat2_old = compat2
1419 103789 : IF (unit_nr_prv > 0) THEN
1420 10 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":"
1421 10 : IF (compat1 == 1 .AND. compat2 == 2) THEN
1422 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1423 6 : ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1424 2 : WRITE (unit_nr_prv, '(A)') "Transposed"
1425 : ELSE
1426 4 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1427 : END IF
1428 : END IF
1429 103789 : IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
1430 :
1431 22 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name)
1432 :
1433 154 : ALLOCATE (tensor_out)
1434 22 : CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1435 22 : CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
1436 22 : compat1 = 1
1437 22 : compat2 = 2
1438 22 : new = .TRUE.
1439 : ELSE
1440 103767 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name)
1441 103767 : tensor_out => tensor_in
1442 : END IF
1443 :
1444 103789 : IF (compat1 == 1 .AND. compat2 == 2) THEN
1445 82117 : trans = .FALSE.
1446 21672 : ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1447 21672 : trans = .TRUE.
1448 : ELSE
1449 0 : CPABORT("this should not happen")
1450 : END IF
1451 :
1452 103789 : IF (unit_nr_prv > 0) THEN
1453 10 : IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
1454 4 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":"
1455 4 : IF (compat1 == 1 .AND. compat2 == 2) THEN
1456 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1457 0 : ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1458 0 : WRITE (unit_nr_prv, '(A)') "Transposed"
1459 : ELSE
1460 0 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1461 : END IF
1462 : END IF
1463 : END IF
1464 :
1465 103789 : END SUBROUTINE
1466 :
1467 : ! **************************************************************************************************
1468 : !> \brief update contraction storage that keeps track of process grids during a batched contraction
1469 : !> and decide if tensor process grid needs to be optimized
1470 : !> \param split_opt optimized TAS process grid
1471 : !> \param split current TAS process grid
1472 : !> \author Patrick Seewald
1473 : ! **************************************************************************************************
1474 79917 : FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
1475 : TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
1476 : TYPE(dbt_tas_split_info), INTENT(IN) :: split_opt
1477 : TYPE(dbt_tas_split_info), INTENT(IN) :: split
1478 : INTEGER, DIMENSION(2) :: pdims, pdims_sub
1479 : LOGICAL, DIMENSION(2) :: do_change_pgrid
1480 : REAL(kind=dp) :: change_criterion, pdims_ratio
1481 : INTEGER :: nsplit_opt, nsplit
1482 :
1483 79917 : CPASSERT(ALLOCATED(split_opt%ngroup_opt))
1484 79917 : nsplit_opt = split_opt%ngroup_opt
1485 79917 : nsplit = split%ngroup
1486 :
1487 239751 : pdims = split%mp_comm%num_pe_cart
1488 :
1489 79917 : storage%ibatch = storage%ibatch + 1
1490 :
1491 : storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, dp) + REAL(nsplit_opt, dp)) &
1492 79917 : /REAL(storage%ibatch, dp)
1493 :
1494 79917 : SELECT CASE (split_opt%split_rowcol)
1495 : CASE (rowsplit)
1496 79917 : pdims_ratio = REAL(pdims(1), dp)/pdims(2)
1497 : CASE (colsplit)
1498 79917 : pdims_ratio = REAL(pdims(2), dp)/pdims(1)
1499 : END SELECT
1500 :
1501 239751 : do_change_pgrid(:) = .FALSE.
1502 :
1503 : ! check for process grid dimensions
1504 239751 : pdims_sub = split%mp_comm_group%num_pe_cart
1505 479502 : change_criterion = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
1506 79917 : IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.
1507 :
1508 : ! check for split factor
1509 79917 : change_criterion = MAX(REAL(nsplit, dp)/storage%nsplit_avg, REAL(storage%nsplit_avg, dp)/nsplit)
1510 79917 : IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE.
1511 :
1512 79917 : END FUNCTION
1513 :
1514 : ! **************************************************************************************************
1515 : !> \brief Check if 2d index is compatible with tensor index
1516 : !> \author Patrick Seewald
1517 : ! **************************************************************************************************
1518 415156 : FUNCTION compat_map(nd_index, compat_ind)
1519 : TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
1520 : INTEGER, DIMENSION(:), INTENT(IN) :: compat_ind
1521 830312 : INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
1522 830312 : INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
1523 : INTEGER :: compat_map
1524 :
1525 415156 : CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1526 :
1527 415156 : compat_map = 0
1528 415156 : IF (array_eq_i(map1, compat_ind)) THEN
1529 : compat_map = 1
1530 175008 : ELSEIF (array_eq_i(map2, compat_ind)) THEN
1531 174890 : compat_map = 2
1532 : END IF
1533 :
1534 415156 : END FUNCTION
1535 :
1536 : ! **************************************************************************************************
1537 : !> \brief
1538 : !> \author Patrick Seewald
1539 : ! **************************************************************************************************
1540 311367 : SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1541 : INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
1542 622734 : INTEGER, DIMENSION(SIZE(ind_ref)) :: sort_indices
1543 :
1544 311367 : CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
1545 1437510 : ind_dep(:) = ind_dep(sort_indices)
1546 :
1547 311367 : END SUBROUTINE
1548 :
1549 : ! **************************************************************************************************
1550 : !> \brief
1551 : !> \author Patrick Seewald
1552 : ! **************************************************************************************************
1553 0 : FUNCTION opt_pgrid(tensor, tas_split_info)
1554 : TYPE(dbt_type), INTENT(IN) :: tensor
1555 : TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
1556 0 : INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
1557 0 : INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
1558 : TYPE(dbt_pgrid_type) :: opt_pgrid
1559 0 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
1560 :
1561 0 : CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1562 0 : CALL blk_dims_tensor(tensor, dims)
1563 0 : opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1564 :
1565 0 : ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info)
1566 0 : CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
1567 0 : END FUNCTION
1568 :
1569 : ! **************************************************************************************************
1570 : !> \brief Copy tensor to tensor with modified index mapping
1571 : !> \param map1_2d new index mapping
1572 : !> \param map2_2d new index mapping
1573 : !> \author Patrick Seewald
1574 : ! **************************************************************************************************
1575 3760 : SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1576 470 : mp_dims_1, mp_dims_2, name, nodata, move_data)
1577 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1578 : INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
1579 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1580 : CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
1581 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1582 : CLASS(mp_comm_type), INTENT(IN), OPTIONAL :: comm_2d
1583 : TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
1584 : INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
1585 : INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
1586 : CHARACTER(len=default_string_length) :: name_tmp
1587 470 : INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("blk_sizes")}$, &
1588 470 : ${varlist("nd_dist")}$
1589 3290 : TYPE(dbt_distribution_type) :: dist
1590 470 : TYPE(mp_cart_type) :: comm_2d_prv
1591 : INTEGER :: handle, i
1592 470 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1593 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_remap'
1594 : LOGICAL :: nodata_prv
1595 1410 : TYPE(dbt_pgrid_type) :: comm_nd
1596 :
1597 470 : CALL timeset(routineN, handle)
1598 :
1599 470 : IF (PRESENT(name)) THEN
1600 0 : name_tmp = name
1601 : ELSE
1602 470 : name_tmp = tensor_in%name
1603 : END IF
1604 470 : IF (PRESENT(dist1)) THEN
1605 238 : CPASSERT(PRESENT(mp_dims_1))
1606 : END IF
1607 :
1608 470 : IF (PRESENT(dist2)) THEN
1609 12 : CPASSERT(PRESENT(mp_dims_2))
1610 : END IF
1611 :
1612 470 : IF (PRESENT(comm_2d)) THEN
1613 364 : comm_2d_prv = comm_2d
1614 : ELSE
1615 106 : comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1616 : END IF
1617 :
1618 470 : comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1619 470 : CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1620 :
1621 : #:for ndim in ndims
1622 866 : IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1623 396 : CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$)
1624 : END IF
1625 : #:endfor
1626 :
1627 : #:for ndim in ndims
1628 936 : IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1629 : #:for idim in range(1, ndim+1)
1630 1340 : IF (PRESENT(dist1)) THEN
1631 1594 : IF (ANY(map1_2d == ${idim}$)) THEN
1632 1132 : i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d
1633 240 : CALL get_ith_array(dist1, i, nd_dist_${idim}$)
1634 : END IF
1635 : END IF
1636 :
1637 1340 : IF (PRESENT(dist2)) THEN
1638 80 : IF (ANY(map2_2d == ${idim}$)) THEN
1639 40 : i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d
1640 16 : CALL get_ith_array(dist2, i, nd_dist_${idim}$)
1641 : END IF
1642 : END IF
1643 :
1644 1340 : IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN
1645 2766 : ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$)))
1646 922 : CALL dbt_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$)
1647 : END IF
1648 : #:endfor
1649 : CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1650 470 : ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.)
1651 : END IF
1652 : #:endfor
1653 :
1654 : #:for ndim in ndims
1655 936 : IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1656 : CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1657 470 : ${varlist("blk_sizes", nmax=ndim)}$)
1658 : END IF
1659 : #:endfor
1660 :
1661 470 : IF (PRESENT(nodata)) THEN
1662 188 : nodata_prv = nodata
1663 : ELSE
1664 : nodata_prv = .FALSE.
1665 : END IF
1666 :
1667 470 : IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
1668 470 : CALL dbt_distribution_destroy(dist)
1669 :
1670 470 : CALL timestop(handle)
1671 940 : END SUBROUTINE
1672 :
1673 : ! **************************************************************************************************
1674 : !> \brief Align index with data
1675 : !> \param order permutation resulting from alignment
1676 : !> \author Patrick Seewald
1677 : ! **************************************************************************************************
1678 2179569 : SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
1679 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1680 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1681 622734 : INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1682 622734 : INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1683 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1684 : INTENT(OUT), OPTIONAL :: order
1685 311367 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: order_prv
1686 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_align_index'
1687 : INTEGER :: handle
1688 :
1689 311367 : CALL timeset(routineN, handle)
1690 :
1691 311367 : CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1692 1940919 : order_prv = dbt_inverse_order([map1_2d, map2_2d])
1693 311367 : CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
1694 :
1695 1126143 : IF (PRESENT(order)) order = order_prv
1696 :
1697 311367 : CALL timestop(handle)
1698 311367 : END SUBROUTINE
1699 :
1700 : ! **************************************************************************************************
1701 : !> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
1702 : !> \author Patrick Seewald
1703 : ! **************************************************************************************************
1704 3216008 : SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
1705 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1706 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1707 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1708 : INTENT(IN) :: order
1709 :
1710 2010005 : TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs
1711 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_permute_index'
1712 : INTEGER :: handle
1713 : INTEGER :: ndims
1714 :
1715 402001 : CALL timeset(routineN, handle)
1716 :
1717 402001 : ndims = ndims_tensor(tensor_in)
1718 :
1719 402001 : CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1720 402001 : CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1721 402001 : CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1722 :
1723 402001 : tensor_out%matrix_rep => tensor_in%matrix_rep
1724 402001 : tensor_out%owns_matrix = .FALSE.
1725 :
1726 402001 : tensor_out%nd_index = nd_index_rs
1727 402001 : tensor_out%nd_index_blk = nd_index_blk_rs
1728 402001 : tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1729 402001 : IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
1730 402001 : ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info)
1731 : END IF
1732 402001 : tensor_out%refcount => tensor_in%refcount
1733 402001 : CALL dbt_hold(tensor_out)
1734 :
1735 402001 : CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1736 402001 : CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1737 402001 : CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1738 402001 : CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1739 1206003 : ALLOCATE (tensor_out%nblks_local(ndims))
1740 1206003 : ALLOCATE (tensor_out%nfull_local(ndims))
1741 1469099 : tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1742 1469099 : tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1743 402001 : tensor_out%name = tensor_in%name
1744 402001 : tensor_out%valid = .TRUE.
1745 :
1746 402001 : IF (ALLOCATED(tensor_in%contraction_storage)) THEN
1747 282191 : ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage)
1748 282191 : CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1749 282191 : CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1750 : END IF
1751 :
1752 402001 : CALL timestop(handle)
1753 804002 : END SUBROUTINE
1754 :
1755 : ! **************************************************************************************************
1756 : !> \brief Map contraction bounds to bounds referring to tensor indices
1757 : !> see dbt_contract for docu of dummy arguments
1758 : !> \param bounds_t1 bounds mapped to tensor_1
1759 : !> \param bounds_t2 bounds mapped to tensor_2
1760 : !> \param do_crop_1 whether tensor 1 should be cropped
1761 : !> \param do_crop_2 whether tensor 2 should be cropped
1762 : !> \author Patrick Seewald
1763 : ! **************************************************************************************************
1764 111236 : SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
1765 111236 : contract_1, notcontract_1, &
1766 222472 : contract_2, notcontract_2, &
1767 111236 : bounds_t1, bounds_t2, &
1768 114954 : bounds_1, bounds_2, bounds_3, &
1769 : do_crop_1, do_crop_2)
1770 :
1771 : TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2
1772 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_1, contract_2, &
1773 : notcontract_1, notcontract_2
1774 : INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
1775 : INTENT(OUT) :: bounds_t1
1776 : INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
1777 : INTENT(OUT) :: bounds_t2
1778 : INTEGER, DIMENSION(2, SIZE(contract_1)), &
1779 : INTENT(IN), OPTIONAL :: bounds_1
1780 : INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1781 : INTENT(IN), OPTIONAL :: bounds_2
1782 : INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1783 : INTENT(IN), OPTIONAL :: bounds_3
1784 : LOGICAL, INTENT(OUT), OPTIONAL :: do_crop_1, do_crop_2
1785 : LOGICAL, DIMENSION(2) :: do_crop
1786 :
1787 111236 : do_crop = .FALSE.
1788 :
1789 390168 : bounds_t1(1, :) = 1
1790 390168 : CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1791 :
1792 415102 : bounds_t2(1, :) = 1
1793 415102 : CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1794 :
1795 111236 : IF (PRESENT(bounds_1)) THEN
1796 168742 : bounds_t1(:, contract_1) = bounds_1
1797 24958 : do_crop(1) = .TRUE.
1798 168742 : bounds_t2(:, contract_2) = bounds_1
1799 111236 : do_crop(2) = .TRUE.
1800 : END IF
1801 :
1802 111236 : IF (PRESENT(bounds_2)) THEN
1803 227752 : bounds_t1(:, notcontract_1) = bounds_2
1804 111236 : do_crop(1) = .TRUE.
1805 : END IF
1806 :
1807 111236 : IF (PRESENT(bounds_3)) THEN
1808 252958 : bounds_t2(:, notcontract_2) = bounds_3
1809 111236 : do_crop(2) = .TRUE.
1810 : END IF
1811 :
1812 111236 : IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1813 111236 : IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1814 :
1815 267674 : END SUBROUTINE
1816 :
1817 : ! **************************************************************************************************
1818 : !> \brief print tensor contraction indices in a human readable way
1819 : !> \param indchar1 characters printed for index of tensor 1
1820 : !> \param indchar2 characters printed for index of tensor 2
1821 : !> \param indchar3 characters printed for index of tensor 3
1822 : !> \param unit_nr output unit
1823 : !> \author Patrick Seewald
1824 : ! **************************************************************************************************
1825 138438 : SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1826 : TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
1827 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
1828 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
1829 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
1830 : INTEGER, INTENT(IN) :: unit_nr
1831 276876 : INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1832 276876 : INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1833 276876 : INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1834 276876 : INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1835 276876 : INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1836 276876 : INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1837 : INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1838 :
1839 138438 : unit_nr_prv = prep_output_unit(unit_nr)
1840 :
1841 138438 : IF (unit_nr_prv /= 0) THEN
1842 138438 : CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1843 138438 : CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
1844 138438 : CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
1845 : END IF
1846 :
1847 138438 : IF (unit_nr_prv > 0) THEN
1848 30 : WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
1849 30 : WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
1850 123 : DO ichar1 = 1, SIZE(indchar1)
1851 123 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
1852 : END DO
1853 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
1854 120 : DO ichar2 = 1, SIZE(indchar2)
1855 120 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
1856 : END DO
1857 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
1858 123 : DO ichar3 = 1, SIZE(indchar3)
1859 123 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
1860 : END DO
1861 30 : WRITE (unit_nr_prv, '(A)') ")"
1862 :
1863 30 : WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
1864 82 : DO ichar1 = 1, SIZE(map11)
1865 82 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
1866 : END DO
1867 30 : WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1868 71 : DO ichar1 = 1, SIZE(map12)
1869 71 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
1870 : END DO
1871 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
1872 76 : DO ichar2 = 1, SIZE(map21)
1873 76 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
1874 : END DO
1875 30 : WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1876 74 : DO ichar2 = 1, SIZE(map22)
1877 74 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
1878 : END DO
1879 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
1880 79 : DO ichar3 = 1, SIZE(map31)
1881 79 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
1882 : END DO
1883 30 : WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1884 74 : DO ichar3 = 1, SIZE(map32)
1885 74 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
1886 : END DO
1887 30 : WRITE (unit_nr_prv, '(A)') ")"
1888 : END IF
1889 :
1890 138438 : END SUBROUTINE
1891 :
1892 : ! **************************************************************************************************
1893 : !> \brief Initialize batched contraction for this tensor.
1894 : !>
1895 : !> Explanation: A batched contraction is a contraction performed in several consecutive steps
1896 : !> by specification of bounds in dbt_contract. This can be used to reduce memory by
1897 : !> a large factor. The routines dbt_batched_contract_init and
1898 : !> dbt_batched_contract_finalize should be called to define the scope of a batched
1899 : !> contraction as this enables important optimizations (adapting communication scheme to
1900 : !> batches and adapting process grid to multiplication algorithm). The routines
1901 : !> dbt_batched_contract_init and dbt_batched_contract_finalize must be
1902 : !> called before the first and after the last contraction step on all 3 tensors.
1903 : !>
1904 : !> Requirements:
1905 : !> - the tensors are in a compatible matrix layout (see documentation of
1906 : !> `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
1907 : !> disabled and a warning is issued.
1908 : !> - within the scope of a batched contraction, it is not allowed to access or change tensor
1909 : !> data except by calling the routines dbt_contract & dbt_copy.
1910 : !> - the bounds affecting indices of the smallest tensor must not change in the course of a
1911 : !> batched contraction (todo: get rid of this requirement).
1912 : !>
1913 : !> Side effects:
1914 : !> - the parallel layout (process grid and distribution) of all tensors may change. In order
1915 : !> to disable the process grid optimization including this side effect, call this routine
1916 : !> only on the smallest of the 3 tensors.
1917 : !>
1918 : !> \note
1919 : !> Note 1: for an example of batched contraction see `examples/dbt_example.F`.
1920 : !> (todo: the example is outdated and should be updated).
1921 : !>
1922 : !> Note 2: it is meaningful to use this feature if the contraction consists of one batch only
1923 : !> but if multiple contractions involving the same 3 tensors are performed
1924 : !> (batched_contract_init and batched_contract_finalize must then be called before/after each
1925 : !> contraction call). The process grid is then optimized after the first contraction
1926 : !> and future contraction may profit from this optimization.
1927 : !>
1928 : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
1929 : !> a new range. The size should be the number of ranges plus one, the last
1930 : !> element being the block index plus one of the last block in the last range.
1931 : !> For internal load balancing optimizations, optionally specify the index
1932 : !> ranges of batched contraction.
1933 : !> \author Patrick Seewald
1934 : ! **************************************************************************************************
1935 99749 : SUBROUTINE dbt_batched_contract_init(tensor, ${varlist("batch_range")}$)
1936 : TYPE(dbt_type), INTENT(INOUT) :: tensor
1937 : INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: ${varlist("batch_range")}$
1938 199498 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
1939 99749 : INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range_prv")}$
1940 : LOGICAL :: static_range
1941 :
1942 99749 : CALL dbt_get_info(tensor, nblks_total=tdims)
1943 :
1944 99749 : static_range = .TRUE.
1945 : #:for idim in range(1, maxdim+1)
1946 99749 : IF (ndims_tensor(tensor) >= ${idim}$) THEN
1947 233720 : IF (PRESENT(batch_range_${idim}$)) THEN
1948 370776 : ALLOCATE (batch_range_prv_${idim}$, source=batch_range_${idim}$)
1949 233720 : static_range = .FALSE.
1950 : ELSE
1951 176926 : ALLOCATE (batch_range_prv_${idim}$ (2))
1952 176926 : batch_range_prv_${idim}$ (1) = 1
1953 176926 : batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1
1954 : END IF
1955 : END IF
1956 : #:endfor
1957 :
1958 99749 : ALLOCATE (tensor%contraction_storage)
1959 99749 : tensor%contraction_storage%static = static_range
1960 99749 : IF (static_range) THEN
1961 68117 : CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
1962 : END IF
1963 99749 : tensor%contraction_storage%nsplit_avg = 0.0_dp
1964 99749 : tensor%contraction_storage%ibatch = 0
1965 :
1966 : #:for ndim in range(1, maxdim+1)
1967 199498 : IF (ndims_tensor(tensor) == ${ndim}$) THEN
1968 : CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, &
1969 99749 : ${varlist("batch_range_prv", nmax=ndim)}$)
1970 : END IF
1971 : #:endfor
1972 :
1973 99749 : END SUBROUTINE
1974 :
1975 : ! **************************************************************************************************
1976 : !> \brief finalize batched contraction. This performs all communication that has been postponed in
1977 : !> the contraction calls.
1978 : !> \author Patrick Seewald
1979 : ! **************************************************************************************************
1980 199498 : SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
1981 : TYPE(dbt_type), INTENT(INOUT) :: tensor
1982 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1983 : LOGICAL :: do_write
1984 : INTEGER :: unit_nr_prv, handle
1985 :
1986 99749 : CALL tensor%pgrid%mp_comm_2d%sync()
1987 99749 : CALL timeset("dbt_total", handle)
1988 99749 : unit_nr_prv = prep_output_unit(unit_nr)
1989 :
1990 99749 : do_write = .FALSE.
1991 :
1992 99749 : IF (tensor%contraction_storage%static) THEN
1993 68117 : IF (tensor%matrix_rep%do_batched > 0) THEN
1994 68117 : IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE.
1995 : END IF
1996 68117 : CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
1997 : END IF
1998 :
1999 99749 : IF (do_write .AND. unit_nr_prv /= 0) THEN
2000 15666 : IF (unit_nr_prv > 0) THEN
2001 : WRITE (unit_nr_prv, "(T2,A)") &
2002 0 : "FINALIZING BATCHED PROCESSING OF MATMUL"
2003 : END IF
2004 15666 : CALL dbt_write_tensor_info(tensor, unit_nr_prv)
2005 15666 : CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
2006 : END IF
2007 :
2008 99749 : CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
2009 99749 : DEALLOCATE (tensor%contraction_storage)
2010 99749 : CALL tensor%pgrid%mp_comm_2d%sync()
2011 99749 : CALL timestop(handle)
2012 :
2013 99749 : END SUBROUTINE
2014 :
2015 : ! **************************************************************************************************
2016 : !> \brief change the process grid of a tensor
2017 : !> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
2018 : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2019 : !> a new range. The size should be the number of ranges plus one, the last
2020 : !> element being the block index plus one of the last block in the last range.
2021 : !> For internal load balancing optimizations, optionally specify the index
2022 : !> ranges of batched contraction.
2023 : !> \author Patrick Seewald
2024 : ! **************************************************************************************************
2025 772 : SUBROUTINE dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, &
2026 : nodata, pgrid_changed, unit_nr)
2027 : TYPE(dbt_type), INTENT(INOUT) :: tensor
2028 : TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
2029 : INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: ${varlist("batch_range")}$
2030 : !!
2031 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
2032 : LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2033 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2034 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_change_pgrid'
2035 : CHARACTER(default_string_length) :: name
2036 : INTEGER :: handle
2037 772 : INTEGER, ALLOCATABLE, DIMENSION(:) :: ${varlist("bs")}$, &
2038 772 : ${varlist("dist")}$
2039 1544 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, &
2040 1544 : tdims
2041 5404 : TYPE(dbt_type) :: t_tmp
2042 5404 : TYPE(dbt_distribution_type) :: dist
2043 1544 : INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2044 : INTEGER, &
2045 1544 : DIMENSION(ndims_matrix_column(tensor)) :: map2
2046 1544 : LOGICAL, DIMENSION(ndims_tensor(tensor)) :: mem_aware
2047 772 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
2048 : INTEGER :: ind1, ind2, batch_size, ibatch
2049 :
2050 772 : IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE.
2051 772 : CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2052 772 : CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
2053 :
2054 796 : IF (ALL(pdims == pdims_ref)) THEN
2055 8 : IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
2056 8 : IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
2057 : RETURN
2058 : END IF
2059 : END IF
2060 : END IF
2061 :
2062 764 : CALL timeset(routineN, handle)
2063 :
2064 : #:for idim in range(1, maxdim+1)
2065 3056 : IF (ndims_tensor(tensor) >= ${idim}$) THEN
2066 2292 : mem_aware(${idim}$) = PRESENT(batch_range_${idim}$)
2067 2292 : IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1
2068 : END IF
2069 : #:endfor
2070 :
2071 764 : CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
2072 :
2073 : #:for idim in range(1, maxdim+1)
2074 3056 : IF (ndims_tensor(tensor) >= ${idim}$) THEN
2075 6876 : ALLOCATE (bs_${idim}$ (dbt_nblks_total(tensor, ${idim}$)))
2076 2292 : CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$)
2077 6876 : ALLOCATE (dist_${idim}$ (tdims(${idim}$)))
2078 16764 : dist_${idim}$ = 0
2079 2292 : IF (mem_aware(${idim}$)) THEN
2080 6276 : DO ibatch = 1, nbatch(${idim}$)
2081 3984 : ind1 = batch_range_${idim}$ (ibatch)
2082 3984 : ind2 = batch_range_${idim}$ (ibatch + 1) - 1
2083 3984 : batch_size = ind2 - ind1 + 1
2084 : CALL dbt_default_distvec(batch_size, pdims(${idim}$), &
2085 6276 : bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2))
2086 : END DO
2087 : ELSE
2088 0 : CALL dbt_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$)
2089 : END IF
2090 : END IF
2091 : #:endfor
2092 :
2093 764 : CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2094 : #:for ndim in ndims
2095 1528 : IF (ndims_tensor(tensor) == ${ndim}$) THEN
2096 764 : CALL dbt_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$)
2097 764 : CALL dbt_create(t_tmp, name, dist, map1, map2, ${varlist("bs", nmax=ndim)}$)
2098 : END IF
2099 : #:endfor
2100 764 : CALL dbt_distribution_destroy(dist)
2101 :
2102 764 : IF (PRESENT(nodata)) THEN
2103 0 : IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
2104 : ELSE
2105 764 : CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
2106 : END IF
2107 :
2108 764 : CALL dbt_copy_contraction_storage(tensor, t_tmp)
2109 :
2110 764 : CALL dbt_destroy(tensor)
2111 764 : tensor = t_tmp
2112 :
2113 764 : IF (PRESENT(unit_nr)) THEN
2114 764 : IF (unit_nr > 0) THEN
2115 0 : WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name)
2116 0 : WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
2117 0 : CALL dbt_write_split_info(pgrid, unit_nr)
2118 : END IF
2119 : END IF
2120 :
2121 764 : IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE.
2122 :
2123 764 : CALL timestop(handle)
2124 772 : END SUBROUTINE
2125 :
2126 : ! **************************************************************************************************
2127 : !> \brief map tensor to a new 2d process grid for the matrix representation.
2128 : !> \author Patrick Seewald
2129 : ! **************************************************************************************************
2130 772 : SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2131 : TYPE(dbt_type), INTENT(INOUT) :: tensor
2132 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm
2133 : INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
2134 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
2135 : INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
2136 : LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2137 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2138 1544 : INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2139 1544 : INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
2140 1544 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2141 2316 : TYPE(dbt_pgrid_type) :: pgrid
2142 772 : INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$
2143 772 : INTEGER, DIMENSION(:), ALLOCATABLE :: array
2144 : INTEGER :: idim
2145 :
2146 772 : CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2147 772 : CALL blk_dims_tensor(tensor, dims)
2148 :
2149 772 : IF (ALLOCATED(tensor%contraction_storage)) THEN
2150 : ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges)
2151 3088 : nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
2152 : ! for good load balancing the process grid dimensions should be chosen adapted to the
2153 : ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
2154 : ! the number of batches (number of index ranges).
2155 3860 : DO idim = 1, ndims_tensor(tensor)
2156 2316 : CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
2157 2316 : dims(idim) = array(nbatches(idim) + 1) - array(1)
2158 2316 : DEALLOCATE (array)
2159 2316 : dims(idim) = dims(idim)/nbatches(idim)
2160 5404 : IF (dims(idim) <= 0) dims(idim) = 1
2161 : END DO
2162 : END ASSOCIATE
2163 : END IF
2164 :
2165 772 : pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2166 772 : IF (ALLOCATED(tensor%contraction_storage)) THEN
2167 : #:for ndim in range(1, maxdim+1)
2168 1544 : IF (ndims_tensor(tensor) == ${ndim}$) THEN
2169 772 : CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$)
2170 : CALL dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, &
2171 772 : nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2172 : END IF
2173 : #:endfor
2174 : ELSE
2175 0 : CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2176 : END IF
2177 772 : CALL dbt_pgrid_destroy(pgrid)
2178 :
2179 772 : END SUBROUTINE
2180 :
2181 184948 : END MODULE
|